通过熵求决策树的根结点

原数据是这样:


任务是这样,我们要用决策树来算出某行记录是否"buy_computer"。

本文只计算决策树的根结点,用那个属性做为根结点呢?这需用到熵。先计算age,income等每个属性的熵。

以age为例,上图总记录数是14,"<=30"的记录数为5,这5个中有3个"buy_computer"是"no",有两2个是"yes"。

那么可以套入熵公式计算了。下面用java代码来实现下:

package net.highersoft.ml;

public class DecisionTree {
	
	public static void main(String[] args) {
		double total=9.0/14*(Math.log(14.0/9)/Math.log(2))+5.0/14*(Math.log(14.0/5)/Math.log(2));
		System.out.println("total:"+total);
		
		double lessThan30=5.0/14*(2.0/5*(Math.log(5.0/2)/Math.log(2))+3.0/5*(Math.log(5.0/3)/Math.log(2)));
		double thirTo40=4.0/14*(4.0/4*(Math.log(4.0/4)/Math.log(2))+0.0);
		double moreThan40=5.0/14*(3.0/5*(Math.log(5.0/3)/Math.log(2))+2.0/5*(Math.log(5.0/2)/Math.log(2)));
		double ageShang=(lessThan30+thirTo40+moreThan40);
		System.out.println("age熵:"+ageShang);
		System.out.println("age的信息增益为:"+(total-ageShang));
		
		
		double incomeLow= comput(14,4,new int[] {3,1});
		double incomeMid=comput(14,6,new int[] {4,2});
		double incomeHigh=comput(14,4,new int[] {2,2});
		double incomeShang=(incomeLow+incomeMid+incomeHigh);
		System.out.println("income熵:"+incomeShang);
		System.out.println("income的信息增益为:"+(total-incomeShang));
		
		
		double studentYes=comput(14,7,new int[] {6,1});
		double studentNo=comput(14,7,new int[] {3,4});
		double studentShang=studentYes+studentNo;
		System.out.println("student熵:"+studentShang);
		System.out.println("student的信息增益为:"+(total-studentShang));
		
		
		double creditFair=comput(14,8,new int[] {6,2}) ;
		double creditEx=comput(14,6,new int[] {3,3});
		double creditShang=creditFair+creditEx;
		System.out.println("credit熵:"+creditShang);
		System.out.println("credit的信息增益为:"+(total-creditShang));
		
		
		
	}
	
	public static double comput(int total,int part,int[] memeber) {
		double ratPer=part*1.0/total;
		double sum=0;
		for(int i=0;i<memeber.length;i++) {
			sum+=memeber[i]*1.0/part*(Math.log(part*1.0/memeber[i])/Math.log(2));
		}
		return ratPer*sum;
	}

}

代码计算lessThan30等的式子太长了,我就搞错了某些括号,导致某些结果与教材不一致,困扰我一整天。所以后面的计算用了个函数封装了下。

计算结果如下:

total:0.940285958670631
age熵:0.6935361388961919
age的信息增益为:0.2467498197744391
income熵:0.9110633930116762
income的信息增益为:0.02922256565895487
student熵:0.7884504573082894
student的信息增益为:0.15183550136234159
credit熵:0.8921589282623614
credit的信息增益为:0.0481270304082696

java计算log有点麻烦,下面用python来实现下,清爽得多:

import math

incomeLow=4.0/14*(3.0/4*math.log(4.0/3,2)+1.0/4*math.log(4/1,2))

incomeMid=6.0/14*(4.0/6*math.log(6.0/4,2)+2.0/6*math.log(6.0/2,2))

incomeHigh=4.0/14*(2.0/4*math.log(4.0/2,2)+2.0/4*math.log(4.0/2,2))

print(incomeLow)
print(incomeMid)
print(incomeHigh)

print(0.94-incomeLow-incomeMid-incomeHigh)
print("-------------")
mid1=4.0/6*math.log(6.0/4,2)
mid2=2.0/6*math.log(6.0/2,2)
print(mid1)
print(mid2)

print(6.0/14*(mid1+mid2))



最后计算信息增益最大的是age(0.2467498197744391)。它就是根结点了。这里信息增益可能有正有负,我感觉得取下绝对值,我点我再找资料验证下。

文/程忠 浏览次数:0次   2020-11-03 08:43:48

相关阅读


评论: