通过熵求决策树的根结点
原数据是这样:
任务是这样,我们要用决策树来算出某行记录是否"buy_computer"。
本文只计算决策树的根结点,用那个属性做为根结点呢?这需用到熵。先计算age,income等每个属性的熵。
以age为例,上图总记录数是14,"<=30"的记录数为5,这5个中有3个"buy_computer"是"no",有两2个是"yes"。
那么可以套入熵公式计算了。下面用java代码来实现下:
package net.highersoft.ml; public class DecisionTree { public static void main(String[] args) { double total=9.0/14*(Math.log(14.0/9)/Math.log(2))+5.0/14*(Math.log(14.0/5)/Math.log(2)); System.out.println("total:"+total); double lessThan30=5.0/14*(2.0/5*(Math.log(5.0/2)/Math.log(2))+3.0/5*(Math.log(5.0/3)/Math.log(2))); double thirTo40=4.0/14*(4.0/4*(Math.log(4.0/4)/Math.log(2))+0.0); double moreThan40=5.0/14*(3.0/5*(Math.log(5.0/3)/Math.log(2))+2.0/5*(Math.log(5.0/2)/Math.log(2))); double ageShang=(lessThan30+thirTo40+moreThan40); System.out.println("age熵:"+ageShang); System.out.println("age的信息增益为:"+(total-ageShang)); double incomeLow= comput(14,4,new int[] {3,1}); double incomeMid=comput(14,6,new int[] {4,2}); double incomeHigh=comput(14,4,new int[] {2,2}); double incomeShang=(incomeLow+incomeMid+incomeHigh); System.out.println("income熵:"+incomeShang); System.out.println("income的信息增益为:"+(total-incomeShang)); double studentYes=comput(14,7,new int[] {6,1}); double studentNo=comput(14,7,new int[] {3,4}); double studentShang=studentYes+studentNo; System.out.println("student熵:"+studentShang); System.out.println("student的信息增益为:"+(total-studentShang)); double creditFair=comput(14,8,new int[] {6,2}) ; double creditEx=comput(14,6,new int[] {3,3}); double creditShang=creditFair+creditEx; System.out.println("credit熵:"+creditShang); System.out.println("credit的信息增益为:"+(total-creditShang)); } public static double comput(int total,int part,int[] memeber) { double ratPer=part*1.0/total; double sum=0; for(int i=0;i<memeber.length;i++) { sum+=memeber[i]*1.0/part*(Math.log(part*1.0/memeber[i])/Math.log(2)); } return ratPer*sum; } }
代码计算lessThan30等的式子太长了,我就搞错了某些括号,导致某些结果与教材不一致,困扰我一整天。所以后面的计算用了个函数封装了下。
计算结果如下:
total:0.940285958670631 age熵:0.6935361388961919 age的信息增益为:0.2467498197744391 income熵:0.9110633930116762 income的信息增益为:0.02922256565895487 student熵:0.7884504573082894 student的信息增益为:0.15183550136234159 credit熵:0.8921589282623614 credit的信息增益为:0.0481270304082696
java计算log有点麻烦,下面用python来实现下,清爽得多:
import math incomeLow=4.0/14*(3.0/4*math.log(4.0/3,2)+1.0/4*math.log(4/1,2)) incomeMid=6.0/14*(4.0/6*math.log(6.0/4,2)+2.0/6*math.log(6.0/2,2)) incomeHigh=4.0/14*(2.0/4*math.log(4.0/2,2)+2.0/4*math.log(4.0/2,2)) print(incomeLow) print(incomeMid) print(incomeHigh) print(0.94-incomeLow-incomeMid-incomeHigh) print("-------------") mid1=4.0/6*math.log(6.0/4,2) mid2=2.0/6*math.log(6.0/2,2) print(mid1) print(mid2) print(6.0/14*(mid1+mid2))
最后计算信息增益最大的是age(0.2467498197744391)。它就是根结点了。这里信息增益可能有正有负,我感觉得取下绝对值,我点我再找资料验证下。
相关阅读
评论:
↓ 广告开始-头部带绿为生活 ↓
↑ 广告结束-尾部支持多点击 ↑