java调用pmml模型文件
1.pom.xml依赖
<dependency> <groupId>org.jpmml</groupId> <artifactId>pmml-evaluator</artifactId> <version>1.5.9</version> <!-- <version>1.4.13</version> --> </dependency> <dependency> <groupId>org.jpmml</groupId> <artifactId>pmml-evaluator-extension</artifactId> <version>1.5.9</version> <!-- <version>1.4.13</version> --> </dependency>
2.pmml文件:
<?xml version="1.0" encoding="UTF-8" standalone="yes"?> <PMML xmlns="http://www.dmg.org/PMML-4_4" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.4"> <Header> <Application name="JPMML-SkLearn" version="1.6.15"/> <Timestamp>2021-03-17T07:46:27Z</Timestamp> </Header> <MiningBuildTask> <Extension>PMMLPipeline(steps=[('classifier', DecisionTreeClassifier(random_state=9))])</Extension> </MiningBuildTask> <DataDictionary> <DataField name="y" optype="categorical" dataType="integer"> <Value value="0"/> <Value value="1"/> <Value value="2"/> </DataField> <DataField name="x3" optype="continuous" dataType="float"/> <DataField name="x4" optype="continuous" dataType="float"/> </DataDictionary> <TransformationDictionary/> <TreeModel functionName="classification" algorithmName="sklearn.tree._classes.DecisionTreeClassifier" missingValueStrategy="nullPrediction"> <MiningSchema> <MiningField name="y" usageType="target"/> <MiningField name="x3"/> <MiningField name="x4"/> </MiningSchema> <Output> <OutputField name="probability(0)" optype="continuous" dataType="double" feature="probability" value="0"/> <OutputField name="probability(1)" optype="continuous" dataType="double" feature="probability" value="1"/> <OutputField name="probability(2)" optype="continuous" dataType="double" feature="probability" value="2"/> </Output> <LocalTransformations> <DerivedField name="double(x3)" optype="continuous" dataType="double"> <FieldRef field="x3"/> </DerivedField> <DerivedField name="double(x4)" optype="continuous" dataType="double"> <FieldRef field="x4"/> </DerivedField> </LocalTransformations> <Node> <True/> <Node> <SimplePredicate field="double(x3)" operator="lessOrEqual" value="3.5"/> <Node score="1" recordCount="1"> <SimplePredicate field="double(x3)" operator="lessOrEqual" value="2.0"/> <ScoreDistribution value="0" recordCount="0"/> <ScoreDistribution value="1" recordCount="1"/> <ScoreDistribution value="2" recordCount="0"/> </Node> <Node score="0" recordCount="2"> <True/> <ScoreDistribution value="0" recordCount="2"/> <ScoreDistribution value="1" recordCount="0"/> <ScoreDistribution value="2" recordCount="0"/> </Node> </Node> <Node score="2" recordCount="1"> <SimplePredicate field="double(x4)" operator="lessOrEqual" value="8.0"/> <ScoreDistribution value="0" recordCount="0"/> <ScoreDistribution value="1" recordCount="0"/> <ScoreDistribution value="2" recordCount="1"/> </Node> <Node score="1" recordCount="1"> <True/> <ScoreDistribution value="0" recordCount="0"/> <ScoreDistribution value="1" recordCount="1"/> <ScoreDistribution value="2" recordCount="0"/> </Node> </Node> </TreeModel> </PMML>3.java代码
package net.highersoft.pmml; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import javax.xml.bind.JAXBException; import org.dmg.pmml.FieldName; import org.dmg.pmml.PMML; import org.jpmml.evaluator.Computable; import org.jpmml.evaluator.Evaluator; import org.jpmml.evaluator.FieldValue; import org.jpmml.evaluator.InputField; import org.jpmml.evaluator.LoadingModelEvaluatorBuilder; import org.jpmml.evaluator.ModelEvaluatorFactory; import org.jpmml.evaluator.TargetField; import org.xml.sax.SAXException; public class PMMLDemo { private Evaluator loadPmml() throws IOException, SAXException, JAXBException { File pmmlFile = new File(Class.class.getResource("/demo.pmml").getFile()); Evaluator evaluator = new LoadingModelEvaluatorBuilder().load(pmmlFile).build(); return evaluator; } public static void main(String args[]) throws IOException, SAXException, JAXBException { PMMLDemo demo = new PMMLDemo(); Evaluator evaluator = demo.loadPmml(); List<InputField> inputFields = evaluator.getInputFields(); Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>(); for (InputField inputField : inputFields) { FieldName inputFieldName = inputField.getName(); //Object rawValue =Arrays.asList(1,2,3,4); Object rawValue =1; System.out.println(inputFieldName.getValue()+":"+rawValue); FieldValue inputFieldValue = inputField.prepare(rawValue); arguments.put(inputFieldName, inputFieldValue); } Map<FieldName, ?> results = evaluator.evaluate(arguments); List<TargetField> targetFields = evaluator.getTargetFields(); TargetField targetField = targetFields.get(0); FieldName targetFieldName = targetField.getName(); Object targetFieldValue = results.get(targetFieldName); System.out.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue); } }
相关阅读
评论:
↓ 广告开始-头部带绿为生活 ↓
↑ 广告结束-尾部支持多点击 ↑