1. 引言

逻辑回归(Logistic Regression)是机器学习(ML)开发者工具箱中的核心算法之一。

本文将深入剖析逻辑回归的核心思想,并通过一个经典案例——手写数字识别,带你从零实现一个基于 Java 的完整 ML 流程。

我们会跳过那些“什么是机器学习”的基础科普,直接切入重点。如果你已经熟悉数据预处理、模型训练这些概念,那这篇文章正适合你。

2. 机器学习流程概览

机器学习解决的是那些人类能轻松理解,但难以用规则精确描述的问题。比如:“这张图里是不是一只猫?”——我们一眼就能判断,但写不出明确的 if-else 判断逻辑。

因此,ML 的核心思路是:让模型从数据中自动学习规律

这个过程依赖一个关键阶段:训练(Training)。我们把大量标注好的数据喂给算法,它会不断调整内部参数(也就是“权重”),使得预测结果越来越接近真实标签。

整个 ML 工作流通常是迭代的:

ml1

关键步骤包括:

数据获取:来源多样,需统一格式
数据清洗与划分:确保数据代表性,通常分为训练集和测试集
模型构建:选择合适的算法架构
训练与评估:迭代优化,直到效果达标
⚠️ 踩坑提示:如果训练数据中没有红苹果,模型几乎不可能识别出红苹果——数据决定模型的上限。

3. 机器学习范式

根据输入数据的类型和任务目标,ML 主要分为三类:

  • 监督学习(Supervised Learning):输入数据带标签,目标是学习从输入到标签的映射。
    ✅ 典型任务:图像分类、情感分析、目标检测
  • 无监督学习(Unsupervised Learning):数据无标签,目标是发现数据内在结构。
    ✅ 典型任务:聚类、异常检测
  • 强化学习(Reinforcement Learning):通过试错和奖励机制学习策略。
    ✅ 典型任务:游戏 AI、机器人控制

本文聚焦于 监督学习 场景。

4. 常用机器学习算法

构建模型时,我们有多种“武器”可选:

  • 线性回归(Linear Regression)
  • 逻辑回归(Logistic Regression)
  • 神经网络(Neural Networks)
  • 支持向量机(SVM)
  • k-近邻算法(k-Nearest Neighbours)

实际项目中,往往需要组合多种技术。本文将使用 逻辑回归思想 + 神经网络架构 来实现分类任务。

5. Java 机器学习库选型

虽然 Python 是 ML 主流语言,但 Java 在企业级应用中依然可靠。我们选择以下库:

  • TensorFlow for Java:Google 官方支持,适合生产环境部署
  • Deeplearning4j(DL4J):JVM 上最成熟的深度学习框架,原生 Java API,与 Spring 等生态无缝集成

本文使用 Deeplearning4j 作为核心实现。

6. 手写数字识别实战

目标:构建一个能识别 0-9 手写数字的模型。

核心思想:通过最小化损失函数(Loss Function) 来优化模型参数。损失函数衡量预测值与真实标签之间的差距,训练的目标就是让这个差距尽可能小。

我们基于经典的 MNIST 数据集和 LeNet-5 网络结构进行实现。

6.1 输入数据准备

使用 MNIST 数据库:包含 6 万张 28×28 灰度手写数字图像,每张图都有对应标签。

ml2

数据划分:

DataSetIterator train = new RecordReaderDataSetIterator(
    new MnistRecordReader(0, 60000), // 训练集:60000 张
    batchSize, 
    10 // 10 个类别 (0-9)
);

DataSetIterator test = new RecordReaderDataSetIterator(
    new MnistRecordReader(60000, 10000), // 测试集:10000 张
    batchSize, 
    10
);

⚠️ 注意:RecordReaderDataSetIterator 负责加载、归一化和批处理数据,是 DL4J 的标准做法。

6.2 模型构建

没有“万能模型”,但 LeNet-5 在手写识别上表现优异。它是一个卷积神经网络(CNN),将 28×28 图像映射为 10 维概率向量。

ml3

输出示例:

{0.1, 0.0, 0.3, 0.2, 0.1, 0.1, 0.0, 0.1, 0.1, 0.0}

表示预测为数字 2 的概率最高(0.3),因此最终分类结果为 2。

构建模型代码:

MultiLayerNetwork model = new MultiLayerNetwork(config);

关键在 MultiLayerConfiguration config 的定义。我们定义网络层结构:

// 第一层:卷积层
ConvolutionLayer layer1 = new ConvolutionLayer
    .Builder(5, 5)           // 卷积核 5x5
    .nIn(1)                  // 输入通道数(灰度图=1)
    .stride(1, 1)            // 步长
    .nOut(20)                // 输出通道数(20 个卷积核)
    .activation(Activation.IDENTITY)
    .build();

// 第二层:池化层(最大池化)
SubsamplingLayer layer2 = new SubsamplingLayer
    .Builder(SubsamplingLayer.PoolingType.MAX)
    .kernelSize(2, 2)        // 池化窗口 2x2
    .stride(2, 2)            // 步长 2
    .build();

⚠️ 踩坑提示:nIn 必须显式指定输入通道数,否则会报错。

完整配置构建:

MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
    .seed(123)                           // 随机种子,保证可复现
    .updater(new Adam(1e-3))            // 优化器:Adam
    .list()
    .layer(0, layer1)
    .layer(1, layer2)
    // ... 添加后续层(全连接层、输出层等)
    .layer(5, new OutputLayer
        .Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
        .activation(Activation.SOFTMAX)
        .nIn(500).nOut(10).build())
    .setInputType(InputType.convolutionalFlat(28, 28, 1)) // 输入类型
    .build();

6.3 模型训练

LeNet-5 约有 43 万个参数。训练过程极其简单:

model.fit(train);

首次训练后即可达到约 99% 的准确率。为追求更高精度,可进行多轮训练(epoch):

int epochs = 5;
for (int i = 0; i < epochs; i++) {
    model.fit(train);
    train.reset();  // 重置迭代器,准备下一轮
    test.reset();
    
    // 每轮后评估
    Evaluation eval = model.evaluate(test);
    System.out.println("Epoch " + i + " - Accuracy: " + eval.accuracy());
}

评估结果输出示例:

INFO  o.d.e.Evaluation - 
=========================Evaluation Metrics=========================
 # of classes:    10
 Accuracy:        0.9878
 Precision:       0.9880
 Recall:          0.9877
 F1 Score:        0.9878

6.4 模型预测

训练完成后,即可对新图像进行预测。创建 MnistPrediction 类加载外部图像:

File file = new File("digit-2.png"); // 假设你画了一个数字 2
INDArray image = new NativeImageLoader(28, 28, 1).asMatrix(file);

// 归一化到 [0,1]
new ImagePreProcessingScaler(0, 1).transform(image);

执行预测:

INDArray output = model.output(image);
System.out.println(output); // 打印概率分布

预测结果可视化:

ml4

输出类似 [0.01, 0.00, 0.99, ...],最大值在索引 2,说明模型正确识别为数字 2 ✅。

7. 总结

本文通过手写数字识别案例,展示了 Java 中实现逻辑回归(基于神经网络)的完整流程:

  • 使用 Deeplearning4j 构建 LeNet-5 模型
  • 完成数据加载、训练、评估与预测
  • 实现了接近 99% 的高准确率

代码已托管至 GitHub,可直接运行验证:

👉 https://github.com/baeldung/tutorials/tree/master/deeplearning4j

逻辑回归虽名为“回归”,实为分类利器。结合神经网络,它在图像识别等任务中依然表现强劲。Java 开发者也能轻松玩转 ML!


原始标题:Logistic Regression in Java | Baeldung