前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >决策树的构建原理

决策树的构建原理

作者头像
SYSU星空
发布2022-05-05 14:16:02
1.3K0
发布2022-05-05 14:16:02
举报
文章被收录于专栏:微生态与微进化

决策树(Decision Tree)是一种简单但是广泛使用的分类预测模型。通过训练数据构建决策树,可以高效的对未知的数据进行分类并作出决策。决策树有两大优点,一是决策树模型可以读性好,具有描述性,有助于人工分析;二是效率高,决策树只需要一次构建,反复使用,但是预测的最大计算次数不能超过决策树的深度。一个简单的决策树例子如下所示:

决策树构建步骤

决策树属于一种有监督的机器学习,同时也属于约束的聚类。决策树可分为分类树回归树两种,分类树对离散响应变量做决策树,回归树对连续响应变量做决策树。决策树需要预测变量的训练数据集来构建,其基本步骤如下:

①开始,所有预测变量均看作一个节点;

②遍历每个预测变量的每一种分割方式,找到最好的分割点;

③分割成两个分支N1和N2;

④对N1和N2分别继续执行2-3步,直到每个节点足够“纯”为止。

决策树的预测变量可以有两种:

①数值型(Numeric):变量类型是整数或浮点数,如“年收入”数据,可以用“>=”,“>”,“<”或“<=”作为分割条件;

②类别型(Nominal):变量只能从有限的选项中选取,比如前面例子中的“婚姻情况”,只能是“单身”,“已婚”或“离婚”,也即因子型,可以使用“=”来分割。

节点分裂标准

如何在节点下进行分类并评估分割点的好坏是决策树构建中的关键环节。如果一个分割点可以将当前的所有节点分为两类,使得每一类都很“纯”,也即分类效果良好,那么就是一个好分割点。比如上面的例子,“拥有房产”,可以将记录分成了两类,“是”的节点全部都可以偿还债务,非常“纯”;“否”的节点,可以偿还贷款和无法偿还贷款的人都有,不是很“纯”,但是两个节点加起来的纯度和与原始节点的总纯度之差最大,所以按照这种方法分割。构建决策树采用贪心策略,只考虑当前纯度差最大的情况作为分割点。

这里的“纯度”必须要量化才能进行实际比较分析,例如在MRT分析中将组内方差作为纯度,分类后组内方差之和与上一层级差别较大,说明分类是有效的。下面介绍几种纯度量化的方法:

①基尼不纯度(Gini impurity)

②信息熵(Information Entropy)

③错误率

其中P(i)为该节点下第i个分支也即分类子节点的概率,也即分到该类的观察值占全部数据的比例。Gini不纯度主要反映的是节点分类纯度,所含类别越多越不纯,Gini不纯度越高;信息熵借鉴的物理学熵的原理,该节点数据所包含的分类越多,那么该节点所蕴含的数据信息越多,混乱度也即熵也越大;错误率也即将该节点数据随即分类的错误率,显然分类越多则错误率越大。上面三个指标均是值越大,表示越“不纯”,越小表示越“纯”。假如该节点下所有观察值都属于同一类,那么该节点下只有一个分支,其分类概率为1,Gini不纯度、信息熵、错误率均为0;如果该节点下有很多分支,而且每个分支概率均匀,也即可以分成很多类,因此Gini不纯度就会高,所包含的信息熵也大,将数据进行随机分类时错误率也高。其中信息熵是最常用的纯度指标。

与MRT分类原理类似,决策树使用下一级节点(子节点)纯度的加权和与上一级节点(父母节点)纯度的差值来衡量这一节点的分类是否是有效的,这个纯度差被称为信息增益(Information Gain),其公式如下所示:

其中I为任一上述纯度量化指标,vj为第j个节点,N为节点所包含的观察值或记录的个数,k为子节点个数,该公式也可以理解为使用该节点(parent)纯度减去该节点子节点纯度的加权和来衡量该节点下的分类有效性。

分裂停止条件

决策树的构建是一个递归过程,如果不设置特定的停止条件,最终每个分支末端节点只包含一个观察值或者记录,这时节点纯度和为0,容易出现过度拟合问题,这样的分类一般是没有意义的。一般可以设置某节点下分类的观察值个数低于一个最小的阈值,即停止分割。常见的停止条件如下所示:

①如果节点中所有观测属于一类;

②如果该节点中所有观测的属性取值一致;

③如果树的深度达到设定的阈值;

④如果该节点所含观测值数小于设定的父节点应含观测数的阈值;

⑤如果该节点的子节点所含观测数将小于设定的阈值;

⑥如果没有属性能满足设定的分裂准则的阈值。

