9. 分类回归树模型CART(Classification & Regression Tree)

您所在的位置:网站首页 树状分类法怎么画 9. 分类回归树模型CART(Classification & Regression Tree)

9. 分类回归树模型CART(Classification & Regression Tree)

2024-07-15 22:37| 来源: 网络整理| 查看: 265

1. 简介

树模型直白且清晰,它即可以用来分类也可以用来预测,他最大的特点是容易解释,这在实际应用中十分关键。树通过在predictor中创建许多的分支来创建(IF ELSE)的规则,例如"IF 年龄2,则分类为1"。树的创建的基本思想包括两条,第一,recursive partitioning (用于树的构建),第二,pruning(用于树的剪枝)

图一:示例 2. 树的创建

a. 树的结构 树的节点分为两种:decision nodes (splitting nodes) 和 terminal nodes (leaves of a tree)

b. recursive partitioning 树中的predictor可以是continuous的,binary的,或者ordinal的,如下图所示(以二维为例,若是高维则产生高维矩形区域),当选定了split以后,会分离出两个矩形区域

图2.1 示例, split之前.png 图2.2 示例, split之后.png

Numerical Predictor:分离线的位置在两个连续点之间,因此在上图的例子中,按照income来,可能存在的位置是[38.1, 45.3, 50.1, ..., 109.5],如果按照lot size来可能是[14.4, 15.4,..., 23],然后对于这些可能存在的点,利用分好之后整体impurity来进行排序(见本section的c部分)

Categorical Predictor:对于categoircal,每次分离创建出两个子集,比如一个categorical variable有四个值{a, b, c, d},则第一次split会创建出{a}, {b, c, d}; {b}, {a, c, d} ... {a, c}; {b, d},假如一个predictor有m个值,那么需要创建m(而不是m-1)个dummy variable

对于树模型而言,normalize数据与否不影响结果

c. 选定分离点 分离点的选定依靠的是impurity,分离之后需要使得分离之后变得更加pure 对于一个区域,impurty的定义有两种(假设categorical variable有m个分类)

i. Gini index:G = 1-\Sigma_{k=1}^m{p^2(k)} Gini index的取值范围是[0, \frac{m-1}{m}] 在区域内的点全为一个种类时取到最低点: G = 1-1^2-0^2-...-0^2=0 在区域内的点每个种类的点的个数相同的时候取到最高点: G = 1-(\frac{1}{m})^2-(\frac{1}{m})^2-...-(\frac{1}{m})^2=1-\frac{1}{m}=\frac{m-1}{m} 显然,种类越多,m值越大,则Gini index的最高点越高。Gini index的值越低,说明该区域的impurity越小

ii. Entropy: E = -\Sigma_{k=1}^mp{(k)}log{p(k)} Entropy的取值范围为[0,logm] (注:该log的底可以是任意数,2, 10, e都可以,但是为了方便一般选择2) 在区域内的点全为一个种类时取到最低点: E = -1*log1-0*log0-...-0*log0=0 在区域内的点每个种类的点的个数相同的时候取到最高点: E = -\frac{1}{m}log\frac{1}{m}-\frac{1}{m}log\frac{1}{m}-...-\frac{1}{m}log\frac{1}{m}=logm 可以看到,种类越多,m值越大,则Entropy的最高点越高。Entropy的值越低,说明该区域的impurity越小

iii. 整合两个区域的purity 使用i或者ii的方法计算完impurity以后,首先计算第一个区域的有多少个点,按百分比贡献总impurity,第二个区域也用同样的方法来计算。 假设第一个区域有n_1个点,Impurity为G_1或者E_1,第二个区域有n_2个点,Impurity为G_2或者E_2则总impurity为: Impurity_{new} = \frac{n_1}{n_1+n_2}*G_1(或E_1)+\frac{n_2}{n_1+n_2}*G_2(或E_2) Impurity的差值被称作Gain of the split,其公式为: J = Impurity_{old}-Impurity_{new} 如果Gain of the split越大,说明降低的impurity越多,split point选取的值越好

