1. 概述

机器学习(ML)和人工智能(AI)正在重塑软件开发,使系统能够从数据中学习并做出智能预测。

Tribuo 是由 Oracle 开发的生产级开源机器学习库。它简化了构建和部署稳健 ML 模型的过程。与 Weka 和 Deeplearning4j 类似,Tribuo 支持多种机器学习任务,并能轻松集成到 Java 应用中。

本教程将介绍 Tribuo 中可用的各种机器学习算法,并使用 UCI 红酒质量数据集构建一个回归模型来预测红酒品质。

2. Tribuo 是什么?

Tribuo 是一个以 Java 为核心的机器学习库,支持:

  • 监督学习:回归、分类等
  • 无监督学习:聚类

其强类型特性确保了输入输出类型的正确性,有效防止运行时错误并保证模型开发的一致性。

支持以 ONNX 格式导入导出模型,可与 TensorFlow、PyTorch 等 ML 框架无缝集成。

另一大亮点是溯源跟踪功能,该功能记录数据集元数据、模型参数和训练配置,提升透明度和可复现性。

随着 AI 在企业级 Java 应用中的普及,Tribuo 为将智能行为直接嵌入 Java 系统提供了实用工具包。

3. 支持的机器学习算法

Tribuo 支持多种 ML 任务:

  • 分类:预测离散类别或标签。例如预测足球队胜负,或根据品质阈值将红酒分为好/坏
  • 回归:预测连续值,如红酒品质评分或患者胆固醇水平
  • 聚类:在无标签数据中识别分组。例如根据酸度、酒精含量等化学属性对红酒分组(无需预先知道品质评分)

4. 搭建 Tribuo 项目

通过构建回归模型预测红酒品质来实战 Tribuo。

首先在 pom.xml 中添加 Tribuo 依赖:

<dependency>
    <groupId>org.tribuo</groupId>
    <artifactId>tribuo-all</artifactId>
    <version>4.3.2</version>
    <type>pom</type>
</dependency>

tribuo-all 依赖提供了加载和训练数据集所需的类。

然后下载 UCI 红酒质量数据集 并放入 src/main/resources/dataset 目录。数据集包含 11 个理化特征(如酸度、酒精含量):

uci red wine dataset

quality 列是连续数值,适合回归任务。

最后创建 WineQualityRegression 类:

public class WineQualityRegression {
}

后续章节将在此类中实现模型训练和保存逻辑。

5. 类级变量

定义以下类级变量:

public static final String DATASET_PATH = "src/main/resources/dataset/winequality-red.csv";
public static final String MODEL_PATH = "src/main/resources/model/winequality-red-regressor.ser";
public Model<Regressor> model;
public Trainer<Regressor> trainer;
public Dataset<Regressor> trainSet;
public Dataset<Regressor> testSet;

代码中定义了数据集路径和模型保存/加载路径。

四个核心变量说明:

  • Model – 存储预测模型的类
  • Trainer – 可训练预测模型的接口
  • Dataset – 存储训练数据的类

显式指定模型输出类型为 Regressor(回归器)。

6. 加载与分割数据集

实现数据加载和分割方法:

void createDatasets() throws Exception {
    RegressionFactory regressionFactory = new RegressionFactory();
    CSVLoader<Regressor> csvLoader = new CSVLoader<>(';', CSVIterator.QUOTE, regressionFactory);
    DataSource<Regressor> dataSource = csvLoader.loadDataSource(Paths.get(DATASET_PATH), "quality");

    TrainTestSplitter<Regressor> dataSplitter = new TrainTestSplitter<>(dataSource, 0.7, 1L);
    trainSet = new MutableDataset<>(dataSplitter.getTrain());
    testSet = new MutableDataset<>(dataSplitter.getTest());
}

使用 CSVLoader 解析分号分隔的 CSV 文件并准备回归数据。RegressionFactory 创建回归输出,指定目标变量 quality 为连续变量。DataSource 存储解析后的数据。

通过 TrainTestSplitter 按 70% 训练集、30% 测试集分割数据,用于评估模型泛化能力。

7. 训练回归模型

红酒品质评分为数值,使用 CART(分类回归树)作为基础学习器训练模型:

void createTrainer() {
    CARTRegressionTrainer subsamplingTree = new CARTRegressionTrainer(
      Integer.MAX_VALUE,
      AbstractCARTTrainer.MIN_EXAMPLES,
      0.001f,
      0.7f,
      new MeanSquaredError(),
      Trainer.DEFAULT_SEED
    );

    trainer = new RandomForestTrainer<>(subsamplingTree, new AveragingCombiner(), 10);
    model = trainer.train(trainSet); 
}

CARTRegressionTrainer 配置决策树:

  • 无最大深度限制
  • 每个分裂最少 6 个样本
  • 使用均方误差作为分裂标准

RandomForestTrainer 组合 10 棵 CART 决策树,通过 AveragingCombiner 平均预测结果。

train() 方法在 trainSet 上训练模型,生成用于预测红酒品质的 *Model*。

8. 模型评估

使用 RegressionEvaluator 评估回归模型性能:

void evaluate(Model<Regressor> model, String datasetName, Dataset<Regressor> dataset) {
    RegressionEvaluator evaluator = new RegressionEvaluator();
    RegressionEvaluation evaluation = evaluator.evaluate(model, dataset);
    Regressor dimension0 = new Regressor("DIM-0", Double.NaN);

    log.info("MAE: " + evaluation.mae(dimension0));
    log.info("RMSE: " + evaluation.rmse(dimension0));
    log.info("R^2: " + evaluation.r2(dimension0));
}

RegressionEvaluator 计算并输出三个关键指标:

  • MAE(平均绝对误差)
  • RMSE(均方根误差)
  • (决定系数)

调用评估方法:

void evaluateModels() throws Exception {
    log.info("Training model");
    evaluate(model, "trainSet", trainSet);

    log.info("Testing model");
    evaluate(model, "testSet", testSet);
}

程序执行时的评估结果:

07:10:14.405 [main] INFO  tribuo.WineQualityRegression - Training model
07:10:14.406 [main] INFO  tribuo.WineQualityRegression - Results for trainSet---------------------
07:10:14.537 [main] INFO  tribuo.WineQualityRegression - MAE: 0.25025410332970005
07:10:14.537 [main] INFO  tribuo.WineQualityRegression - RMSE: 0.3422557198486092
07:10:14.538 [main] INFO  tribuo.WineQualityRegression - R^2: 0.8190947891297661
07:10:14.538 [main] INFO  tribuo.WineQualityRegression - Testing model
07:10:14.540 [main] INFO  tribuo.WineQualityRegression - Results for testSet---------------------
07:10:14.565 [main] INFO  tribuo.WineQualityRegression - MAE: 0.48711029366796743
07:10:14.565 [main] INFO  tribuo.WineQualityRegression - RMSE: 0.6584973595553575
07:10:14.565 [main] INFO  tribuo.WineQualityRegression - R^2: 0.3444460580874339

指标解读

  • MAE/RMSE 越低越好:预测值与实际值的平均偏差
  • R² 越高越好:模型解释数据方差的能力(0-1 之间)

⚠️ 训练集 R²=0.82 但测试集仅 0.34,表明模型存在过拟合。

9. 保存模型

将训练好的模型序列化到文件:

void saveModel() throws Exception {
    File modelFile = new File(MODEL_PATH);
    try (ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(modelFile))) {
        objectOutputStream.writeObject(model);
    }
}

使用 ObjectOutputStream 序列化模型,避免重复训练,实现模型复用

10. 方法调用

main() 方法中整合所有步骤:

public static void main(String[] args) throws Exception {
    WineQualityRegression wineQualityRegression = new WineQualityRegression();

    wineQualityRegression.createDatasets();
    wineQualityRegression.createTrainer();
    wineQualityRegression.evaluateModels();
    wineQualityRegression.saveModel();
}

执行后模型将保存到指定目录。

11. 使用模型

创建 WinePredictor 类加载模型:

class WineQualityPredictor {
    private static final Logger log = LoggerFactory.getLogger(WineQualityPredictor.class);

    public static void main(String[] args) throws IOException, ClassNotFoundException {
        File modelFile = new File("src/main/resources/model/winequality-red-regressor.ser");
        Model<Regressor> loadedModel = null;

        try (ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(modelFile))) {
            loadedModel = (Model<Regressor>) objectInputStream.readObject();
        }
}

Tribuo 强类型特性要求显式指定模型类型(Regressor)。

创建红酒样本的 ArrayExample 对象:

ArrayExample<Regressor> wineAttribute = new ArrayExample<Regressor>(new Regressor("quality", Double.NaN));
wineAttribute.add("fixed acidity", 7.4f);
wineAttribute.add("volatile acidity", 0.7f);
wineAttribute.add("citric acid", 0.47f);
wineAttribute.add("residual sugar", 1.9f);
wineAttribute.add("chlorides", 0.076f);
wineAttribute.add("free sulfur dioxide", 11.0f);
wineAttribute.add("total sulfur dioxide", 34.0f);
wineAttribute.add("density", 0.9978f);
wineAttribute.add("pH", 3.51f);
wineAttribute.add("sulphates", 0.56f);
wineAttribute.add("alcohol", 9.4f);

使用 Prediction 类进行预测:

Prediction<Regressor> prediction = loadedModel.predict(wineAttribute);
double predictQuality = prediction.getOutput().getValues()[0];
log.info("Predicted wine quality: " + predictQuality);

预测结果输出:

07:31:05.772 [main] INFO  tribuo.WineQualityPredictor - Predicted wine quality: 5.028163673540464

12. 总结

本文介绍了 Tribuo 及其核心特性,概述了支持的机器学习算法,并通过红酒品质预测实战演示了回归模型的完整开发流程。

完整示例代码可在 GitHub 获取。