wangchen0823 发表于 2022-7-1 18:22

求问一个python的画图问题

本帖最后由 wangchen0823 于 2022-7-1 19:05 编辑

就是画一个决策树的问题,现在我已经得到了一个字典,要可视化变成一棵树。然后由于字典太大了,树的内容就挤一起了。

像下面这样。


就想求问下代码怎么改进能让其变得不这么臃肿

下面这个是画图的代码

decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args=dict(arrowstyle="<-")

def plotNode(nodeText,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeText,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",
                            ha="center",bbox=nodeType,arrowprops=arrow_args)
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=list(myTree.keys())
    secondDict=myTree
    for key in secondDict.keys():
      if type(secondDict).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict)
      else:
            numLeafs+=1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())
    secondDict=myTree
    for key in secondDict.keys():
      if type(secondDict).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict)
      else:thisDepth=1
      if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth
def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt-cntrPt)/2.0+cntrPt
    yMid=(parentPt-cntrPt)/2.0+cntrPt
    createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    for key in secondDict.keys():
      if type(secondDict).__name__=='dict':
            plotTree(secondDict,cntrPt,str(key))
      else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict,(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
   
    fig.clf()
    fig.tight_layout()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
    plotTree(inTree,(-0.5,1.0),'')
    plt.show()

perfectddk 发表于 2022-7-2 10:00

我把你的代码复制下来运行看一下
页: [1]
查看完整版本: 求问一个python的画图问题