iv. 示例

图3.1 树的创建示例 问题一:寻找第一个split point 第一个split point可以是X1,X2,或者X3,讨论:

未分离前, G_{old}=1-(\frac{3}{7})^2-(\frac{4}{7})^2=0.490 对于X_1:显然split point应该为 X_1\geq0 分离后, observation 1,2,6,7被分离到一起(左半边),其中1,2,6的Y=1,7的Y=0 因此G_{left}=1-(\frac{3}{4})^2-(\frac{1}{4})^2=0.375 observation 3,4,5被分离到一起(右半边),并且Y=0 因此G_{right}=1-(\frac{3}{3})^2-(\frac{0}{3})^2=0 整合左半边右半边,左半边有4个点,右半边有3个点,因此 G_{new}=\frac{4}{7} *G_{left}+\frac{3}{7} *G_{right}=0.214 因此降低的impurity为: J(X_1)=G_{old}-G_{new}=0.276

同理,可以算出,J(X_2)=0.085 (rule: X_2 \geq 0), J(X_3)=0.085 (rule: X_3 \geq 0)J(X_1)最大,因此第一个split point选择用X_1来split,split rule为X_1 \geq 0

如果用Entropy来计算,则 E_{old}=-\frac{3}{7}log(\frac{3}{7})-\frac{4}{7}log(\frac{4}{7})=0.985 E_{left}=-\frac{3}{4}log(\frac{3}{4})-\frac{1}{4}log(\frac{1}{4})=0.562 E_{right}=-\frac{3}{3}log(\frac{3}{3})-\frac{0}{3}log(\frac{0}{3})=0 J(X_1)=0.985-\frac{4}{7}*0.562-\frac{3}{7}*0=0.521 同理,J(X_2)=J(X_3)=0.128,因此选择用X_1来split,split rule为X_1 \geq 0

问题二:按照X_1,X_2,X_3的顺序写出fully-fit tree

图3.2 fully-fit tree的创建 按照规则,所有的点都被分到对应的terminal node即可

c. python相关函数 DecisionTreeClassifier():产生的是二叉树,terminal nodes比decision nodes刚好多一个node export_graphviz():用于visualize并产生图一中的rule

3. 树模型的分类预测

若新的数据的目的是: 分类,则新的数据随着树一直落到对应的terminal node,然后由该node的数据进行投票决定 如果对于balanced data,我们可以直接通过多数的投票决定其分类 但是如果data中的class of interest十分稀少,则可以设定一个cutoff probability,显然,lower cutoff相当于减少分类为class of interest所需要的投票百分比,则该数据更偏向于被分类为class of interest 在分类问题中,有可能会出现class of interest非常稀有的情况,这种情况下 预测,则新的数据随着树一直落到对应的terminal node,然后由该node的数据进行weight得到

4. full-grown tree存在的问题及解决思路

full-grown tree存在两个问题: a. unstable,当不同的sample被选取的时候,树的结构会发生极大的改动 b. overfitting,显而易见,不赘述了

为了解决问题一,采用cross-validation的方式来进行sensitivity analysis 采用python的cross_val_score可以得到每一个fold对应的accuracy

treeClassifier = DecisionTreeClassifier(random_state=1) scores = cross_val_score(treeClassifier, train_X, train_y, cv=5) # cv=5表示5-fold validation print(’Accuracy scores of each fold: ’, [f’acc:.3f’ for acc in scores])

为了解决问题二,采用pruning的方式 a) Fine-tuning Trees: 限制树的深度 / terminal node的数量 / impurity降低的最小值 Fine-tuning最大的难点在于其参数该如何调整,关于其调参的步骤在下面的python代码中会详细讲解

Fine-tuning在python中代码的实现:

# 设定深度最多为30,如果节点超过20则split,split之后的impurity必须小于0.01,random-state参数决定的是随机种子用于稳定分离数据集的结果 smallClassTree = DecisionTreeClassifier(max_depth=30, min_samples_split=20, min_impurity_decrease=0.01, random_state=1)

现实运用中则需要对这些参数进行调整,因此可以用到exhaustive grid search方法来网格式搜索并寻找最佳的参数组合:

# param_grid之中每个value的组合都被搜索一遍,这里总共有4*5*5=100种组合 param_grid = { 'max_depth': [10, 20, 30, 40], 'min_samples_split': [20, 40, 60, 80, 100], 'min_impurity_decrease': [0, 0.0005, 0.001, 0.005, 0.01], } gridSearch = GridSearchCV(DecisionTreeClassifier(random_state=1), param_grid, cv=5, n_jobs=-1) # n_jobs=-1 will utilize all available CPUs gridSearch.fit(train_X, train_y) print('Initial score: ', gridSearch.best_score_) print('Initial parameters: ', gridSearch.best_params_)

打印出相应的结果后,在根据对应的结果继续细化网格搜索,例如以上如果输出的best_parameter的组合是:

{'max_depth': 10, 'min_impurity_decrease': 0.001, 'min_samples_split': 20}

则可以运行如下的代码:

# 根据第一波搜索的结果进行新的一轮网格搜索 param_grid = { 'max_depth': list(range(2, 16)), # 14 values 'min_samples_split': list(range(10, 22)), # 11 values 'min_impurity_decrease': [0.0009, 0.001, 0.0011], # 3 values } gridSearch = GridSearchCV(DecisionTreeClassifier(random_state=1), param_grid, cv=5, n_jobs=-1) gridSearch.fit(train_X, train_y) print('Improved score: ', gridSearch.best_score_) print('Improved parameters: ', gridSearch.best_params_) bestClassTree = gridSearch.best_estimator_

b) Conditional Inference Trees 对于树的大小的限制,还有一个方法 CHAID (Chi-squared Automatic Interaction Detection),当运用split point的时候,我们检测的指标是impurity降低的大小,\chi^2可以用于假设性检验这个impurity降低的量到底是不是significant。 CHAID的第一步是寻找与response variable的association最强的predictor,这个强度由\chi^2对两个变量的独立性测试的p-value决定,如果asscociation最强的predictor进行split以后purity的提升没有达到一个significant的地步,那么就停止树的增长。这种方式演化出来的树就是conditional inference trees。

c) Pruning Tree pruning的基本思路是剪掉原本树模型中的冗余枝干,这些枝干会拟合噪声,从而使结果的准确性降低。 pruning常用的算法有两个,一个是C4.5算法,一个是CART method。 对于C4.5,训练集不但被用于grow也被用于prune 对于CART,则是用validation data去prune back tree

d) CART method CART运用cost complexity function来prune back tree,cost complexity function如下: CC(T)=err(T)+\alpha L(T) 其中err(T)表示树对训练集的错误率,L(T)表示树有多少个terminal nodes,\alpha是惩罚项,显然这个数值越小越好 显然如果\alpha=0表示对terminal nodes没有惩罚,得到的会是fully-grown tree,而如果\alpha很大,则对树的terminal nodes的惩罚很大,树只会给出极少的terminal nodes。同样的,考虑到树的不稳定性,对不同的\alpha可以运用k-fold cross validation来选取得到最高准确率的\alpha

Minimum-error tree: Pruning to the Minimum Error to the Validation Set

图4. Minimum-error tree tree的生成与prune的伪代码: step 1:分成training set和validation set step 2:用training set生成树 step 3:连续prune,每次都记录下CC(T) step 4:记录下validation set取到minimum error的时候的CC(T) step 5:cross-validation,重复step 1 -> 4 step 6:每一次都记录下minimum error处的CC(T),将每一次的CC(T)做一个平均值\overline{CC(T)} step 7:回到原始数据或是新的数据,生成一棵新的树,当树的CC(T)达到\overline{CC(T)}的时候停止生长

