weka的训练与预测

一、训练

1.准备数据文件,如这样的arff文件

@relation decisiontree_weka

@attribute age {'less30','30to40','more40'} 
@attribute income {'high','medium','low'}
@attribute student {'yes','no'}
@attribute credit_rating {'fair','excellent'}
@attribute 'buys_computer' {'yes','no'}

@data
less30,high,no,fair,no
less30,high,no,excellent,no
30to40,high,no,fair,yes
more40,medium,no,fair,yes
more40,low,yes,fair,yes
more40,low,yes,excellent,no
30to40,low,yes,excellent,yes
less30,medium,no,fair,no
less30,low,yes,fair,yes
more40,medium,yes,fair,yes
less30,medium,yes,excellent,yes
30to40,medium,no,excellent,yes
30to40,high,yes,fair,yes
more40,medium,no,excellent,no
2.打开GUI,选择参数训练

Explorer -> Preprocess ->Open File

选择刚才的文件



训练:

Classify -> Choose选择多层网络 ->

选择start训练

右侧窗口可右键 -> save model 保存模型文件



使用代码调用模型,预测分类:

package net.highersoft.weka;

import java.io.File;
import java.io.FileInputStream;

import com.google.gson.Gson;

import weka.classifiers.Classifier;
import weka.classifiers.functions.MultilayerPerceptron;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.core.converters.ConverterUtils.DataSource;

public class TestMultilayerPerceptron {
	Instances m_Data = null;
	Classifier m_Classifier = null;
	Instances ins=null;

	public void init() throws Exception {
		File inputFile = new File("/Users/chengzhong/code/eclipse-workspace-highersoft/wekamodel/file/bp.model");
		m_Classifier = (MultilayerPerceptron) SerializationHelper.read(new FileInputStream(inputFile));
		String file = "/Users/chengzhong/code/eclipse-workspace-highersoft/wekamodel/file/dt3_test.arff";
		ins = DataSource.read(file);

		// Instances data = m_Classifier.getDataSet();

	}

	public void classifyMessage() throws Exception{

		int index=2;
		Instance instance = makeInstance();
		//Instance instance = makeInstance2(index);

		double predicted = m_Classifier.classifyInstance(instance);

		System.out.println("index:"+index+" predicted:" + predicted);

		double[] vals = m_Classifier.distributionForInstance(instance);
		
		System.out.print("instance:" );
		
		for(int i=0;i<instance.numAttributes();i++) {
			
			//instance.attribute(i).indexOfValue(instance.attribute(i).name())
			System.out.print(instance.attribute(i).name()+":"+ instance.stringValue(instance.attribute(i)) +"\t");
		}
		System.out.println();
		
		System.out.println("distribution predicted:" + new Gson().toJson(vals));

		System.out.println("real class:" + instance.classValue());
		// System.out.println("Message classified as : " +
		// m_Data.classAttribute().value((int)predicted));

	}

	private Instance makeInstance2(int index) throws Exception {
		
		ins.setClassIndex(ins.numAttributes() - 1);
		return ins.get(index);

	}
	
	private Instance makeInstance() throws Exception {
		Instance instance = new DenseInstance(5);
		
		 
		ins.setClassIndex(4);
		instance.setDataset(ins);
		
		instance.setValue(0, "more40");
		instance.setValue(1, "high");
		instance.setValue(2,"yes");
		instance.setValue(3,"excellent");
		instance.setValue(4,"yes");
		
		
		//instance.setClassValue("yes");


		return instance;
	}

	public static void main(String[] args) throws Exception

	{

		TestMultilayerPerceptron wTestInstance = new TestMultilayerPerceptron();
		wTestInstance.init();
		wTestInstance.classifyMessage();

	}

}

pom.xml依赖

<project xmlns="http://maven.apache.org/POM/4.0.0"
	xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<groupId>net.highersoft</groupId>
	<artifactId>model</artifactId>
	<version>0.0.1-SNAPSHOT</version>

	<dependencies>
		<dependency>
			<groupId>nz.ac.waikato.cms.weka</groupId>
			<artifactId>weka-stable</artifactId>
			<version>3.8.5</version>
		</dependency>
		
		<dependency>
            <groupId>com.google.code.gson</groupId>
            <artifactId>gson</artifactId>
            <version>2.8.7</version>
        </dependency>
	</dependencies>
</project>



相关阅读
评论:
点击刷新

↓ 广告开始-头部带绿为生活 ↓
↑ 广告结束-尾部支持多点击 ↑