1. 概述
TensorFlow 是一个用于数据流编程的开源库,最初由 Google 开发,支持多种平台。虽然 TensorFlow 可以在单核上运行,但也可以轻松利用多 CPU、GPU 或 TPU 提升性能。
本教程将带你了解 TensorFlow 的基本概念,并讲解如何在 Java 中使用它。需要注意的是,TensorFlow 的 Java API 目前仍处于实验阶段,因此不提供稳定性保障。我们也会在后续部分介绍一些适用场景。
2. 基础概念
TensorFlow 的计算主要围绕两个核心概念:Graph(图)和 Session(会话)。下面快速了解一下,为后续内容打基础。
2.1. TensorFlow 图(Graph)
TensorFlow 程序的基本组成单位是图(Graph)。计算过程在 TensorFlow 中以图的形式表示。图通常是一个有向无环图(DAG),包含操作(Operation)和数据(Tensor),例如:
上图表示的是如下方程的计算图:
f(x, y) = z = a*x + b*y
TensorFlow 图由两个基本元素构成:
- Tensor(张量):TensorFlow 中数据的基本单位,图中的边表示数据流。张量可以是任意维度的,维度数量称为“秩”(rank)。例如:
- 标量:0 维张量
- 向量:1 维张量
- 矩阵:2 维张量
- Operation(操作):图中的节点,代表各种计算操作,可以接收张量作为输入,输出新的张量。
2.2. TensorFlow 会话(Session)
图本身只是一个计算结构,不包含实际值。要让图中的张量被计算,必须在 Session 中运行图。Session 可以指定图中要计算的张量,然后从这些张量开始,逆向执行图中必要的节点。
有了这些基础知识,我们就可以开始使用 Java API 来操作 TensorFlow 了。
3. Maven 依赖配置
我们先创建一个 Maven 项目,引入 TensorFlow 的依赖。只需添加如下依赖:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.12.0</version>
</dependency>
4. 构建图(Graph)
我们来用 Java API 实现之前提到的图。具体来说,我们要计算如下函数:
z = 3*x + 2*y
第一步是创建一个图对象:
Graph graph = new Graph()
4.1. 定义常量(Constants)
常量操作需要提供一个张量作为值:
Operation a = graph.opBuilder("Const", "a")
.setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.<Double>create(3.0, Double.class))
.build();
Operation b = graph.opBuilder("Const", "b")
.setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.<Double>create(2.0, Double.class))
.build();
我们定义了两个常量 a
和 b
,分别代表 3.0
和 2.0
。
4.2. 定义占位符(Placeholders)
占位符在定义时不需要值,而是在运行时通过 Session 传入:
Operation x = graph.opBuilder("Placeholder", "x")
.setAttr("dtype", DataType.fromClass(Double.class))
.build();
Operation y = graph.opBuilder("Placeholder", "y")
.setAttr("dtype", DataType.fromClass(Double.class))
.build();
4.3. 定义函数操作
接下来定义乘法和加法操作:
Operation ax = graph.opBuilder("Mul", "ax")
.addInput(a.output(0))
.addInput(x.output(0))
.build();
Operation by = graph.opBuilder("Mul", "by")
.addInput(b.output(0))
.addInput(y.output(0))
.build();
Operation z = graph.opBuilder("Add", "z")
.addInput(ax.output(0))
.addInput(by.output(0))
.build();
我们定义了三个操作:两个乘法和一个加法。每个操作的输入是之前操作的输出张量。
5. 图的可视化
随着图的复杂度增加,手动追踪变得困难。TensorFlow 提供了 TensorBoard 来可视化图结构。
遗憾的是,Java API 无法直接生成 TensorBoard 所需的 event 文件,但可以用 Python 生成:
writer = tf.summary.FileWriter('.')
......
writer.add_graph(tf.get_default_graph())
writer.flush()
然后在命令行中启动 TensorBoard:
tensorboard --logdir .
6. 使用 Session 执行图
现在我们已经构建了图,但还没运行。如果直接打印输出:
System.out.println(z.output(0));
会得到类似如下输出:
<Add 'z:0' shape=<unknown> dtype=DOUBLE>
这只是描述了一个张量,并没有实际值。我们需要创建一个 Session 来运行图:
Session sess = new Session(graph)
然后运行并获取结果:
Tensor<Double> tensor = sess.runner().fetch("z")
.feed("x", Tensor.<Double>create(3.0, Double.class))
.feed("y", Tensor.<Double>create(6.0, Double.class))
.run().get(0).expect(Double.class);
System.out.println(tensor.doubleValue());
输出为:
21.0
✅ 正确结果!
7. Java API 的使用场景
虽然 TensorFlow 用在简单计算上显得有点“杀鸡用牛刀”,但它的真正优势在于处理大型机器学习模型。
Java API 并不适合用于构建和训练复杂模型,但我们可以:
✅ 使用 Python 构建并训练模型
✅ 将训练好的模型导出为 Protocol Buffer 文件
✅ 在 Java 中加载并使用该模型
这在 Android 或其他 Java 客户端中特别有用,比如为图片推荐标题等场景。
8. 使用保存的模型(SavedModel)
TensorFlow 支持将模型保存为语言和平台无关的格式(Protocol Buffer),可以在不同语言中加载使用。
8.1. 保存模型
在 Python 中保存模型:
import tensorflow as tf
graph = tf.Graph()
builder = tf.saved_model.builder.SavedModelBuilder('./model')
with graph.as_default():
a = tf.constant(2, name='a')
b = tf.constant(3, name='b')
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
z = tf.math.add(a*x, b*y, name='z')
sess = tf.Session()
sess.run(z, feed_dict = {x: 2, y: 3})
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
builder.save()
这段代码会生成一个 saved_model.pb
文件。
8.2. 加载模型
Java 中加载模型:
SavedModelBundle model = SavedModelBundle.load("./model", "serve");
Tensor<Integer> tensor = model.session().runner().fetch("z")
.feed("x", Tensor.<Integer>create(3, Integer.class))
.feed("y", Tensor.<Integer>create(3, Integer.class))
.run().get(0).expect(Integer.class);
System.out.println(tensor.intValue());
输出结果为:
15
9. 总结
本教程介绍了 TensorFlow 的核心概念:图和会话,并演示了如何使用 Java API 创建、运行和可视化图。我们还讲解了如何保存和加载模型,以便在 Java 中复用 Python 训练的模型。
通过这些内容,你应该对 TensorFlow Java API 有了基本了解,能够在实际项目中合理使用它。