分类树算法原理及实现

您所在的位置:网站首页 树状分类法分类依据 分类树算法原理及实现

分类树算法原理及实现

2023-07-14 08:02| 来源: 网络整理| 查看: 265

作者:归辰

来源:海边的拾遗者

导读

今天是该系列第六篇文章,介绍分类树原理及实现。

    机器学习领域中的树模型其实就是结合了数据结构中的二叉树来开展机器学习任务的方法。本文所讲解的分类树为CART树中的一种,而CART树是决策树中的一种,其它还有ID3和C4.5。决策树算法是一类常用的机器学习算法,在分类问题中,决策树算法通过样本中某一维特征属性值的分布,将样本划分到不同的类别中,而这一功能就是基于树形结构来实现的。

    本文以决策树中的CART树为例介绍分类树的原理及实现。

叶节点分裂指标

     通常有这些指标:信息增益(Information Gain)、增益率(Gain Ratio)和基尼指数(Gini Index)。

    熵(Entropy)是度量样本集合纯度最常用的一种指标,对于包含m个训练样本的数据集D{(X(1),y(1)),(X(2),y(2)),…,(X(m),y(m))},pk为数据集D中第k类别数量所占比例,则数据集D的熵为:

    将数据集D按照某个特征的值划分为两个子数据集,此时数据集D的信息熵减小了,对于给定的数据集,划分前后信息熵的减少量称为信息增益为: 

    

    其中Dp为第p个子数据集样本数,ID3就是利用该指标来进行叶节点分裂。因而可以得到增益率的计算方法为:

    其中,IV(A)被称为特征A的"固有值",也等于数据集D根据划分好的子数据集种类来计算得到的信息熵 ,C4.5就是利用增益率来进行叶节点分裂。

    基尼系数也可以作为分裂的指标,数据集D的基尼系数为:

    在CART中即用该指标来进行叶节点分裂。现在让我们用代码将其实现。

from math import pow def cal_gini_index(data):     '''input: data     output: gini 基尼指数'''      total_sample = len(data) if len(data) == 0: return 0 label_counts = label_uniq_cnt(data) # 计算数据集的Gini指数 gini = 0 for label in label_counts:         gini = gini + pow(label_counts[label], 2)         gini = 1 - float(gini) / pow(total_sample, 2) return gini      def label_uniq_cnt(data):     '''input: data     output: label_uniq_cnt数据集各标签个数'''     label_uniq_cnt = {}     for x in data:         label = x[len(x) - 1] if label not in label_uniq_cnt: label_uniq_cnt[label] = 0 label_uniq_cnt[label] = label_uniq_cnt[label] + 1 return label_uniq_cnt 分类树

     在按照特征对上述的数据进行划分的过程中,需要设置划分的终止条件,通常在算法的过程中,设置划分终止条件的方法主要有:①结点中的样本数小于给定阀值(前剪枝);②样本集的基尼指数小于给定阀值(后剪枝);③没有更多特征。

    分类树的构建过程可以分为以下几个步骤:

对于当前训练数据集,遍历所有特征及其对应的所有可能切分点,寻找最佳切分特征及其最佳切分点,使得切分之后的基尼指数最小,利用该最佳特征及其最佳切分点将训练数据集切分成两个子集,分别对应判别结果为左子树和判别结果为右子树。

重复以下的步骤直至满足停止条件:为每一个叶子节点寻找最佳切分特征及其最佳切分点,将其划分为左右子树。

生成分类树。 

    将数据集D按照某个特征的值划分为两个子数据集,此时数据集D的信息熵减小了。

    现在先为树中的节点定义一个结构类,代码如下:

class node: def __init__(self, fea=-1, value=None, results=None, right=None, left=None): self.fea = fea # 用于切分数据集的特征的列索引值 self.value = value # 设置划分的值 self.results = results # 存储叶节点所属的类别 self.right = right # 右子树 self.left = left # 左子树

     然后我们可以利用递归的方法开始构建树了,在构建树的过程中,主要有如下的几步:①计算当前的基尼指数;②尝试按照数据集中的每一个特征将树划分成左右子树,计算出最好的划分,通过递归的方式继续对左右子树进行划分;③判断当前是否还可以继续划分,若不能继续划分则退出。

