1. 概述
在本教程中,我们将探讨如何在 Java 中实现两个矩阵的乘法运算。
由于 Java 本身并未原生支持矩阵数据结构,因此我们会手动实现一个基础版本,并使用几个流行的库来完成矩阵乘法。最后,我们还将对不同实现方式进行性能基准测试,看看哪种方式最快。
2. 示例说明
我们以一个 3×2 的矩阵为例:
再来看一个 2×4 的矩阵:
将它们相乘后,得到一个 3×4 的结果矩阵:
矩阵乘法的基本公式如下:
其中 r 是矩阵 A 的行数,c 是矩阵 B 的列数,n 是矩阵 A 的列数(必须等于矩阵 B 的行数)。
3. 矩阵乘法实现
3.1. 自定义实现
我们使用二维 double
数组来表示矩阵:
double[][] firstMatrix = {
new double[]{1d, 5d},
new double[]{2d, 3d},
new double[]{1d, 7d}
};
double[][] secondMatrix = {
new double[]{1d, 2d, 3d, 7d},
new double[]{5d, 2d, 8d, 1d}
};
定义预期结果:
double[][] expected = {
new double[]{26d, 12d, 43d, 12d},
new double[]{17d, 10d, 30d, 17d},
new double[]{36d, 16d, 59d, 14d}
};
接下来是矩阵乘法的实现:
double[][] multiplyMatrices(double[][] firstMatrix, double[][] secondMatrix) {
double[][] result = new double[firstMatrix.length][secondMatrix[0].length];
for (int row = 0; row < result.length; row++) {
for (int col = 0; col < result[row].length; col++) {
result[row][col] = multiplyMatricesCell(firstMatrix, secondMatrix, row, col);
}
}
return result;
}
计算单个单元格的值:
double multiplyMatricesCell(double[][] firstMatrix, double[][] secondMatrix, int row, int col) {
double cell = 0;
for (int i = 0; i < secondMatrix.length; i++) {
cell += firstMatrix[row][i] * secondMatrix[i][col];
}
return cell;
}
验证结果是否符合预期:
double[][] actual = multiplyMatrices(firstMatrix, secondMatrix);
assertThat(actual).isEqualTo(expected);
✅ 实现简单,适合教学或小规模数据。
❌ 性能一般,不适合大规模数据。
3.2. EJML (Efficient Java Matrix Library)
EJML 是一个专注于性能的 Java 矩阵库。
添加依赖:
<dependency>
<groupId>org.ejml</groupId>
<artifactId>ejml-all</artifactId>
<version>0.38</version>
</dependency>
创建矩阵:
SimpleMatrix firstMatrix = new SimpleMatrix(new double[][] {
{1d, 5d},
{2d, 3d},
{1d, 7d}
});
SimpleMatrix secondMatrix = new SimpleMatrix(new double[][] {
{1d, 2d, 3d, 7d},
{5d, 2d, 8d, 1d}
});
定义预期结果:
SimpleMatrix expected = new SimpleMatrix(new double[][] {
{26d, 12d, 43d, 12d},
{17d, 10d, 30d, 17d},
{36d, 16d, 59d, 14d}
});
执行乘法并验证:
SimpleMatrix actual = firstMatrix.mult(secondMatrix);
assertThat(actual).matches(m -> m.isIdentical(expected, 0d));
✅ 性能优秀,适合中小型数据。
❌ API 略显繁琐。
3.3. ND4J (Numerical Computing for Java)
ND4J 是 deeplearning4j 生态中的数值计算库,支持矩阵运算。
添加依赖:
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>1.0.0-beta4</version>
</dependency>
创建矩阵:
INDArray firstMatrix = Nd4j.create(new double[][] {
{1d, 5d},
{2d, 3d},
{1d, 7d}
});
INDArray secondMatrix = Nd4j.create(new double[][] {
{1d, 2d, 3d, 7d},
{5d, 2d, 8d, 1d}
});
定义预期结果:
INDArray expected = Nd4j.create(new double[][] {
{26d, 12d, 43d, 12d},
{17d, 10d, 30d, 17d},
{36d, 16d, 59d, 14d}
});
执行乘法并验证:
INDArray actual = firstMatrix.mmul(secondMatrix);
assertThat(actual).isEqualTo(expected);
✅ 大规模数据性能极佳,适合深度学习场景。
❌ 依赖较大,学习曲线陡峭。
3.4. Apache Commons Math3
Apache Commons Math3 是一个通用数学库,支持矩阵操作。
添加依赖:
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>
创建矩阵:
RealMatrix firstMatrix = new Array2DRowRealMatrix(new double[][] {
{1d, 5d},
{2d, 3d},
{1d, 7d}
});
RealMatrix secondMatrix = new Array2DRowRealMatrix(new double[][] {
{1d, 2d, 3d, 7d},
{5d, 2d, 8d, 1d}
});
定义预期结果:
RealMatrix expected = new Array2DRowRealMatrix(new double[][] {
{26d, 12d, 43d, 12d},
{17d, 10d, 30d, 17d},
{36d, 16d, 59d, 14d}
});
执行乘法并验证:
RealMatrix actual = firstMatrix.multiply(secondMatrix);
assertThat(actual).isEqualTo(expected);
✅ 稳定性好,适合通用场景。
❌ 性能中等,不适合大规模数据。
3.5. LA4J (Linear Algebra for Java)
LA4J 是一个专注于线性代数的库。
添加依赖:
<dependency>
<groupId>org.la4j</groupId>
<artifactId>la4j</artifactId>
<version>0.6.0</version>
</dependency>
创建矩阵:
Matrix firstMatrix = new Basic2DMatrix(new double[][] {
{1d, 5d},
{2d, 3d},
{1d, 7d}
});
Matrix secondMatrix = new Basic2DMatrix(new double[][] {
{1d, 2d, 3d, 7d},
{5d, 2d, 8d, 1d}
});
定义预期结果:
Matrix expected = new Basic2DMatrix(new double[][] {
{26d, 12d, 43d, 12d},
{17d, 10d, 30d, 17d},
{36d, 16d, 59d, 14d}
});
执行乘法并验证:
Matrix actual = firstMatrix.multiply(secondMatrix);
assertThat(actual).isEqualTo(expected);
✅ 简洁易用,适合教学或中等规模数据。
❌ 性能不如 EJML 和 ND4J。
3.6. Colt
Colt 是由 CERN 开发的高性能科学计算库。
添加依赖:
<dependency>
<groupId>colt</groupId>
<artifactId>colt</artifactId>
<version>1.2.0</version>
</dependency>
创建矩阵:
DoubleFactory2D doubleFactory2D = DoubleFactory2D.dense;
DoubleMatrix2D firstMatrix = doubleFactory2D.make(new double[][] {
{1d, 5d},
{2d, 3d},
{1d, 7d}
});
DoubleMatrix2D secondMatrix = doubleFactory2D.make(new double[][] {
{1d, 2d, 3d, 7d},
{5d, 2d, 8d, 1d}
});
定义预期结果:
DoubleMatrix2D expected = doubleFactory2D.make(new double[][] {
{26d, 12d, 43d, 12d},
{17d, 10d, 30d, 17d},
{36d, 16d, 59d, 14d}
});
执行乘法并验证:
Algebra algebra = new Algebra();
DoubleMatrix2D actual = algebra.mult(firstMatrix, secondMatrix);
assertThat(actual).isEqualTo(expected);
✅ 性能较好,适合科学计算。
❌ 文档较少,社区活跃度一般。
4. 性能测试
4.1. 小型矩阵(3×2 和 2×4)
库 | 耗时(μs/op) |
---|---|
自定义实现 | 0.389 |
EJML | 0.226 |
Colt | 0.219 |
LA4J | 0.427 |
Apache Commons Math3 | 1.008 |
ND4J | 12.670 |
✅ EJML 和 Colt 表现最佳。
❌ ND4J 在小型矩阵上表现不佳。
4.2. 大型矩阵(3000×3000)
库 | 耗时(秒/op) |
---|---|
自定义实现 | 497.493 |
Apache Commons Math3 | 511.140 |
Colt | 197.914 |
LA4J | 35.523 |
EJML | 25.830 |
ND4J | 0.548 |
✅ ND4J 在大型矩阵上遥遥领先。
❌ 自定义实现和 Apache Commons 性能极差。
5. 总结
场景 | 推荐库 |
---|---|
教学/小规模数据 | 自定义实现、EJML、LA4J |
科学计算 | Colt |
大规模/深度学习 | ND4J |
通用数学计算 | Apache Commons Math3 |
⚠️ 选择库时需根据矩阵规模和用途综合考虑。
完整示例代码可在 GitHub 获取。