用java生成决策树
1.数据准备,还是以之前的数据为训练数据。不过为了查询方便,我把它存进了数据库。mysql的sql代码如下:
创建:
CREATE TABLE `decision_tree` ( `id` int(11) NOT NULL AUTO_INCREMENT, `age` varchar(45) DEFAULT NULL, `income` varchar(45) DEFAULT NULL, `student` varchar(45) DEFAULT NULL, `credit_rating` varchar(45) DEFAULT NULL, `buys_computer` varchar(45) DEFAULT NULL, PRIMARY KEY (`id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
insert数据,我是手工加的:
INSERT INTO `decision_tree` VALUES (1,'less30','high','no','fair','no'),(2,'less30','high','no','excellent','no'),(3,'30to40','high','no','fair','yes'),(4,'more40','medium','no','fair','yes'),(5,'more40','low','yes','fair','yes'),(6,'more40','low','yes','excellent','no'),(7,'30to40','low','yes','excellent','yes'),(8,'less30','medium','no','fair','no'),(9,'less30','low','yes','fair','yes'),(10,'more40','medium','yes','fair','yes'),(11,'less30','medium','yes','excellent','yes'),(12,'30to40','medium','no','excellent','yes'),(13,'30to40','high','yes','fair','yes'),(14,'more40','medium','no','excellent','no');
如下图:
2.开发环境准备
我是用java的spring boot,集成了mybatis等,这个我在Application启动后,运行下面的DecisionService类的execute方法
package net.highersoft.ml; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import org.apache.commons.lang3.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.ibatis.datasource.unpooled.UnpooledDataSource; import org.apache.ibatis.mapping.Environment; import org.apache.ibatis.session.Configuration; import org.apache.ibatis.session.SqlSession; import org.apache.ibatis.session.SqlSessionFactory; import org.apache.ibatis.session.SqlSessionFactoryBuilder; import org.apache.ibatis.transaction.jdbc.JdbcTransactionFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import com.github.pagehelper.util.StringUtil; import net.highersoft.mapper.DecisionTreeMapper; @Service public class DecisionService { private static Log log = LogFactory.getLog(DecisionService.class); private SqlSession analysisSession; private DecisionTreeMapper decisionTreeMapper; final static String driver = "com.mysql.cj.jdbc.Driver"; @Autowired DecisionTree decisionTree; public void execute() { init(); try { String divCol="buys_computer"; String[] trainCols=new String[] {"age","income","student","credit_rating"}; queryDecision(divCol,trainCols.length,trainCols,new HashMap()); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } } private SqlSession createSessioin(String url, String db, String userName, String password, Class cls[]) { UnpooledDataSource ds = new UnpooledDataSource(); ds.setDriver(driver); ds.setUrl(url + db); ds.setUsername(userName); ds.setPassword(password); //TransactionFactory transactionFactory = new JdbcTransactionFactory(); JdbcTransactionFactory transactionFactory = new JdbcTransactionFactory(); Environment environment = new Environment(db, transactionFactory, ds); Configuration configuration = new Configuration(environment); configuration.setLazyLoadingEnabled(true); // configuration.setUseActualParamName(false); // to test legacy style reference // (#{0} #{1}) // configuration.getTypeAliasRegistry().registerAlias(TestMybatis.class); for (Class mapper : cls) { configuration.addMapper(mapper); } SqlSessionFactory sqlSessionFactory = new SqlSessionFactoryBuilder().build(configuration); //true事务自动提交,等同于Connection.setAutoCommit(true); 默认空是false return sqlSessionFactory.openSession(true); //return sqlSessionFactory.openSession(); } public void init() { analysisSession = createSessioin("jdbc:mysql://localhost:3306/", "data_analysis", "xxx", "xxx", new Class[] { DecisionTreeMapper.class }); decisionTreeMapper = analysisSession.getMapper(DecisionTreeMapper.class); } private Map<String,Integer> getTypeNum(Map<String,Map> distinctType){ Map<String,Integer> typeNum=new HashMap<>(); for(Entry<String,Map> entry:distinctType.entrySet()) { typeNum.put(entry.getKey(), Integer.valueOf(String.valueOf(entry.getValue().get("num")))); } return typeNum; } private Map<String,Map<String,Integer>> getSubTypeNum(Map<String,Integer> distinctType,String chooseCol,String targCol,String filterSql){ Map<String,Map<String,Integer>> subTypeNum=new HashMap<>(); for(Entry<String,Integer> entry:distinctType.entrySet()) { Map<String,Map> dimTypeMap=decisionTreeMapper.querySubTargetType(chooseCol,entry.getKey(),targCol,filterSql); Map<String,Integer> subMap=new HashMap<>(); for(Entry<String,Map> subEntry:dimTypeMap.entrySet()) { subMap.put(subEntry.getKey(), Integer.valueOf(String.valueOf(subEntry.getValue().get("num")))); } subTypeNum.put(entry.getKey(), subMap); } return subTypeNum; } private double computDim(int total,Map<String,Integer> subDimNum,Map<String,Map<String,Integer>> evereyPartNum) { double sum=0; for(Entry<String,Integer> entry:subDimNum.entrySet()) { Collection<Integer> vals=evereyPartNum.get(entry.getKey()).values(); double subDim=decisionTree.computDim(total,entry.getValue(),vals.toArray(new Integer[vals.size()])); sum+=subDim; } return sum; } /** * * @param divCol 目标列(预测列)(是否买电脑列) * @param trainCols * @param preCond 前置过滤条件 */ private void queryDecision(String divCol,int colSize,String[] trainCols,Map<String,Object> preCond) { //String divCol="buys_computer"; //String[] trainCols=new String[] {"age","income","student","credit_rating"}; String filterSql=getFilterSql(preCond); filterSql=StringUtil.isEmpty(filterSql)?"":" where "+filterSql; int totalNum = decisionTreeMapper.queryTotal(filterSql); Map<String,Integer> typeNum=getTypeNum(decisionTreeMapper.queryTargetType(divCol,filterSql)); //System.out.println(totalNum); //System.out.println(typeNum); double decisionTotal=decisionTree.computTotal(totalNum, typeNum.values().toArray(new Integer[typeNum.size()])); //System.out.println(decisionTotal); if(decisionTotal==0) { System.out.println("信息增益为0,找到根结点,前置条件:"+preCond); return; } double maxGain=0; String finalChooseCol=""; for(String trainCol:trainCols) { Map<String,Integer> trainColNum=getTypeNum(decisionTreeMapper.queryTargetType(trainCol,filterSql)); Map<String,Map<String,Integer>> ageEvereyNum=getSubTypeNum(trainColNum, trainCol, divCol,filterSql); double ageShang=computDim(totalNum,trainColNum,ageEvereyNum); //System.out.println(trainColNum); //System.out.println(ageEvereyNum); double gain=decisionTotal-ageShang; if(gain>maxGain) { maxGain=gain; finalChooseCol=trainCol; } System.out.println("Gain("+trainCol+"):"+gain); } System.out.println("选择列:"+finalChooseCol+",前置条件:"+preCond); if(preCond.size()==colSize-1||StringUtils.isBlank(finalChooseCol)) { return; } //下一层开始 Map<String,Integer> nextTypeNum=getTypeNum(decisionTreeMapper.queryTargetType(finalChooseCol,filterSql)); for(Entry<String,Integer> entry:nextTypeNum.entrySet()) { Map<String,Object> nextPreCond=new HashMap<>(preCond); nextPreCond.put(finalChooseCol, entry.getKey()); List<String> nextTranCols=new ArrayList(Arrays.asList(trainCols)); nextTranCols.remove(finalChooseCol); queryDecision(divCol,colSize,nextTranCols.toArray(new String[nextTranCols.size()]),nextPreCond); } } private String getFilterSql(Map<String,Object> preCond) { StringBuffer sb=new StringBuffer(); for(Entry<String,Object> entry:preCond.entrySet()) { if(sb.length()>0) { sb.append(" and "); } sb.append(entry.getKey()+"="); if(entry.getValue() instanceof String) { sb.append("'"+entry.getValue()+"'"); } } return sb.toString(); } }
这个类会依赖的数据库查询的Mapper:
package net.highersoft.mapper; import java.util.Map; import org.apache.ibatis.annotations.CacheNamespace; import org.apache.ibatis.annotations.MapKey; import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.Select; @CacheNamespace(readWrite = false) public interface DecisionTreeMapper { @Select({"select count(*) from decision_tree ${filterSql}"}) public int queryTotal(@Param("filterSql")String filterSql); @Select({"select ${col} dimType,count(*) num from decision_tree ${filterSql} group by ${col}"}) @MapKey("dimType") public Map<String,Map> queryTargetType(@Param("col")String col,@Param("filterSql")String filterSql); @Select({"select ${targCol} dimType,count(*) num from (select * from decision_tree ${filterSql} ) a where ${chooseCol}=#{chooseVal} group by ${targCol}"}) @MapKey("dimType") public Map<String,Map> querySubTargetType(@Param("chooseCol") String chooseCol,@Param("chooseVal")String chooseVal,@Param("targCol")String targCol,@Param("filterSql")String filterSql); }
最后,代码输出是这样的:
Gain(age):0.2467498197744391 Gain(income):0.02922256565895487 Gain(student):0.15183550136234159 Gain(credit_rating):0.0481270304082696 选择列:age,前置条件:{} Gain(income):0.019973094021975002 Gain(student):0.019973094021975002 Gain(credit_rating):0.9709505944546687 选择列:credit_rating,前置条件:{age=more40} 信息增益为0,找到叶结点,前置条件:{age=more40, credit_rating=excellent},各分类数量:{"no":{"num":2,"divCol":"no"}} 信息增益为0,找到叶结点,前置条件:{age=more40, credit_rating=fair},各分类数量:{"yes":{"num":3,"divCol":"yes"}} 信息增益为0,找到叶结点,前置条件:{age=30to40},各分类数量:{"yes":{"num":4,"divCol":"yes"}} Gain(income):0.5709505944546687 Gain(student):0.9709505944546687 Gain(credit_rating):0.019973094021975002 选择列:student,前置条件:{age=less30} 信息增益为0,找到叶结点,前置条件:{age=less30, student=no},各分类数量:{"no":{"num":3,"divCol":"no"}} 信息增益为0,找到叶结点,前置条件:{age=less30, student=yes},各分类数量:{"yes":{"num":2,"divCol":"yes"}}
根据这个信息,就可以画出决策树的图了:
理论补充:http://www.highersoft.net/html/notice/notice_900.html
相关阅读
评论:
↓ 广告开始-头部带绿为生活 ↓
↑ 广告结束-尾部支持多点击 ↑