首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

AI机器学习-决策树算法-Python实现画决策树图形

matplotlib是python中强大的画图模块。

首先确保已经安装python,然后用pip来安装matplotlib模块。

接着键入python -m pip install matplotlib进行自动的安装,系统会自动下载安装包。

安装完成后,可以用python -m pip list查看本机的安装的所有模块,确保matplotlib已经安装成功。

然后创建一个python文件,取名drawTree.py

在drawTree.py加入下面的代码,此代码适用于python3,网上有的代码在python3上会报错。

importmatplotlib.pyplotasplt

plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签

plt.rcParams['axes.unicode_minus']=False#用来正常显示负号

decisionNode = dict(boxstyle="sawtooth", fc="0.8") #定义文本框与箭头的格式

leafNode = dict(boxstyle="round4", fc="0.8")

arrow_args = dict(arrowstyle="

defgetNumLeafs(myTree): #获取树叶节点的数目

numLeafs = 0

firstStr = list(myTree.keys())[0]

secondDict = myTree[firstStr]

forkeyinsecondDict.keys():

iftype(secondDict[key]).__name__=='dict':#测试节点的数据类型是不是字典,如果是则就需要递归的调用getNumLeafs()函数

numLeafs += getNumLeafs(secondDict[key])

else: numLeafs +=1

returnnumLeafs

defgetTreeDepth(myTree): #获取树的深度

maxDepth = 0

firstStr = list(myTree.keys())[0]

secondDict = myTree[firstStr]

forkeyinsecondDict.keys():

iftype(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes

thisDepth = 1 + getTreeDepth(secondDict[key])

else: thisDepth = 1

ifthisDepth > maxDepth: maxDepth = thisDepth

returnmaxDepth

# 绘制带箭头的注释

defplotNode(nodeTxt, centerPt, parentPt, nodeType):

createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',

xytext=centerPt, textcoords='axes fraction',

va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

#计算父节点和子节点的中间位置,在父节点间填充文本的信息

defplotMidText(cntrPt, parentPt, txtString):

xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]

yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]

createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

# 画决策树的准备方法

defplotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on

numLeafs = getNumLeafs(myTree) #计算树的宽度

depth = getTreeDepth(myTree) #计算树的深度

firstStr = list(myTree.keys())[0] #the text label for this node should be this

cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

plotMidText(cntrPt, parentPt, nodeTxt)

plotNode(firstStr, cntrPt, parentPt, decisionNode)

secondDict = myTree[firstStr]

plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD

forkeyinsecondDict.keys():

iftype(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes

plotTree(secondDict[key],cntrPt,str(key)) #recursion

else: #it's a leaf node print the leaf node

plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW

plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)

plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))

plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

#if you do get a dictonary you know it's a tree, and the first element will be another dict

# 画决策树主方法

defcreatePlot(inTree):

fig = plt.figure(1, facecolor='white')

fig.clf()

axprops = dict(xticks=[], yticks=[])

createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks

#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses

plotTree.totalW = float(getNumLeafs(inTree))

plotTree.totalD = float(getTreeDepth(inTree))

plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;

plotTree(inTree, (0.5,1.0),'')

plt.show()

#def createPlot():

# fig = plt.figure(1, facecolor='white')

# fig.clf()

# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses

# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)

# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)

# plt.show()

defretrieveTree(i):

listOfTrees =[{'no surfacing': }}},

{'no surfacing': }, 1:'no'}}}}

]

returnlistOfTrees[i]

在dTree.py中添加drawTree.py的引用

importdrawTree

然后在主方法加调用:

drawTree.createPlot(myTree)

直接运行dTree.py,输出决策树图形

从图中可以看出本测试数据中对“长相”的要求比较高,要高于有没有钱

关注微信公众号“挨踢学霸”,获取更多人工智能技术文章

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180502A1QP3F00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券