Best-Pruned Tree: prune

图5. Best-pruned tree 为了模型的简便性,也可以选择一个比minimum error tree更小的树,当达到了validation set的最小误差的地方时,继续pruning,使树的准确率在最小的validation error(xerror)的一个标准差的估计值(xstd)之内。

e) C4.5算法 C4.5算法是基于ID3算法的改进方法,ID3算法使用Gain of the split来grow tree,C4.5算法与ID3算法的不同点在于C4.5的split采用的是Gain ratio

Gain ratio 回顾上面的Gain of the split的公式: J = Impurity_{old}-Impurity_{new} Gain ratio,可以写成下面的形式 ratio = \frac{J(X)}{I(X)} 分母的数I(X)代表了X这个属性在原本树中的impurity 注:

之前的impurity的参照都是Y,但是I(X)的impurity的参照是X Gain ratio只能用Entropy来计算,Gini index只用在CART method里面

比如图3中的例子,对于X_1而言,observation 1,2,6,7共4个点的X_1是1,observation 3, 4, 5共3个点的X_1是-1 I(X_1) = -\frac{3}{7}log(\frac{3}{7})-\frac{4}{7}log(\frac{4}{7})^2=0.985 因此,Gain ratio为: ratio=\frac{J(X_1)}{I(X_1)}=\frac{0.521}{0.985}=0.529

C4.5相较于ID3而言,允许了numerical predictors,允许了missing value,并且也能进行树的剪枝

Numerical predictor in C4.5 因为C4.5需要计算predictor的gain,这一步可以用来将numerical predictor转换成nomial predictor(通过寻找该numerical predictor的maximum gain,并将其设为threshold) 例如

图 6.1示例数据 其中humidity是numerical predictor 图 6.2转换过程

参考:https://sefiks.com/2018/05/13/a-step-by-step-c4-5-decision-tree-example/

C4.5与missing value C4.5会在missing value的这个属性下一直落到底端寻找决策,然后将决策按照概率返回 比如缺失的数据是cost,cost的两个叶子节点分别是买(对应小于5),不买(对应大于等于5),左叶子节点有3个示例,右叶子节点有2个示例,则C4.5会返回[0.6, 0.4]对应[买,不买] 参考:https://stackoverflow.com/questions/42219073/c4-5-algorithm-missing-values https://cis.temple.edu/~giorgio/cis587/readings/id3-c45.html

C4.5的Pruning与Statistical pruning ID3没有pruning的设计,这一点用以下的例子可以给出

图 6.3 ID3的示例 如果按照ID3的算法推导,最终会得到: 图 6.4 ID3得出的树状结构 最后的两个叶子节点得到的结果是一致的,是因为倒数第二个叶子节点是由投票决定的,而C4.5允许剪枝,然而ID3则不行 参考:https://towardsdatascience.com/decision-trees-for-classification-id3-algorithm-explained-89df76e72df1

C4.5的pruning可以采用计算error rate的方式来pruning 第一步:计算子节点的error rate的权重和,权重由子节点的数据量给出 第二步:如果剪枝以后父节点的error rate比第一步小,则选择剪枝

statistical pruning在这个基础上增加了置信区间,可对结果有一定的提升

置信区间的设计 f:f代表了在训练集中有多少的错误率 p:p代表了在无限的数据集之中有多少的错误率 其关系为: p = f \pm z\sqrt{f(1-f) / N} 其中z是根据significance level定义的数值,C4.5比较upper limit of error rate,如果剪枝之后的upper limit of error rate比较小,则选择剪枝,父节点的error rate为子节点的error rate的权重和,权重由子节点的数据量给出。

statistical pruning 例如

