1. 概述
机器学习在现代软件开发中至关重要。我们构建各种架构的模型,使用不同算法训练,并通过神经网络优化系统实现惊人效果。
本教程将深入探索 H2O 平台,它提供了一种简单粗暴的方式创建、训练和调优模型。
2. 安装开源 H2O 平台
从官网下载 H2O 后,在目标目录启动平台:
java -jar h2o.jar
启动成功后,访问 http://localhost:54321 进入 Web 控制台:
主页面展示了 H2O 平台提供的核心功能列表。这些功能也可通过顶部菜单栏访问,点击 Help → Assist Me 即可查看完整操作指南。
3. 准备数据集
训练模型前需要准备合适的数据集。
**我们将解决机器学习中的经典分类问题**,使用广为人知的 iris 数据集。该数据集包含不同花卉的特征数据。
下载 CSV 格式数据集后,文件内容应如下:
"sepal.length","sepal.width","petal.length","petal.width","variety"
5.1,3.5,1.4,.2,"Setosa"
4.9,3,1.4,.2,"Setosa"
4.7,3.2,1.3,.2,"Setosa"
...
5.8,2.7,3.9,1.2,"Versicolor"
6,2.7,5.1,1.6,"Versicolor"
5.4,3,4.5,1.5,"Versicolor"
...
6.5,3,5.2,2,"Virginica"
6.2,3.4,5.4,2.3,"Virginica"
5.9,3,5.1,1.8,"Virginica"
每行数据包含鸢尾花的四个特征(花萼/花瓣的长宽)及其对应品种。
4. 训练模型
使用准备好的数据集开始模型训练。
4.1 导入数据集
在 Web 控制台导入数据集。通过顶部菜单 Data → Import Files 选择数据文件所在目录,搜索文件后点击 Import:
4.2 划分训练集与测试集
数据上传后,使用 getFrames 功能查看。通过顶部菜单 Data → List All Frames 找到数据集,点击 Parse:
**遵循机器学习标准实践,需将数据集划分为训练集和测试集**。使用 splitFrame 功能(顶部菜单 Data → Split Frame),按 80/20 比例划分:
4.3 构建模型
开始构建模型。顶部菜单选择 Model,我们选用随机森林算法(分类问题效果显著):
必须指定三个关键参数才能训练模型:
training_frame
:训练数据集validation_frame
:验证数据集response_column
:目标列(即品种列)
根据所选算法,还可调整其他参数优化效果。配置完成后点击 Build Model:
4.4 AutoML 功能
H2O 的 AutoML 是个杀手级功能。顶部菜单 Model → Run AutoML:当不确定使用哪种算法时,直接选它就对了。需配置与手动构建相同的参数,额外设置 max_runtime_secs
控制训练时长:
训练完成后,平台会展示模型排行榜:从中可直接选用表现最佳模型:
4.5 下载模型
训练完成后可下载模型相关文件:
5. 在 Java 应用中使用模型
将训练好的模型集成到 Java 应用中。
5.1 添加 H2O 文件
将下载的文件放入项目 libs
目录,然后添加到 classpath。
5.2 配置依赖
在 pom.xml
中添加模型依赖:
<dependency>
<groupId>ai.h2o</groupId>
<artifactId>h2o-genmodel</artifactId>
<version>1.0</version>
<scope>system</scope>
<systemPath>${project.basedir}/libs/h2o-genmodel.jar</systemPath>
</dependency>
注意 groupId
和 artifactId
是 H2O 预定义的。
5.3 使用手动构建的模型预测
Java 代码调用模型示例:
public class H2OModelLiveTest {
Logger logger = LoggerFactory.getLogger(H2OModelLiveTest.class);
@Test
public void givenH2OTrainedModel_whenPredictTheIrisByFeatures_thenExpectedFlowerShouldBeReturned() throws IOException, PredictException {
String mojoFilePath = "libs/mojo.zip";
MojoModel mojoModel = MojoModel.load(mojoFilePath);
EasyPredictModelWrapper model = new EasyPredictModelWrapper(mojoModel);
RowData row = new RowData();
row.put("sepal.length", 5.1);
row.put("sepal.width", 3.4);
row.put("petal.length", 4.6);
row.put("petal.width", 1.2);
MultinomialModelPrediction prediction = model.predictMultinomial(row);
Assertions.assertEquals("Versicolor", prediction.label);
logger.info("Class probabilities: ");
for (int i = 0; i < prediction.classProbabilities.length; i++) {
logger.info("Class " + i + ": " + prediction.classProbabilities[i]);
}
}
}
核心步骤:
- 从 MOJO 文件加载
MojoModel
- 用
EasyPredictModelWrapper
包装模型 - 创建特征数据
RowData
- 调用
predictMultinomial()
获取预测结果 - 验证预测品种为 Versicolor
输出显示分类准确率高达 95.97% ✅:
19:33:48.648 [main] INFO com.baeldung.h2o.H2OModelLiveTest - Class probabilities:
19:33:48.653 [main] INFO com.baeldung.h2o.H2OModelLiveTest - Class 0: 0.016846955011789237
19:33:48.653 [main] INFO com.baeldung.h2o.H2OModelLiveTest - Class 1: 0.9597659357519948
19:33:48.653 [main] INFO com.baeldung.h2o.H2OModelLiveTest - Class 2: 0.023387109236216036
5.4 使用 AutoML 模型预测
使用 AutoML 选出的最佳模型进行预测:
@Test
public void givenH2OTrainedAutoMLModel_whenPredictTheIrisByFeatures_thenExpectedFlowerShouldBeReturned() throws IOException, PredictException {
String mojoFilePath = "libs/automl_model.zip";
MojoModel mojoModel = MojoModel.load(mojoFilePath);
EasyPredictModelWrapper model = new EasyPredictModelWrapper(mojoModel);
RowData row = new RowData();
row.put("sepal.length", 5.1);
row.put("sepal.width", 3.4);
row.put("petal.length", 4.6);
row.put("petal.width", 1.2);
MultinomialModelPrediction prediction = model.predictMultinomial(row);
Assertions.assertEquals("Versicolor", prediction.label);
logger.info("Class probabilities: ");
for (int i = 0; i < prediction.classProbabilities.length; i++) {
logger.info("Class " + i + ": " + prediction.classProbabilities[i]);
}
}
虽然预测结果相同,但概率略低(84.52%)⚠️:
20:28:06.440 [main] INFO com.baeldung.h2o.H2OModelLiveTest - Class probabilities:
20:28:06.443 [main] INFO com.baeldung.h2o.H2OModelLiveTest - Class 0: 0.08536296008169375
20:28:06.443 [main] INFO com.baeldung.h2o.H2OModelLiveTest - Class 1: 0.8451806663486182
20:28:06.443 [main] INFO com.baeldung.h2o.H2OModelLiveTest - Class 2: 0.06945637356968806
6. 总结
本文深入实践了 H2O 平台的核心功能:通过该工具,我们能高效训练神经网络并生成 Java 应用可用的模型文件。这对不想深入 Python 技术栈的开发者尤其有用,避免了学习额外 ML 库的麻烦。