决策树优化方案

在决策树建立过程中可能会出现过度拟合情况,也即分类过于“细”,导致对训练数据可以得到很低的错误率,但是运用到测试数据上却得到非常高的错误率。过度拟合的原因可能有以下几点:

①噪音数据:训练数据中存在噪音数据,决策树的某些节点有噪音数据作为分割标准,导致决策树无法代表真实数据;

②缺少代表性数据:训练数据没有包含所有具有代表性的数据,导致某一类数据无法很好的匹配;

③多重比较(Mulitple Comparition):当预测变量很多时,总会有一个变量与响应变量具有较好的随机相关性从而具有好的分类效果,但是使用新数据进行预测时这种效果消失,这与多元回归中由于随机相关对R2校正的原理类似。

对于存在过度拟合的决策树,有以下几种优化方案:

①修剪枝叶

决策树过度拟合往往是因为太过“茂盛”,也就是节点过多,分类过细,所以需要裁剪(Prune Tree)枝叶。裁剪枝叶的策略对决策树的正确率影响很大,主要有两种裁剪策略,一种是前置裁剪,也即在构建决策树的过程时,提前停止,可以将分裂准则设定的更严格来实现;另一种是后置裁剪,也即决策树构建好后,然后才开始裁剪,可以用单一叶节点代替整个子树,叶节点的分类采用子树中最主要的分类,也可以将一个子树完全替代另外一颗子树。

②交叉验证

使用K-Fold Validataion方法计算决策树,并裁剪到i个节点,计算错误率,最后求出平均错误率。这样可以用具有最小错误率对应的i作为最终决策树的大小,对原始决策树进行裁剪,得到最优决策树。

③自助方法

自助聚合(bagging:bootstrap aggregating)也叫装袋法,是基于自助法发展而来,也即让机器学习进行多轮,每轮在训练数据集中随机抽取n个样本进行学习,最终选取错误率低的模型。随机森林(Random Forest)就是决策树的自助聚合法,用训练数据随机的计算出许多决策树,形成了一个森林。然后用这个森林对未知数据进行预测,选取正确率最高的分类。实践证明,此算法的错误率得到了进一步的降低。这种方法背后的原理可以用“三个臭皮匠,赛过诸葛亮”这句谚语来概括。

④推进方法

推进或者说提升(boosting)方法是一种改进的决策树构建方法,其原理和随机森林类似,例如对于分类树,获得比较粗糙的分类(弱学习或者弱分类器)要比获得一个精确的分类(强学习)容易得多,提升方法就是获得很多粗糙的分类并赋予这些弱分类器相等的权重,然后根据错误率重新加权将这些“弱学习”提升为一个“强学习”。聚合推进树(aggregated boosted tree,ABT)就是决策树的提升聚合法应用。

决策树算法是整合了前面分裂准则、停止条件和优化方案的整合算法,常见的决策树算法有有ID3、CART和C4.5等,其中ID3和C4.5只用于分类,而CART既可以用于分类又可以用于回归,是现在最常用的方法。

决策树构建示例

在R中与决策树有关的常见软件包如下所示:

单棵决策树:rpart/tree/C50

随机森林:randomForest/ranger/party

梯度提升树:gbm/xgboost

决策树可视化:rpart.plot

接下来我们使用rpart包中的rpart()函数来实现CART算法建模,使用rpart.plot包中的rpart.plot()函数进行决策树可视化。首先以rpart包内置数据集kyphosis为例进行分析,该数据集为经过脊柱矫正手术的儿童驼背出现情况,包含了驼背(kyphosis)、年龄(Age,单位:月)、矫正的椎骨数目(Number)和手术矫正椎骨起始位置(Start)四个变量,如下所示:

我们想根据这些数据建立对是否出现驼背(present or absent)的决策树预测模型,如下图所示:

代码语言:javascript
复制
library(rpart)
library(rpart.plot)
fit=rpart(Kyphosis~Age+Number+Start, data=kyphosis, method="class")
rpart.plot(fit, type=2)

如图所示节点上标出了分割条件例如第一个节点是利用变量start对因变量进行分割,分割点是8.5,右边是no表示start<8.5,节点方框中显示了改节点的分类结果、出现驼背(present)的概率、该节点下样本数目占全部样本的比例,可以使用summary(fit)命令查看决策树详细构造。

rpart()函数主要参数有:

method:根据树末端的数据类型选择相应变量分割方法,有四种取值分别为连续型“anova”、离散型“class”、计数型(泊松过程)“poisson”、生存分析型“exp”。程序会根据因变量的类型自动选择方法,但一般情况下最好还是指明本参数,以便让程序清楚做哪一种树模型。

