1. 概述
机器学习(ML)和人工智能(AI)正在重塑软件开发,使系统能够从数据中学习并做出智能预测。
Tribuo 是由 Oracle 开发的生产级开源机器学习库。它简化了构建和部署稳健 ML 模型的过程。与 Weka 和 Deeplearning4j 类似,Tribuo 支持多种机器学习任务,并能轻松集成到 Java 应用中。
本教程将介绍 Tribuo 中可用的各种机器学习算法,并使用 UCI 红酒质量数据集构建一个回归模型来预测红酒品质。
2. Tribuo 是什么?
Tribuo 是一个以 Java 为核心的机器学习库,支持:
- 监督学习:回归、分类等
- 无监督学习:聚类
其强类型特性确保了输入输出类型的正确性,有效防止运行时错误并保证模型开发的一致性。
支持以 ONNX 格式导入导出模型,可与 TensorFlow、PyTorch 等 ML 框架无缝集成。
另一大亮点是溯源跟踪功能,该功能记录数据集元数据、模型参数和训练配置,提升透明度和可复现性。
随着 AI 在企业级 Java 应用中的普及,Tribuo 为将智能行为直接嵌入 Java 系统提供了实用工具包。
3. 支持的机器学习算法
Tribuo 支持多种 ML 任务:
- 分类:预测离散类别或标签。例如预测足球队胜负,或根据品质阈值将红酒分为好/坏
- 回归:预测连续值,如红酒品质评分或患者胆固醇水平
- 聚类:在无标签数据中识别分组。例如根据酸度、酒精含量等化学属性对红酒分组(无需预先知道品质评分)
4. 搭建 Tribuo 项目
通过构建回归模型预测红酒品质来实战 Tribuo。
首先在 pom.xml 中添加 Tribuo 依赖:
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-all</artifactId>
<version>4.3.2</version>
<type>pom</type>
</dependency>
tribuo-all 依赖提供了加载和训练数据集所需的类。
然后下载 UCI 红酒质量数据集 并放入 src/main/resources/dataset 目录。数据集包含 11 个理化特征(如酸度、酒精含量):
quality 列是连续数值,适合回归任务。
最后创建 WineQualityRegression 类:
public class WineQualityRegression {
}
后续章节将在此类中实现模型训练和保存逻辑。
5. 类级变量
定义以下类级变量:
public static final String DATASET_PATH = "src/main/resources/dataset/winequality-red.csv";
public static final String MODEL_PATH = "src/main/resources/model/winequality-red-regressor.ser";
public Model<Regressor> model;
public Trainer<Regressor> trainer;
public Dataset<Regressor> trainSet;
public Dataset<Regressor> testSet;
代码中定义了数据集路径和模型保存/加载路径。
四个核心变量说明:
- Model – 存储预测模型的类
- Trainer – 可训练预测模型的接口
- Dataset – 存储训练数据的类
显式指定模型输出类型为 Regressor(回归器)。
6. 加载与分割数据集
实现数据加载和分割方法:
void createDatasets() throws Exception {
RegressionFactory regressionFactory = new RegressionFactory();
CSVLoader<Regressor> csvLoader = new CSVLoader<>(';', CSVIterator.QUOTE, regressionFactory);
DataSource<Regressor> dataSource = csvLoader.loadDataSource(Paths.get(DATASET_PATH), "quality");
TrainTestSplitter<Regressor> dataSplitter = new TrainTestSplitter<>(dataSource, 0.7, 1L);
trainSet = new MutableDataset<>(dataSplitter.getTrain());
testSet = new MutableDataset<>(dataSplitter.getTest());
}
使用 CSVLoader 解析分号分隔的 CSV 文件并准备回归数据。RegressionFactory 创建回归输出,指定目标变量 quality 为连续变量。DataSource
通过 TrainTestSplitter 按 70% 训练集、30% 测试集分割数据,用于评估模型泛化能力。
7. 训练回归模型
红酒品质评分为数值,使用 CART(分类回归树)作为基础学习器训练模型:
void createTrainer() {
CARTRegressionTrainer subsamplingTree = new CARTRegressionTrainer(
Integer.MAX_VALUE,
AbstractCARTTrainer.MIN_EXAMPLES,
0.001f,
0.7f,
new MeanSquaredError(),
Trainer.DEFAULT_SEED
);
trainer = new RandomForestTrainer<>(subsamplingTree, new AveragingCombiner(), 10);
model = trainer.train(trainSet);
}
CARTRegressionTrainer 配置决策树:
- 无最大深度限制
- 每个分裂最少 6 个样本
- 使用均方误差作为分裂标准
RandomForestTrainer 组合 10 棵 CART 决策树,通过 AveragingCombiner 平均预测结果。
train() 方法在 trainSet 上训练模型,生成用于预测红酒品质的 *Model
8. 模型评估
使用 RegressionEvaluator 评估回归模型性能:
void evaluate(Model<Regressor> model, String datasetName, Dataset<Regressor> dataset) {
RegressionEvaluator evaluator = new RegressionEvaluator();
RegressionEvaluation evaluation = evaluator.evaluate(model, dataset);
Regressor dimension0 = new Regressor("DIM-0", Double.NaN);
log.info("MAE: " + evaluation.mae(dimension0));
log.info("RMSE: " + evaluation.rmse(dimension0));
log.info("R^2: " + evaluation.r2(dimension0));
}
RegressionEvaluator 计算并输出三个关键指标:
- MAE(平均绝对误差)
- RMSE(均方根误差)
- R²(决定系数)
调用评估方法:
void evaluateModels() throws Exception {
log.info("Training model");
evaluate(model, "trainSet", trainSet);
log.info("Testing model");
evaluate(model, "testSet", testSet);
}
程序执行时的评估结果:
07:10:14.405 [main] INFO tribuo.WineQualityRegression - Training model
07:10:14.406 [main] INFO tribuo.WineQualityRegression - Results for trainSet---------------------
07:10:14.537 [main] INFO tribuo.WineQualityRegression - MAE: 0.25025410332970005
07:10:14.537 [main] INFO tribuo.WineQualityRegression - RMSE: 0.3422557198486092
07:10:14.538 [main] INFO tribuo.WineQualityRegression - R^2: 0.8190947891297661
07:10:14.538 [main] INFO tribuo.WineQualityRegression - Testing model
07:10:14.540 [main] INFO tribuo.WineQualityRegression - Results for testSet---------------------
07:10:14.565 [main] INFO tribuo.WineQualityRegression - MAE: 0.48711029366796743
07:10:14.565 [main] INFO tribuo.WineQualityRegression - RMSE: 0.6584973595553575
07:10:14.565 [main] INFO tribuo.WineQualityRegression - R^2: 0.3444460580874339
指标解读:
- ✅ MAE/RMSE 越低越好:预测值与实际值的平均偏差
- ✅ R² 越高越好:模型解释数据方差的能力(0-1 之间)
⚠️ 训练集 R²=0.82 但测试集仅 0.34,表明模型存在过拟合。
9. 保存模型
将训练好的模型序列化到文件:
void saveModel() throws Exception {
File modelFile = new File(MODEL_PATH);
try (ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(modelFile))) {
objectOutputStream.writeObject(model);
}
}
使用 ObjectOutputStream 序列化模型,避免重复训练,实现模型复用。
10. 方法调用
在 main() 方法中整合所有步骤:
public static void main(String[] args) throws Exception {
WineQualityRegression wineQualityRegression = new WineQualityRegression();
wineQualityRegression.createDatasets();
wineQualityRegression.createTrainer();
wineQualityRegression.evaluateModels();
wineQualityRegression.saveModel();
}
执行后模型将保存到指定目录。
11. 使用模型
创建 WinePredictor 类加载模型:
class WineQualityPredictor {
private static final Logger log = LoggerFactory.getLogger(WineQualityPredictor.class);
public static void main(String[] args) throws IOException, ClassNotFoundException {
File modelFile = new File("src/main/resources/model/winequality-red-regressor.ser");
Model<Regressor> loadedModel = null;
try (ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(modelFile))) {
loadedModel = (Model<Regressor>) objectInputStream.readObject();
}
}
Tribuo 强类型特性要求显式指定模型类型(Regressor)。
创建红酒样本的 ArrayExample 对象:
ArrayExample<Regressor> wineAttribute = new ArrayExample<Regressor>(new Regressor("quality", Double.NaN));
wineAttribute.add("fixed acidity", 7.4f);
wineAttribute.add("volatile acidity", 0.7f);
wineAttribute.add("citric acid", 0.47f);
wineAttribute.add("residual sugar", 1.9f);
wineAttribute.add("chlorides", 0.076f);
wineAttribute.add("free sulfur dioxide", 11.0f);
wineAttribute.add("total sulfur dioxide", 34.0f);
wineAttribute.add("density", 0.9978f);
wineAttribute.add("pH", 3.51f);
wineAttribute.add("sulphates", 0.56f);
wineAttribute.add("alcohol", 9.4f);
使用 Prediction 类进行预测:
Prediction<Regressor> prediction = loadedModel.predict(wineAttribute);
double predictQuality = prediction.getOutput().getValues()[0];
log.info("Predicted wine quality: " + predictQuality);
预测结果输出:
07:31:05.772 [main] INFO tribuo.WineQualityPredictor - Predicted wine quality: 5.028163673540464
12. 总结
本文介绍了 Tribuo 及其核心特性,概述了支持的机器学习算法,并通过红酒品质预测实战演示了回归模型的完整开发流程。
完整示例代码可在 GitHub 获取。