1. 概述

机器学习在现代软件开发中至关重要。我们构建各种架构的模型,使用不同算法训练,并通过神经网络优化系统实现惊人效果。

本教程将深入探索 H2O 平台,它提供了一种简单粗暴的方式创建、训练和调优模型。

2. 安装开源 H2O 平台

官网下载 H2O 后,在目标目录启动平台:

java -jar h2o.jar

启动成功后,访问 http://localhost:54321 进入 Web 控制台:

H2O 控制台界面

主页面展示了 H2O 平台提供的核心功能列表。这些功能也可通过顶部菜单栏访问,点击 HelpAssist Me 即可查看完整操作指南。

3. 准备数据集

训练模型前需要准备合适的数据集。

**我们将解决机器学习中的经典分类问题**,使用广为人知的 iris 数据集。该数据集包含不同花卉的特征数据。

下载 CSV 格式数据集后,文件内容应如下:

"sepal.length","sepal.width","petal.length","petal.width","variety"
5.1,3.5,1.4,.2,"Setosa"
4.9,3,1.4,.2,"Setosa"
4.7,3.2,1.3,.2,"Setosa"
...
5.8,2.7,3.9,1.2,"Versicolor"
6,2.7,5.1,1.6,"Versicolor"
5.4,3,4.5,1.5,"Versicolor"
...
6.5,3,5.2,2,"Virginica"
6.2,3.4,5.4,2.3,"Virginica"
5.9,3,5.1,1.8,"Virginica"

每行数据包含鸢尾花的四个特征(花萼/花瓣的长宽)及其对应品种。

4. 训练模型

使用准备好的数据集开始模型训练。

4.1 导入数据集

在 Web 控制台导入数据集。通过顶部菜单 DataImport Files 选择数据文件所在目录,搜索文件后点击 Import

H2O 导入文件界面

4.2 划分训练集与测试集

数据上传后,使用 getFrames 功能查看。通过顶部菜单 DataList All Frames 找到数据集,点击 Parse

H2O 数据帧查看界面

**遵循机器学习标准实践,需将数据集划分为训练集和测试集**。使用 splitFrame 功能(顶部菜单 DataSplit Frame),按 80/20 比例划分:

H2O 数据集划分界面

4.3 构建模型

开始构建模型。顶部菜单选择 Model,我们选用随机森林算法(分类问题效果显著):

H2O 模型构建界面

必须指定三个关键参数才能训练模型

  • training_frame:训练数据集
  • validation_frame:验证数据集
  • response_column:目标列(即品种列)

根据所选算法,还可调整其他参数优化效果。配置完成后点击 Build Model

H2O 模型参数配置界面

4.4 AutoML 功能

H2O 的 AutoML 是个杀手级功能。顶部菜单 ModelRun AutoML当不确定使用哪种算法时,直接选它就对了。需配置与手动构建相同的参数,额外设置 max_runtime_secs 控制训练时长:

H2O AutoML 界面

训练完成后,平台会展示模型排行榜:从中可直接选用表现最佳模型

H2O AutoML 排行榜

4.5 下载模型

训练完成后可下载模型相关文件

  • 点击 Download Gen Model 获取 Java 应用所需的 JAR 包: H2O 模型下载界面
  • 点击 Download Model Deployment Package (MOJO) 下载模型本身: H2O MOJO 下载界面

5. 在 Java 应用中使用模型

将训练好的模型集成到 Java 应用中。

5.1 添加 H2O 文件

将下载的文件放入项目 libs 目录,然后添加到 classpath。

5.2 配置依赖

pom.xml 中添加模型依赖:

<dependency>
    <groupId>ai.h2o</groupId>
    <artifactId>h2o-genmodel</artifactId>
    <version>1.0</version>
    <scope>system</scope>
    <systemPath>${project.basedir}/libs/h2o-genmodel.jar</systemPath>
</dependency>

注意 groupIdartifactId 是 H2O 预定义的。

