用java生成决策树

1.数据准备,还是以之前的数据为训练数据。不过为了查询方便,我把它存进了数据库。mysql的sql代码如下:

创建:

CREATE TABLE `decision_tree` (
  `id` int(11) NOT NULL AUTO_INCREMENT,
  `age` varchar(45) DEFAULT NULL,
  `income` varchar(45) DEFAULT NULL,
  `student` varchar(45) DEFAULT NULL,
  `credit_rating` varchar(45) DEFAULT NULL,
  `buys_computer` varchar(45) DEFAULT NULL,
  PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4

insert数据,我是手工加的:

INSERT INTO `decision_tree` VALUES (1,'less30','high','no','fair','no'),(2,'less30','high','no','excellent','no'),(3,'30to40','high','no','fair','yes'),(4,'more40','medium','no','fair','yes'),(5,'more40','low','yes','fair','yes'),(6,'more40','low','yes','excellent','no'),(7,'30to40','low','yes','excellent','yes'),(8,'less30','medium','no','fair','no'),(9,'less30','low','yes','fair','yes'),(10,'more40','medium','yes','fair','yes'),(11,'less30','medium','yes','excellent','yes'),(12,'30to40','medium','no','excellent','yes'),(13,'30to40','high','yes','fair','yes'),(14,'more40','medium','no','excellent','no');


2.开发环境准备

我是用java的spring boot,集成了mybatis等,这个我在Application启动后,运行下面的DecisionService类的execute方法

package net.highersoft.ml;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.ibatis.datasource.unpooled.UnpooledDataSource;
import org.apache.ibatis.mapping.Environment;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.SqlSession;
import org.apache.ibatis.session.SqlSessionFactory;
import org.apache.ibatis.session.SqlSessionFactoryBuilder;
import org.apache.ibatis.transaction.jdbc.JdbcTransactionFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import com.github.pagehelper.util.StringUtil;

import net.highersoft.mapper.DecisionTreeMapper;

@Service
public class DecisionService {
	private static Log log = LogFactory.getLog(DecisionService.class);
	
	private SqlSession analysisSession;
	private DecisionTreeMapper decisionTreeMapper;
	final static String driver = "com.mysql.cj.jdbc.Driver";
	@Autowired
	DecisionTree decisionTree;
	
	public void execute() {
		init();
		try {
			String divCol="buys_computer";
			String[] trainCols=new String[] {"age","income","student","credit_rating"};
			queryDecision(divCol,trainCols.length,trainCols,new HashMap());
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}

	
	
	private SqlSession createSessioin(String url, String db, String userName, String password, Class cls[]) {
		UnpooledDataSource ds = new UnpooledDataSource();
		ds.setDriver(driver);
		ds.setUrl(url + db);
		ds.setUsername(userName);
		ds.setPassword(password);
		//TransactionFactory transactionFactory = new JdbcTransactionFactory();
		JdbcTransactionFactory transactionFactory = new JdbcTransactionFactory();
		Environment environment = new Environment(db, transactionFactory, ds);
		Configuration configuration = new Configuration(environment);
		configuration.setLazyLoadingEnabled(true);
		// configuration.setUseActualParamName(false); // to test legacy style reference
		// (#{0} #{1})
		// configuration.getTypeAliasRegistry().registerAlias(TestMybatis.class);
		
		for (Class mapper : cls) {
			configuration.addMapper(mapper);
		}
		SqlSessionFactory sqlSessionFactory = new SqlSessionFactoryBuilder().build(configuration);
		//true事务自动提交,等同于Connection.setAutoCommit(true); 默认空是false
		return sqlSessionFactory.openSession(true);
		//return sqlSessionFactory.openSession();
	}

	public void init() {
		
		analysisSession = createSessioin("jdbc:mysql://localhost:3306/", "data_analysis", "xxx", "xxx",
				new Class[] { DecisionTreeMapper.class });
		decisionTreeMapper = analysisSession.getMapper(DecisionTreeMapper.class);
	}
	
	

	private Map<String,Integer> getTypeNum(Map<String,Map> distinctType){
		Map<String,Integer> typeNum=new HashMap<>();
		for(Entry<String,Map> entry:distinctType.entrySet()) {
			typeNum.put(entry.getKey(), Integer.valueOf(String.valueOf(entry.getValue().get("num"))));
		}
		return typeNum;
	}
	private Map<String,Map<String,Integer>> getSubTypeNum(Map<String,Integer> distinctType,String chooseCol,String targCol,String filterSql){
		Map<String,Map<String,Integer>> subTypeNum=new HashMap<>();
		for(Entry<String,Integer> entry:distinctType.entrySet()) {
			Map<String,Map> dimTypeMap=decisionTreeMapper.querySubTargetType(chooseCol,entry.getKey(),targCol,filterSql);
			
			Map<String,Integer> subMap=new HashMap<>();
			for(Entry<String,Map> subEntry:dimTypeMap.entrySet()) {
				subMap.put(subEntry.getKey(), Integer.valueOf(String.valueOf(subEntry.getValue().get("num"))));
			}
			subTypeNum.put(entry.getKey(), subMap);
		}
		return subTypeNum;
	}
	
	private double computDim(int total,Map<String,Integer> subDimNum,Map<String,Map<String,Integer>> evereyPartNum) {
		double sum=0;
		for(Entry<String,Integer> entry:subDimNum.entrySet()) {
			Collection<Integer> vals=evereyPartNum.get(entry.getKey()).values();
			double subDim=decisionTree.computDim(total,entry.getValue(),vals.toArray(new Integer[vals.size()]));
			sum+=subDim;
		}
		return sum;
	}
	
	
	/**
	 * 
	 * @param divCol 目标列(预测列)(是否买电脑列)
	 * @param trainCols 
	 * @param preCond 前置过滤条件
	 */
	private void queryDecision(String divCol,int colSize,String[] trainCols,Map<String,Object> preCond) {
		//String divCol="buys_computer";
		//String[] trainCols=new String[] {"age","income","student","credit_rating"};
		String filterSql=getFilterSql(preCond);
		filterSql=StringUtil.isEmpty(filterSql)?"":" where "+filterSql;
		int totalNum = decisionTreeMapper.queryTotal(filterSql);
		Map<String,Integer> typeNum=getTypeNum(decisionTreeMapper.queryTargetType(divCol,filterSql));
		//System.out.println(totalNum);
		//System.out.println(typeNum);
		
		double decisionTotal=decisionTree.computTotal(totalNum, typeNum.values().toArray(new Integer[typeNum.size()]));
		//System.out.println(decisionTotal);
		if(decisionTotal==0) {
			System.out.println("信息增益为0,找到根结点,前置条件:"+preCond);
			return;
		}
		double maxGain=0;
		String finalChooseCol="";
		for(String trainCol:trainCols) {
			Map<String,Integer> trainColNum=getTypeNum(decisionTreeMapper.queryTargetType(trainCol,filterSql));
			Map<String,Map<String,Integer>> ageEvereyNum=getSubTypeNum(trainColNum, trainCol, divCol,filterSql);
			double ageShang=computDim(totalNum,trainColNum,ageEvereyNum);
			//System.out.println(trainColNum);
			//System.out.println(ageEvereyNum);
			double gain=decisionTotal-ageShang;
			if(gain>maxGain) {
				maxGain=gain;
				finalChooseCol=trainCol;
			}
			System.out.println("Gain("+trainCol+"):"+gain);
			
		}
		
		System.out.println("选择列:"+finalChooseCol+",前置条件:"+preCond);
		if(preCond.size()==colSize-1||StringUtils.isBlank(finalChooseCol)) {
			return;
		}
		
		//下一层开始
		Map<String,Integer> nextTypeNum=getTypeNum(decisionTreeMapper.queryTargetType(finalChooseCol,filterSql));
		for(Entry<String,Integer> entry:nextTypeNum.entrySet()) {
			Map<String,Object> nextPreCond=new HashMap<>(preCond);
			nextPreCond.put(finalChooseCol, entry.getKey());
			List<String> nextTranCols=new ArrayList(Arrays.asList(trainCols));
			
			nextTranCols.remove(finalChooseCol);
			queryDecision(divCol,colSize,nextTranCols.toArray(new String[nextTranCols.size()]),nextPreCond);
		}
		
	}
	
	private String getFilterSql(Map<String,Object> preCond) {
		StringBuffer sb=new StringBuffer();
		for(Entry<String,Object> entry:preCond.entrySet()) {
			if(sb.length()>0) {
				sb.append(" and ");
			}
			sb.append(entry.getKey()+"=");
			if(entry.getValue() instanceof String) {
				sb.append("'"+entry.getValue()+"'");
			}
		}
		return sb.toString();
		
	}

	

	

}


这个类会依赖的数据库查询的Mapper:

package net.highersoft.mapper;

import java.util.Map;

import org.apache.ibatis.annotations.CacheNamespace;
import org.apache.ibatis.annotations.MapKey;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Select;

@CacheNamespace(readWrite = false)
public interface DecisionTreeMapper {
	
	

	@Select({"select count(*) from decision_tree ${filterSql}"})
	public int queryTotal(@Param("filterSql")String filterSql);
	
	@Select({"select  ${col} dimType,count(*) num from decision_tree ${filterSql} group by ${col}"})
	@MapKey("dimType")
	public Map<String,Map> queryTargetType(@Param("col")String col,@Param("filterSql")String filterSql);
	
	
	@Select({"select  ${targCol} dimType,count(*) num from (select *  from decision_tree ${filterSql} ) a where ${chooseCol}=#{chooseVal} group by ${targCol}"})
	@MapKey("dimType")
	public Map<String,Map> querySubTargetType(@Param("chooseCol") String chooseCol,@Param("chooseVal")String chooseVal,@Param("targCol")String targCol,@Param("filterSql")String filterSql);
	
	
	
	
	


	
}


最后,代码输出是这样的:

Gain(age):0.2467498197744391
Gain(income):0.02922256565895487
Gain(student):0.15183550136234159
Gain(credit_rating):0.0481270304082696
选择列:age,前置条件:{}
Gain(income):0.019973094021975002
Gain(student):0.019973094021975002
Gain(credit_rating):0.9709505944546687
选择列:credit_rating,前置条件:{age=more40}
信息增益为0,找到根结点,前置条件:{age=more40, credit_rating=excellent}
信息增益为0,找到根结点,前置条件:{age=more40, credit_rating=fair}
信息增益为0,找到根结点,前置条件:{age=30to40}
Gain(income):0.5709505944546687
Gain(student):0.9709505944546687
Gain(credit_rating):0.019973094021975002
选择列:student,前置条件:{age=less30}
信息增益为0,找到根结点,前置条件:{student=no, age=less30}
信息增益为0,找到根结点,前置条件:{student=yes, age=less30}


根据这个信息,就可以画出决策树的图了:


文/程忠 浏览次数:0次   2020-11-03 19:57:47

相关阅读


评论: