作为一名Java开发者,如果要训练自己的预测模型,是不是第一想到的还是把Python拿起来?其实不一定非要拿起Python,在Java领域也有自己的生产级机器学习工具,它支持分类、回归、聚类等常见任务,还能无缝对接 TensorFlow 等框架,用 Java 就能直接训模型、做预测!它就是:Tribuo。
Tribuo 是 Oracle 推出的面向生产环境的开源机器学习库,极大简化了健壮 ML 模型的构建与部署。与 Weka 和 Deeplearning4j 类似,Tribuo 支持多种机器学习任务,并能轻松集成到 Java 应用中。
本文我们将了解 Tribuo 支持的多种机器学习算法,并以 UCI 红葡萄酒质量数据集为例,构建一个用于预测葡萄酒质量的回归模型。
Tribuo 是一个以 Java 为核心的机器学习库,支持:
此外,Tribuo 拥有强类型特性,能够强制输入输出类型一致,有效防止运行时错误,确保模型开发过程的规范性。
它支持以 ONNX(开放神经网络交换)格式导入和导出模型,便于与 TensorFlow、PyTorch 等主流 ML 框架集成。
另一个亮点是 provenance(溯源)追踪功能,可记录数据集、模型参数和训练配置等元数据,提升透明度和可复现性。
随着 AI 在企业级 Java 应用中的普及,Tribuo 为在 Java 系统中直接嵌入智能行为提供了实用工具包。
Tribuo 支持多种机器学习任务,包括:
我们将通过构建一个葡萄酒质量回归预测模型,体验 Tribuo 的实际应用。
首先,在 pom.xml
中添加Tribuo 依赖:
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-all</artifactId>
<version>4.3.2</version>
</dependency>
tribuo-all
依赖提供了加载和训练数据集所需的相关类。
下载UCI 红葡萄酒质量数据集:
https://archive.ics.uci.edu/dataset/186/wine+quality
放置到 src/main/resources/dataset
目录下。该数据集包含 11 个理化特征,如酸度和酒精含量:
quality
列是一个适合回归任务的连续数值。
最后,创建一个名为 WineQualityRegression
的类:
public class WineQualityRegression {
}
后续章节将在该类中实现训练和保存模型的相关逻辑。
接下来,定义如下类级变量:
public staticfinalStringDATASET_PATH="src/main/resources/dataset/winequality-red.csv";
publicstaticfinalStringMODEL_PATH="src/main/resources/model/winequality-red-regressor.ser";
public Model<Regressor> model;
public Trainer<Regressor> trainer;
public Dataset<Regressor> trainSet;
public Dataset<Regressor> testSet;
上述代码中,我们定义了数据集路径和训练模型保存/加载路径。
随后,定义了四个变量,分别代表:
Model
—— 存储预测模型的类Trainer
—— 可训练预测模型的接口Dataset
—— 用于训练的数据集类此外,我们显式指定了模型输出类型为 Regressor
。
定义一个方法用于加载并划分数据集:
void createDatasets()throws Exception {
RegressionFactoryregressionFactory=newRegressionFactory();
CSVLoader<Regressor> csvLoader = newCSVLoader<>(';', CSVIterator.QUOTE, regressionFactory);
DataSource<Regressor> dataSource = csvLoader.loadDataSource(Paths.get(DATASET_PATH), "quality");
TrainTestSplitter<Regressor> dataSplitter = newTrainTestSplitter<>(dataSource, 0.7, 1L);
trainSet = newMutableDataset<>(dataSplitter.getTrain());
testSet = newMutableDataset<>(dataSplitter.getTest());
}
这里,我们用 CSVLoader
解析分号分隔的 CSV 文件并为回归任务做准备。RegressionFactory
用于创建回归输出,指定目标变量 quality
为连续变量。DataSource<Regressor>
保存解析后的数据。
随后,为了评估模型的泛化能力和表现,使用 TrainTestSplitter
将数据集按 7:3 划分为训练集和测试集。
由于葡萄酒质量分数为数值型,我们采用分类与回归树(CART)作为基学习器进行训练:
void createTrainer() {
CARTRegressionTrainersubsamplingTree=newCARTRegressionTrainer(
Integer.MAX_VALUE,
AbstractCARTTrainer.MIN_EXAMPLES,
0.001f,
0.7f,
newMeanSquaredError(),
Trainer.DEFAULT_SEED
);
trainer = newRandomForestTrainer<>(subsamplingTree, newAveragingCombiner(), );
model = trainer.train(trainSet);
}
上述方法中,CARTRegressionTrainer
配置了无最大深度、每次分裂最少 6 个样本、以均方误差为分裂标准。随后,RandomForestTrainer
结合 10 棵 CART 决策树,并用 AveragingCombiner
平均预测结果。
train()
方法在 trainSet
数据集上训练模型,生成用于预测葡萄酒质量分数的 Model<Regressor>
。
接下来,使用 RegressionEvaluator
评估模型在数据集上的表现,计算相关指标:
void evaluate(Model<Regressor> model, String datasetName, Dataset<Regressor> dataset) {
RegressionEvaluatorevaluator=newRegressionEvaluator();
RegressionEvaluationevaluation= evaluator.evaluate(model, dataset);
Regressordimension0=newRegressor("DIM-0", Double.NaN);
log.info("MAE: " + evaluation.mae(dimension0));
log.info("RMSE: " + evaluation.rmse(dimension0));
log.info("R^2: " + evaluation.r2(dimension0));
}
RegressionEvaluator
用于评估模型在数据集上的表现。我们将 MAE(平均绝对误差)、RMSE(均方根误差)和 R^2(决定系数)输出到控制台。
随后,调用 evaluate()
方法评估模型和数据集:
void evaluateModels() throws Exception {
log.info("Training model");
evaluate(model, "trainSet", trainSet);
log.info("Testing model");
evaluate(model, "testSet", testSet);
}
执行程序后,训练集和测试集的评估结果如下:
07:10:14.405 [main] INFO tribuo.WineQualityRegression - Training model
07:10:14.406 [main] INFO tribuo.WineQualityRegression - Results for trainSet---------------------
07:10:14.537 [main] INFO tribuo.WineQualityRegression - MAE: 0.25025410332970005
07:10:14.537 [main] INFO tribuo.WineQualityRegression - RMSE: 0.3422557198486092
07:10:14.538 [main] INFO tribuo.WineQualityRegression - R^2: 0.8190947891297661
07:10:14.538 [main] INFO tribuo.WineQualityRegression - Testing model
07:10:14.540 [main] INFO tribuo.WineQualityRegression - Results for testSet---------------------
07:10:14.565 [main] INFO tribuo.WineQualityRegression - MAE: 0.48711029366796743
07:10:14.565 [main] INFO tribuo.WineQualityRegression - RMSE: 0.6584973595553575
07:10:14.565 [main] INFO tribuo.WineQualityRegression - R^2: 0.3444460580874339
MAE
表示预测值与实际值的绝对差异,RMSE
表示预测值与实际值的平方差均值的平方根,R^2
表示模型对训练和测试数据方差的解释能力。
更低的 MAE
和 RMSE
,以及更高的 R^2
,意味着模型预测性能更优。
最后,将模型保存为文件以便后续复用:
void saveModel()throws Exception {
FilemodelFile=newFile(MODEL_PATH);
try (ObjectOutputStreamobjectOutputStream=newObjectOutputStream(newFileOutputStream(modelFile))) {
objectOutputStream.writeObject(model);
}
}
上述代码通过 ObjectOutputStream 类将训练好的模型序列化保存到文件。这样,我们可以在后续预测中直接复用模型,无需重新训练。
现在,在 main()
方法中调用前面创建的方法:
public static void main(String[] args) throws Exception {
WineQualityRegression wineQualityRegression = new WineQualityRegression();
wineQualityRegression.createDatasets();
wineQualityRegression.createTrainer();
wineQualityRegression.evaluateModels();
wineQualityRegression.saveModel();
}
编译代码后,模型会被保存到指定目录。
新建一个 WinePredictor
类,在 main()
方法中加载已保存的模型:
class WineQualityPredictor {
privatestaticfinalLoggerlog= LoggerFactory.getLogger(WineQualityPredictor.class);
publicstaticvoidmain(String[] args)throws IOException, ClassNotFoundException {
FilemodelFile=newFile("src/main/resources/model/winequality-red-regressor.ser");
Model<Regressor> loadedModel = null;
try (ObjectInputStreamobjectInputStream=newObjectInputStream(newFileInputStream(modelFile))) {
loadedModel = (Model<Regressor>) objectInputStream.readObject();
}
}
如前所述,Tribuo 对类型敏感,因此我们指定模型类型为 Regressor
。
通过创建 ObjectInputStream
并传入模型路径来加载模型。
然后,创建一个 ArrayExample
对象,表示单个葡萄酒样本:
ArrayExample<Regressor> wineAttribute = new ArrayExample<Regressor>(new Regressor("quality", Double.NaN));
wineAttribute.add("fixed acidity", 7.4f);
wineAttribute.add("volatile acidity", 0.7f);
wineAttribute.add("citric acid", 0.47f);
wineAttribute.add("residual sugar", 1.9f);
wineAttribute.add("chlorides", 0.076f);
wineAttribute.add("free sulfur dioxide", 11.0f);
wineAttribute.add("total sulfur dioxide", 34.0f);
wineAttribute.add("density", 0.9978f);
wineAttribute.add("pH", 3.51f);
wineAttribute.add("sulphates", 0.56f);
wineAttribute.add("alcohol", 9.4f);
最后,使用 Prediction
类进行预测:
Prediction<Regressor> prediction = loadedModel.predict(wineAttribute);
double predictQuality = prediction.getOutput().getValues()[];
log.info("Predicted wine quality: " + predictQuality);
预测结果如下:
07:31:05.772 [main] INFO tribuo.WineQualityPredictor - Predicted wine quality: 5.028163673540464
在本文中,我们学习了 Tribuo 及其特性,了解了其支持的部分机器学习算法,并通过回归算法训练了葡萄酒质量预测模型。
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有