def build_tree(data):     '''input: data 训练样本     output: node 树的根结点''' # 构建决策树,函数返回该决策树的根节点 if len(data) == 0: return node() # 1、计算当前的基尼指数 currentGini = cal_gini_index(data) bestGain = 0.0 bestCriteria = None # 存储最佳切分特征以及最佳切分点 bestSets = None # 存储切分后的两个数据集 feature_num = len(data[0]) - 1 # 样本中特征个数     # 2、通过贪心法找到最好的划分 for fea in range(0, feature_num): # 2.1、取得fea特征处所有可能的取值 feature_values = {} # 在fea位置处可能的取值         for sample in data:  feature_values[sample[fea]] = 1 # 存储特征fea处所有可能的取值 # 2.2、针对每一个可能的取值,尝试将数据集划分,并计算基尼指数 for value in feature_values.keys(): # 遍历该特征的所有切分点 # 2.2.1、 根据fea特征中的值value将数据集划分成左右子树 (set_1, set_2) = split_tree(data, fea, value) # 2.2.2、计算当前的基尼指数 nowGini = float(len(set_1) * cal_gini_index(set_1) + \ len(set_2) * cal_gini_index(set_2)) / len(data) # 2.2.3、计算基尼指数的增加量 gain = currentGini - nowGini # 2.2.4、判断此划分是否比当前的划分更好 if gain > bestGain and len(set_1) > 0 and len(set_2) > 0: bestGain = gain bestCriteria = (fea, value) bestSets = (set_1, set_2) # 3、判断划分是否结束 if bestGain > 0: right = build_tree(bestSets[0]) left = build_tree(bestSets[1]) return node(fea=bestCriteria[0], value=bestCriteria[1], \ right=right, left=left) else:         return node(results=label_uniq_cnt(data))  # 返回当前的类别标签作为最终的类别标签 def split_tree(data, fea, value):     '''input: data fea待分割特征的索引 value待分割的特征的具体值     output: (set1,set2)分割后的左右子树''' set_1 = [] set_2 = [] for x in data: if x[fea] >= value: set_1.append(x) else: set_2.append(x) return (set_1, set_2)

    函数split_tree主要用于特征的值是连续的值时的划分,当特征fea处的值是一些连续值的时候,当该处的值大于或等于待划分的值value时,将该样本划分到set_1中,否则,划分到set_2中。

预测

     当整个分类树构建完成后,利用训练样本对分类树进行训练,最终得到分类树的模型,对于未知的样本,需要用训练好的分类树的模型对其进行预测。可以将其实现: 

def predict(sample, tree):     '''input: sample需要预测的样本 tree构建好的分类树     output: tree.results所属的类别''' # 1、只是树根 if tree.results != None: return tree.results else: # 2、有左右子树 val_sample = sample[tree.fea] branch = None if val_sample >= tree.value: branch = tree.right else: branch = tree.left return predict(sample, branch)

     到这里整个流程基本就结束了~

    由于小编近期工作比较忙,机器学习系列的第五篇文章(SVM)会在后面尽快发表,各位支持和关注我的朋友请多多谅解哈~

◆ ◆ ◆  ◆ ◆

号主新书已经在京东上架了,想要一睹为快的朋友,可以直接下单购买,目前京东正在举行活动,大家可以用原价5折的预购价格购买,3天左右会发货,还是非常划算的:

扫描下方二维码即可进入京东的购买链接(https://item.jd.com/12686131.html):

数据森麟公众号的交流群已经建立,许多小伙伴已经加入其中,感谢大家的支持。大家可以在群里交流关于数据分析&数据挖掘的相关内容,还没有加入的小伙伴可以扫描下方管理员二维码,进群前一定要关注公众号奥,关注后让管理员帮忙拉进群,期待大家的加入。

管理员二维码:

猜你喜欢

● 笑死人不偿命的知乎沙雕问题排行榜

● 用Python扒出B站那些“惊为天人”的阿婆主!

● 全球股市跳水大战,谁最坑爹!

● 华农兄弟、徐大Sao&李子柒?谁才是B站美食区的最强王者?

● 你相信逛B站也能学编程吗

点击阅读原文,即可参与京东5折购书活动



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3