图7 C4.5剪枝示例 \alpha=0.75, z=0.69 以health-plan为例,三个节点共14个数据,其中(4+1+4=)9个成功,(2+1+2=)5个失败 如果不分,则f = 5/14, N=14 => p上限为0.49 如果分,则 节点1:f = 2/6 (2成功,共6), N=6(共6)=> p上限0.46 节点2:f = 1/2, N=2 => p上限0.74 节点3:f = 2/6, N=6 => p上限0.46 取权重和到节点4,则p上限为0.46 * 6/14 + 0.74 * 2/14 + 0.46 * 6/14=0.55 不分的情况下,error rate比分的情况要小,因此选择剪枝

5. 回归树

预测:回归树的建立和分类树类似,当有一条新数据的时候,将新数据一直落到叶子节点,回归树的预测是为该叶子节点处所有数据的平均值 impurity:在teminal node处所有数据同均值的差值的平方,最低可能是0,如果所有值都是一样的话 performance evaluation:可采用MSE, RMSE等常规summary measures

6. Multitrees Algorithm

a. 随机森林 随机森林,顾名思义就是生成许多的tree,总和这些tree的结果来实现分类和预测 随机森林的生成 第一步:boostrap,用data with replacement的方式生成多组数据 第二步:对每一组数据用随机的subset of predictors,生成分类或者回归树 第三步:组合每一棵树的结果,如果是分类,则对每棵树的结果投票,如果是预测,则对每棵树的结果做平均 随机森林不能够画出树形图,因此失去了单棵树对data的解释性,但是可以提供每个predictor的重要性,这个重要性可由每一次按照该predictor分类的情况下,impurity降低的程度来进行衡量

rf = RandomForestClassifier(n_estimators=500, random_state=1) rf.fit(train_X, train_y) # 打印出每个feature的重要性 importances = rf.feature_importances_ std = np.std([tree.feature_importances_ for tree in rf.estimators_], axis=0)

b. boosted trees boosted trees的关注点在于被错误分类的数据集,因此它通过对被错误分类的数据进行提升来产生结果

第一步:生成一棵树 第二步:给被错误分类的数据较高的probability,在这个基础上重新选取数据集 第三步:对新生成的数据集再拟合一棵树 第四步:重复二三步,直到停止的条件被满足 第五步:采用权重投票的方式,越往后的树投票的权重越高

boost = GradientBoostingClassifier() boost.fit(train_X, train_y) 7. 树的优点和缺点

优点: a. 树不需要进行data transformation,因为树的split只依赖数据的单调性 b. 树的variable section部分是自动选择,而且variable的重要性十分明显,重要的variable在树的上层结构 c. 树对于outliers十分稳定,因为split的选取只依赖数据的顺序而不是具体的值,但同时这也是它的缺点,因为一个微小的数据改动可能会导致树结构的彻底改变 d. nonlinear, nonparametric,既是优点也是缺点,优点在于它对于predictor和outcome之间的关系的容忍度很高,缺点在于它每次只关注一个predictor,这样会使得predictor之间的关系的信息丢失,比如以下的例子

图8

在这样的情况之下,树的分类的表现会很差,因为树每次的分类都是垂直或者水平的, 解决方案一:是从已知存在关系的predictor之中在产生一个新的predictor来概括他们可能存在的关系 解决方案二:则是6中的multitree的方式,random forest会在解决这种情况下比较有用。 e. 树需要大量的数据投喂,而且grow a tree需要消耗很多资源,这是因为每一次split的过程中都需要进行大量的排序操作,以及为了稳定树的结果所做的cross-validation也会增加计算的负担 f. 尽管树能够自动选择variable,但是树有一个偏好,就是如果一个preditor的潜在的split point比较多,它就会比较倾向于这个predictor,比如有很多种类的cateogircal variable和有很多值的numerical variable 解决方案一:对于有很多种类categorical variable选择组成一个个小的集合,对于很多值的numerical variable选择bin操作 解决方案二:用其他的splitting criterion来替代,比如conditional inference trees, 或者QUEST分类树 g. 树对待missing data有很好的效果,而且也不需要impute或者delte values h. 树能够产生非常清晰易懂的规则,尽管这些规则在multitree的规则下会丧失功效



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3