决策树-属性选择
现在,我们要做的是进行属性(或者说特征)的选择
光看程序清单3-2,以及把数组带进去运行一遍可能也有点不清晰,最好先看一下西瓜书
然后意思是传进去一个数据集,对于某一列(axis=0表示第1列),如果为0(value=0),那么保留这一行但是不要这个属性对应的值
import shannonEnt
dataSet = shannonEnt.dataSet
labelSet = shannonEnt.labelSet
def splitDataSet(dataSet, axis, value):
featureDataSet = []
for featureVec in dataSet:
if featureVec[axis] == value:
tempVec = featureVec[:axis]
tempVec.extend(featureVec[axis + 1:])
featureDataSet.append(tempVec)
return featureDataSet
a = splitDataSet(dataSet, 0, 0)
b = splitDataSet(dataSet, 0, 1)
print(a)
print(b)
[[1, 'no'], [1, 'no']]
[[1, 'yes'], [1, 'yes'], [0, 'no']]
这里a是对第0列取值为0的行进行了处理,b是对第0列取值为1的行进行了处理
这里想更简单点的话,用pd去掉某一列,然后再算比例也可以
接着书上使用ID3进行属性选择
思路如下:
- 首先计算总的Ent,得到总共有2个属性
- 然后对于2个属性进行遍历,对于第1个属性,得到其对应的属性取值为[1, 1, 1, 0, 0]
- 那么对于刚刚得到的[1, 1, 1, 0, 0],我们知道有两种取值,用set得到列表[1,0]
- 从这两个取值中再去用刚刚写好的splitDataSet函数得到1对应的子集,以及0对应的子集
- 这里我们能知道1对应的子集个数为3,那么由西瓜书公式4.2去计算$sigma$和Ent
结合起来代码如下
import shannonEnt
dataSet = shannonEnt.dataSet
labelSet = shannonEnt.labelSet
def splitDataSet(dataSet, axis, value):
featureDataSet = []
for featureVec in dataSet:
if featureVec[axis] == value:
tempVec = featureVec[:axis]
tempVec.extend(featureVec[axis + 1:])
featureDataSet.append(tempVec)
return featureDataSet
# a = splitDataSet(dataSet, 0, 0)
# b = splitDataSet(dataSet, 0, 1)
# print(a)
# print(b)
def bestFeature(dataSet):
# 获得特征(属性)个数,这里为2
featureNum = len(dataSet[0]) - 1
# 按西瓜书来看,计算Ent(D)
totalEntropy = shannonEnt.calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(featureNum):
# 匿名函数,[1, 1, 1, 0, 0],[1, 1, 0, 1, 1]
# 即获得每个属性对应的列向量
featList = [example[i] for example in dataSet]
# print(featList)
# 知道每个属性可能有的取值
uniqueVals = set(featList)
# print(uniqueVals)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet)/len(dataSet)
newEntropy += prob*shannonEnt.calcShannonEnt(subDataSet)
infoGain = totalEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
if __name__ == "__main__":
result = bestFeature(dataSet)
print(result)
0
这里的内容,主要是要对ID3算法比较熟,可以结合西瓜书多看几遍