5.3 使用手动构建的模型预测

Java 代码调用模型示例:

public class H2OModelLiveTest {

    Logger logger = LoggerFactory.getLogger(H2OModelLiveTest.class);

    @Test
    public void givenH2OTrainedModel_whenPredictTheIrisByFeatures_thenExpectedFlowerShouldBeReturned() throws IOException, PredictException {
        String mojoFilePath = "libs/mojo.zip";

        MojoModel mojoModel = MojoModel.load(mojoFilePath);
        EasyPredictModelWrapper model = new EasyPredictModelWrapper(mojoModel);

        RowData row = new RowData();
        row.put("sepal.length", 5.1);
        row.put("sepal.width", 3.4);
        row.put("petal.length", 4.6);
        row.put("petal.width", 1.2);

        MultinomialModelPrediction prediction = model.predictMultinomial(row);

        Assertions.assertEquals("Versicolor", prediction.label);

        logger.info("Class probabilities: ");
        for (int i = 0; i < prediction.classProbabilities.length; i++) {
            logger.info("Class " + i + ": " + prediction.classProbabilities[i]);
        }
    }
}

核心步骤:

  1. 从 MOJO 文件加载 MojoModel
  2. EasyPredictModelWrapper 包装模型
  3. 创建特征数据 RowData
  4. 调用 predictMultinomial() 获取预测结果
  5. 验证预测品种为 Versicolor

输出显示分类准确率高达 95.97% ✅:

19:33:48.648 [main] INFO  com.baeldung.h2o.H2OModelLiveTest - Class probabilities: 
19:33:48.653 [main] INFO  com.baeldung.h2o.H2OModelLiveTest - Class 0: 0.016846955011789237
19:33:48.653 [main] INFO  com.baeldung.h2o.H2OModelLiveTest - Class 1: 0.9597659357519948
19:33:48.653 [main] INFO  com.baeldung.h2o.H2OModelLiveTest - Class 2: 0.023387109236216036

5.4 使用 AutoML 模型预测

使用 AutoML 选出的最佳模型进行预测:

@Test
public void givenH2OTrainedAutoMLModel_whenPredictTheIrisByFeatures_thenExpectedFlowerShouldBeReturned() throws IOException, PredictException {
    String mojoFilePath = "libs/automl_model.zip";

    MojoModel mojoModel = MojoModel.load(mojoFilePath);
    EasyPredictModelWrapper model = new EasyPredictModelWrapper(mojoModel);

    RowData row = new RowData();
    row.put("sepal.length", 5.1);
    row.put("sepal.width", 3.4);
    row.put("petal.length", 4.6);
    row.put("petal.width", 1.2);

    MultinomialModelPrediction prediction = model.predictMultinomial(row);

    Assertions.assertEquals("Versicolor", prediction.label);

    logger.info("Class probabilities: ");
    for (int i = 0; i < prediction.classProbabilities.length; i++) {
        logger.info("Class " + i + ": " + prediction.classProbabilities[i]);
    }
}

虽然预测结果相同,但概率略低(84.52%)⚠️:

20:28:06.440 [main] INFO  com.baeldung.h2o.H2OModelLiveTest - Class probabilities: 
20:28:06.443 [main] INFO  com.baeldung.h2o.H2OModelLiveTest - Class 0: 0.08536296008169375
20:28:06.443 [main] INFO  com.baeldung.h2o.H2OModelLiveTest - Class 1: 0.8451806663486182
20:28:06.443 [main] INFO  com.baeldung.h2o.H2OModelLiveTest - Class 2: 0.06945637356968806

6. 总结

本文深入实践了 H2O 平台的核心功能:通过该工具,我们能高效训练神经网络并生成 Java 应用可用的模型文件。这对不想深入 Python 技术栈的开发者尤其有用,避免了学习额外 ML 库的麻烦。


原始标题:Introduction to H2O | Baeldung