求问一个python的画图问题
本帖最后由 wangchen0823 于 2022-7-1 19:05 编辑就是画一个决策树的问题,现在我已经得到了一个字典,要可视化变成一棵树。然后由于字典太大了,树的内容就挤一起了。
像下面这样。
就想求问下代码怎么改进能让其变得不这么臃肿
下面这个是画图的代码
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args=dict(arrowstyle="<-")
def plotNode(nodeText,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeText,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",
ha="center",bbox=nodeType,arrowprops=arrow_args)
def getNumLeafs(myTree):
numLeafs=0
firstStr=list(myTree.keys())
secondDict=myTree
for key in secondDict.keys():
if type(secondDict).__name__=='dict':
numLeafs+=getNumLeafs(secondDict)
else:
numLeafs+=1
return numLeafs
def getTreeDepth(myTree):
maxDepth=0
firstStr=list(myTree.keys())
secondDict=myTree
for key in secondDict.keys():
if type(secondDict).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict)
else:thisDepth=1
if thisDepth>maxDepth:maxDepth=thisDepth
return maxDepth
def plotMidText(cntrPt,parentPt,txtString):
xMid=(parentPt-cntrPt)/2.0+cntrPt
yMid=(parentPt-cntrPt)/2.0+cntrPt
createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
numLeafs=getNumLeafs(myTree)
depth=getTreeDepth(myTree)
firstStr=list(myTree.keys())
cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree
plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict).__name__=='dict':
plotTree(secondDict,cntrPt,str(key))
else:
plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
plotNode(secondDict,(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
fig.tight_layout()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD=float(getTreeDepth(inTree))
plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
plotTree(inTree,(-0.5,1.0),'')
plt.show() 我把你的代码复制下来运行看一下
页:
[1]