[Python] 纯文本查看 复制代码
# coding=utf-8
from math import log
import copy
# --------------------------------------------------------
# 信息初始化 -
# --------------------------------------------------------
def InputData():
initData = [
[0,0,0,'no'],
[0,1,1,'yes'],
[1,0,0,'yes'],
[0,1,1,'yes'],
[0,1,0,'yes'],
[1,0,1,'yes'],
[0,0,1,'yes'],
[0,1,0,'yes'],
[1,0,1,'yes'],
[1,1,0,'yes'],
[0,0,0,'no'],
]
dataLables = ['feature_A','feature_B','feature_C']
return initData,dataLables
initData,dataLables = InputData()
# --------------------------------------------------------
# 计算当前数据集的信息熵 -
# --------------------------------------------------------
def CalcEntropy(initData):
newDataLen = len(initData)
resultList = [row[-1] for row in initData]
resultType = set(resultList)
entropy = 0
for x in resultType:
pi = float(resultList.count(x)) / newDataLen
entropy -= pi*log(pi,2)
return entropy
# print(CalcEntropy(initData))
# --------------------------------------------------------
# 分割后,初始数据集不应该被改变 -
# --------------------------------------------------------
def SplitData(initData,featureNum,featureValue):
newData = []
dataCopy = copy.deepcopy(initData)
for row in dataCopy:
if row[featureNum] == featureValue:
del row[featureNum]
newData.append(row)
return newData
# print(initData)
# print(SplitData(initData,1,0))
# print(initData)
# --------------------------------------------------------
# 选取最佳标志 -
# 选取信息增益最大,即条件熵最小的属性作为分裂属性 -
# --------------------------------------------------------
def ChooseFeature(initData):
dataLen = len(initData[0]) - 1
elemNum = len(initData)
entropy = 0
biggestEntropy = 10000
bestFeatureNum = 0
for i in range(0,dataLen):
featureSet = set([row[i] for row in initData])
for j in featureSet:
subData = SplitData(initData,i,j)
subEntropy = CalcEntropy(subData)
entropy += float(len(subData)) / elemNum * subEntropy
if entropy < biggestEntropy:
bestFeatureNum = i
biggestEntropy = entropy
return bestFeatureNum
# --------------------------------------------------------
# 获取list中出现最多次的元素 -
# --------------------------------------------------------
def MaxList(lt):
temp = 0
for i in lt:
if lt.count(i) > temp:
max_str = i
temp = lt.count(i)
return max_str
# ------------------------------------------------------------
# 递归方式构建dict树。终止条件有两种:1、当前数据集中所有结果一致 -
# 2、当前数据集虽然结果不一致,但已无继续判断的条件/lable,此时返 -
# 回当前数据集最后一列出现最多次的元素。另外,lables也要做深拷贝 -
# 因为当for循环遍历当前数据集子集时,lable的初始值应该是相同的 -
# -----------------------------------------------------------
def CreateTree(data,lables):
# print(data)
resultList = [x[-1] for x in data]
if resultList.count(resultList[-1]) == len(resultList):
return resultList[-1]
elif len(data[0]) == 1:
return MaxList(resultList)
bestFeatureNum = ChooseFeature(data)
bestFeatureName = lables[bestFeatureNum]
tree = {bestFeatureName:{}}
lables.remove(bestFeatureName)
featureSet = set([row[bestFeatureNum] for row in data])
for value in featureSet:
subLables = copy.deepcopy(lables)
# value作为树枝,将前后两次迭代的两个tree连接起来
tree[bestFeatureName][value] = CreateTree(SplitData(data,bestFeatureNum,value),subLables)
return tree
print(CreateTree(initData,dataLables))