决策树ID3

您所在的位置:网站首页 决策树概念正确的是 决策树ID3

决策树ID3

2024-02-21 02:33| 来源: 网络整理| 查看: 265

最近在看决策树,在B站上看到了一个前辈的讲课视频 讲的非常详细,于是自己手动实现了一下基于ID3的决策树

说来惭愧,我是新手,所以并没有导包,纯原始python写的。也并没有画出最后决策树的构建图。

只是让我对这个决策树更加了解一些,后续学到引入外部包,再说。

#!/usr/bin/env python # -*- coding:utf-8 -*- import math import numpy as np # 训练集--构建决策树 data = [ ['Sunny', 'Hot', 'High', 'Weak', 'No'], ['Sunny', 'Hot', 'High', 'Strong', 'No'], ['Overcast', 'Hot', 'High', 'Weak', 'Yes'], ['Rain', 'Mild', 'High', 'Weak', 'Yes'], ['Rain', 'Cool', 'Normal', 'Weak', 'Yes'], ['Rain', 'Cool', 'Normal', 'Strong', 'No'], ['Overcast', 'Cool', 'Normal', 'Strong', 'Yes'], ['Sunny', 'Mild', 'High', 'Weak', 'No'], ['Sunny', 'Cool', 'Normal', 'Weak', 'Yes'], ['Rain', 'Mild', 'Normal', 'Weak', 'Yes'], ['Sunny', 'Mild', 'Normal', 'Strong', 'Yes'], ['Overcast', 'Mild', 'High', 'Strong', 'Yes'], ['Overcast', 'Hot', 'Normal', 'Weak', 'Yes'], ['Rain', 'Mild', 'High', 'Strong', 'No'] ] columns = ['Outlook', 'Temperature', 'Humidity', 'Wind'] columns_index = { 'Outlook': 0, 'Temperature': 1, 'Humidity': 2, 'Wind': 3, } # 第1步计算决策属性的熵 def calculate_entropy(path): # print('path', path) decision_entropy = 0 decision_calculate = {} filtered_data = [] for line in data: # if满足条件 satisfy = True for column in path: if path[column] != line[columns_index[column]]: satisfy = False break if satisfy: filtered_data.append(line) # print(filtered_data) for line in filtered_data: count = decision_calculate.get(line[-1]) if count is None: count = 0 count += 1 decision_calculate[line[-1]] = count # print(decision_calculate) if len(filtered_data) > 0: for decision in decision_calculate: decision_calculate[decision] /= len(filtered_data) * 1.0 decision_entropy -= decision_calculate[decision] * math.log(decision_calculate[decision], 2) return decision_entropy, filtered_data # 第2步计算条件属性的熵 # 条件属性共有4个: # Outlook、 Temperature、 Humidity、 Wind。 # 分别计算不同属性的信息增益。 # 计算Outlook中各个属性的条件熵 # Outlook共分三个组: # Sunny(D1)、Overcast(D2)、 Rain(D3) # Sunny def child_node(parent_score, nodes, node_data, path): # node_data 根据 node 分组 node_dict = {} for line in node_data: for node in nodes: attributes = node_dict.get(node) if attributes is None: attributes = {} attribute = attributes.get(line[columns_index[node]]) if attribute is None: attribute = {} num = attribute.get(line[-1]) if num is None: num = 1 else: num += 1 attribute[line[-1]] = num attribute['count'] = 1 if attribute.get('count') is None else attribute.get('count') + 1 attributes[line[columns_index[node]]] = attribute node_dict[node] = attributes # print(node_dict) # 计算Outlook中各个属性的条件熵 root = next(iter(node_dict)) increment = 0 for node in node_dict: node_score = [] weight = [] # print('node :', node, end='') for attribute in node_dict[node]: # print(' attribute :', attribute, end='') current_path = path.copy() current_path[node] = attribute decision_entropy, filtered_data = calculate_entropy(current_path) # print(' attribute_score :', decision_entropy) node_score.append(decision_entropy) weight.append((node_dict[node][attribute]['count'] / len(node_data))) # print('node_score', node_score) # print('weight', weight) node_score = sum(np.multiply(node_score, weight)) # print('node_score', node_score) if parent_score - node_score > increment: increment = parent_score - node_score root = node # print('increment', parent_score - node_score) # print('choose :', root, ' increment:', increment) return root, node_dict[root] def find_attribute(root, attributes, path, tree_node): # print('choose', root, 'attributes', attributes) # print('tree_node', tree_node.name, 'root', root) for attribute in attributes: # print('attribute... : ', attribute) path[root] = attribute # find node entropy, node_data = calculate_entropy(path) # print('filtered_data', node_data) # print('entropy', entropy) attribute_node = Node(attribute, tree_node, []) tree_node.next.append(attribute_node) if entropy == 0.0: attribute_node.next.append(Node(node_data[-1][-1], next, None)) # print(node_data[-1][-1]) elif len(node_data) > 0: node, attributes = child_node(entropy, columns, node_data, path) # find attribute temp_node = Node(node, next, []) attribute_node.next.append(temp_node) # print('choose', node, 'attributes', attributes) find_attribute(node, attributes, path, temp_node) path.pop(root) class Node: def __init__(self, name, before, next): self.next = next self.before = before self.name = name # 决策树构建完成后,进行预测 def predict(root, line): if root.next is None: print('result', root.name) column = root.name attribute = line.get(column) if attribute is not None: for next in root.next: if next.name == attribute: for next2 in next.next: predict(next2, line) if __name__ == '__main__': # {'Outlook': 'Sunny', 'Temperature': 'Hot'} path = {} # find node entropy, node_data = calculate_entropy(path) root, attributes = child_node(entropy, columns, node_data, path) head = Node(root, None, []) # find attribute find_attribute(root, attributes, {}, head) test_data = [ { 'Temperature': 'Hot', 'Humidity': 'High', 'Wind': 'Weak', 'Outlook': 'Sunny' }, { 'Outlook': 'Overcast', 'Temperature': 'Hot', 'Humidity': 'High', 'Wind': 'Weak' }, ] for line in test_data: print(line) predict(head, line)


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3