parms:对于连续型anova不用设置parms;对于计数型和生存分析型poisson/exp来说parms设置先验分布的变异系数,默认为1,对于离散分类型class来书parms包含三个子参数即先验概率prior、损失矩阵loss、分类纯度的度量方法split(可选gini和information),三个子函数通过list()函数包含起来。先验概率一般为不同类别的出现概率(比例),但是有时候由于样本局限性不能代表总体,则需要设置先验概率进行修正,损失矩阵也即每一个节点预测失败时的损失,常见于商业分析。

control:设置分裂准则、停止条件、优化方法、交叉验证等,通过rpart.control()函数来构建,主要参数如下:

xval:交叉验证的次数;

minsplit:最小分支节点数,如果分支包含的子节点数大于等于设定值,那么该节点会继续分划下去,否则停止;

minbucket:设置节点最小样本数,小于设定值则停止分割;

maxdepth:决策树的最大深度也即分类层级;

cp:全称为complexityparameter,节点的复杂度参数,指对每一步拆分模型的拟合优度必须提高的程度,也即信息增益。

接下来我们设置参数进行精确建树,如下所示:

代码语言:javascript
复制
ct=rpart.control(xval=10, minsplit=20, cp=0.1)
fit=rpart(Kyphosis~Age+Number+Start, data=kyphosis, method="class", control=ct, parms=list(prior=c(0.65, 0.35), split="information"))
rpart.plot(fit, branch=1, branch.type=2, type=1, extra=102, shadow.col="gray", box.col="green", border.col="blue", split.col="red", split.cex=1.2, main="Kyphosis决策树")

除了对决策树的生成进行控制外,还可以对决策树进行后期的评价与修剪,可以使用printcp()函数查看决策树的各项指标:

结果中给出了分到每一层的cp、分割点数目nsplit、相对误差rel error、交叉验证的估计误差xerror、标准误差xstd。其中相对误差为决策树不能解释的方差,其值为1减去该层的上一层累积的cp。和MRT一样,rel error和xerror越小越好,而cp越大越好。决策树的修剪可以使用prune()函数,如下所示:

代码语言:javascript
复制
fit2=prune(fit, cp=0.3) 
rpart.plot(fit2, branch=1, branch.type=2, type=1, extra=102, shadow.col="gray",box.col="green", border.col="blue", split.col="red", split.cex=1.2, main="Kyphosis决策树")

一个完整的决策树构建、交差验证、修剪以及测试数据预测流程如下所示:

代码语言:javascript
复制
##导入数据集,把目标变量转为因子
accepts=read.csv("accepts.csv")
accepts$bad_ind=as.factor(accepts$bad_ind)
names(accepts) #查看变量
代码语言:javascript
复制
accepts=accepts[,c(3,7:24)] #选择响应变量与预测变量
##将数据分为训练集和测试集(训练集样本占70%)
select=sample(1:nrow(accepts), length(accepts$bad_ind)*0.7)
train=accepts[select,]
test=accepts[-select,]
summary(train$bad_ind)
##CART建树
library(rpart)
tc=rpart.control(minsplit=20, minbucket=20, maxdepth=10, xval=5, cp=0.005)
rpart.mod=rpart(bad_ind~., data=train, method="class", parms=list(prior=c(0.65,0.35), split="gini"), control=tc)
summary(rpart.mod)
#绘制决策树
library(rpart.plot)
rpart.plot(rpart.mod, branch=1, extra=106, under=TRUE, faclen=0, cex=0.8,main="决策树")
代码语言:javascript
复制
##查看变量重要性
rpart.mod$variable.importance
代码语言:javascript
复制
#查看cp并绘制cp与交差验证误差曲线
rpart.mod$cp
代码语言:javascript
复制
plotcp(rpart.mod)
代码语言:javascript
复制
#修剪决策树并绘图
rpart.mod.pru=prune(rpart.mod, cp=0.007) 
library(rpart.plot)
rpart.plot(rpart.mod.pru,branch=1, extra=106, under=TRUE, faclen=0, cex=0.8, main="决策树")
代码语言:javascript
复制
#使用测试集进行预测
rpart.pred=predict(rpart.mod.pru, test)
pre=ifelse(rpart.pred[,2]>0.5,1,0)

部分预测结果如下所示:

我们可以统计预测正确的比例,如下所示:

代码语言:javascript
复制
accuracy=length(which(test$bad_ind==pre))/length(pre)

可以看到预测的准确率为0.74,准确率越高,说明决策树模型越好。具有多个响应变量的决策树可以使用多元回归树MRT。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-12-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 微生态与微进化 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档