一、训练
1.准备数据文件,如这样的arff文件
@relation decisiontree_weka @attribute age {'less30','30to40','more40'} @attribute income {'high','medium','low'} @attribute student {'yes','no'} @attribute credit_rating {'fair','excellent'} @attribute 'buys_computer' {'yes','no'} @data less30,high,no,fair,no less30,high,no,excellent,no 30to40,high,no,fair,yes more40,medium,no,fair,yes more40,low,yes,fair,yes more40,low,yes,excellent,no 30to40,low,yes,excellent,yes less30,medium,no,fair,no less30,low,yes,fair,yes more40,medium,yes,fair,yes less30,medium,yes,excellent,yes 30to40,medium,no,excellent,yes 30to40,high,yes,fair,yes more40,medium,no,excellent,no2.打开GUI,选择参数训练
Explorer -> Preprocess ->Open File
选择刚才的文件
训练:
Classify -> Choose选择多层网络 ->
选择start训练
右侧窗口可右键 -> save model 保存模型文件
使用代码调用模型,预测分类:
package net.highersoft.weka; import java.io.File; import java.io.FileInputStream; import com.google.gson.Gson; import weka.classifiers.Classifier; import weka.classifiers.functions.MultilayerPerceptron; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; import weka.core.SerializationHelper; import weka.core.converters.ConverterUtils.DataSource; public class TestMultilayerPerceptron { Instances m_Data = null; Classifier m_Classifier = null; Instances ins=null; public void init() throws Exception { File inputFile = new File("/Users/chengzhong/code/eclipse-workspace-highersoft/wekamodel/file/bp.model"); m_Classifier = (MultilayerPerceptron) SerializationHelper.read(new FileInputStream(inputFile)); String file = "/Users/chengzhong/code/eclipse-workspace-highersoft/wekamodel/file/dt3_test.arff"; ins = DataSource.read(file); // Instances data = m_Classifier.getDataSet(); } public void classifyMessage() throws Exception{ int index=2; Instance instance = makeInstance(); //Instance instance = makeInstance2(index); double predicted = m_Classifier.classifyInstance(instance); System.out.println("index:"+index+" predicted:" + predicted); double[] vals = m_Classifier.distributionForInstance(instance); System.out.print("instance:" ); for(int i=0;i<instance.numAttributes();i++) { //instance.attribute(i).indexOfValue(instance.attribute(i).name()) System.out.print(instance.attribute(i).name()+":"+ instance.stringValue(instance.attribute(i)) +"\t"); } System.out.println(); System.out.println("distribution predicted:" + new Gson().toJson(vals)); System.out.println("real class:" + instance.classValue()); // System.out.println("Message classified as : " + // m_Data.classAttribute().value((int)predicted)); } private Instance makeInstance2(int index) throws Exception { ins.setClassIndex(ins.numAttributes() - 1); return ins.get(index); } private Instance makeInstance() throws Exception { Instance instance = new DenseInstance(5); ins.setClassIndex(4); instance.setDataset(ins); instance.setValue(0, "more40"); instance.setValue(1, "high"); instance.setValue(2,"yes"); instance.setValue(3,"excellent"); instance.setValue(4,"yes"); //instance.setClassValue("yes"); return instance; } public static void main(String[] args) throws Exception { TestMultilayerPerceptron wTestInstance = new TestMultilayerPerceptron(); wTestInstance.init(); wTestInstance.classifyMessage(); } }
pom.xml依赖
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>net.highersoft</groupId> <artifactId>model</artifactId> <version>0.0.1-SNAPSHOT</version> <dependencies> <dependency> <groupId>nz.ac.waikato.cms.weka</groupId> <artifactId>weka-stable</artifactId> <version>3.8.5</version> </dependency> <dependency> <groupId>com.google.code.gson</groupId> <artifactId>gson</artifactId> <version>2.8.7</version> </dependency> </dependencies> </project>
软件下载 - 联系邮箱 - 关注微博
Copyright © 2010-2024 匠艺软件 蜀ICP备19010796号