matplotlib是python中强大的画图模块。
首先确保已经安装python,然后用pip来安装matplotlib模块。
接着键入python -m pip install matplotlib进行自动的安装,系统会自动下载安装包。
安装完成后,可以用python -m pip list查看本机的安装的所有模块,确保matplotlib已经安装成功。
然后创建一个python文件,取名drawTree.py
在drawTree.py加入下面的代码,此代码适用于python3,网上有的代码在python3上会报错。
importmatplotlib.pyplotasplt
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False#用来正常显示负号
decisionNode = dict(boxstyle="sawtooth", fc="0.8") #定义文本框与箭头的格式
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="
defgetNumLeafs(myTree): #获取树叶节点的数目
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
forkeyinsecondDict.keys():
iftype(secondDict[key]).__name__=='dict':#测试节点的数据类型是不是字典,如果是则就需要递归的调用getNumLeafs()函数
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
returnnumLeafs
defgetTreeDepth(myTree): #获取树的深度
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
forkeyinsecondDict.keys():
iftype(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
ifthisDepth > maxDepth: maxDepth = thisDepth
returnmaxDepth
# 绘制带箭头的注释
defplotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
#计算父节点和子节点的中间位置,在父节点间填充文本的信息
defplotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
# 画决策树的准备方法
defplotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) #计算树的宽度
depth = getTreeDepth(myTree) #计算树的深度
firstStr = list(myTree.keys())[0] #the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
forkeyinsecondDict.keys():
iftype(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key)) #recursion
else: #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict
# 画决策树主方法
defcreatePlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0),'')
plt.show()
#def createPlot():
# fig = plt.figure(1, facecolor='white')
# fig.clf()
# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
# plt.show()
defretrieveTree(i):
listOfTrees =[{'no surfacing': }}},
{'no surfacing': }, 1:'no'}}}}
]
returnlistOfTrees[i]
在dTree.py中添加drawTree.py的引用
importdrawTree
然后在主方法加调用:
drawTree.createPlot(myTree)
直接运行dTree.py,输出决策树图形
从图中可以看出本测试数据中对“长相”的要求比较高,要高于有没有钱
关注微信公众号“挨踢学霸”,获取更多人工智能技术文章
领取专属 10元无门槛券
私享最新 技术干货