1. 概述

TensorFlow 是一个用于数据流编程的开源库,最初由 Google 开发,支持多种平台。虽然 TensorFlow 可以在单核上运行,但也可以轻松利用多 CPU、GPU 或 TPU 提升性能。

本教程将带你了解 TensorFlow 的基本概念,并讲解如何在 Java 中使用它。需要注意的是,TensorFlow 的 Java API 目前仍处于实验阶段,因此不提供稳定性保障。我们也会在后续部分介绍一些适用场景。

2. 基础概念

TensorFlow 的计算主要围绕两个核心概念:Graph(图)和 Session(会话)。下面快速了解一下,为后续内容打基础。

2.1. TensorFlow 图(Graph)

TensorFlow 程序的基本组成单位是图(Graph)。计算过程在 TensorFlow 中以图的形式表示。图通常是一个有向无环图(DAG),包含操作(Operation)和数据(Tensor),例如:


TensorFlow-Graph-1-1

上图表示的是如下方程的计算图:

f(x, y) = z = a*x + b*y

TensorFlow 图由两个基本元素构成:

  1. Tensor(张量):TensorFlow 中数据的基本单位,图中的边表示数据流。张量可以是任意维度的,维度数量称为“秩”(rank)。例如:
    • 标量:0 维张量
    • 向量:1 维张量
    • 矩阵:2 维张量
  2. Operation(操作):图中的节点,代表各种计算操作,可以接收张量作为输入,输出新的张量。

2.2. TensorFlow 会话(Session)

图本身只是一个计算结构,不包含实际值。要让图中的张量被计算,必须在 Session 中运行图。Session 可以指定图中要计算的张量,然后从这些张量开始,逆向执行图中必要的节点。

有了这些基础知识,我们就可以开始使用 Java API 来操作 TensorFlow 了。

3. Maven 依赖配置

我们先创建一个 Maven 项目,引入 TensorFlow 的依赖。只需添加如下依赖:

<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
    <version>1.12.0</version>
</dependency>

4. 构建图(Graph)

我们来用 Java API 实现之前提到的图。具体来说,我们要计算如下函数:

z = 3*x + 2*y

第一步是创建一个图对象:

Graph graph = new Graph()

4.1. 定义常量(Constants)

常量操作需要提供一个张量作为值:

Operation a = graph.opBuilder("Const", "a")
  .setAttr("dtype", DataType.fromClass(Double.class))
  .setAttr("value", Tensor.<Double>create(3.0, Double.class))
  .build();        
Operation b = graph.opBuilder("Const", "b")
  .setAttr("dtype", DataType.fromClass(Double.class))
  .setAttr("value", Tensor.<Double>create(2.0, Double.class))
  .build();

我们定义了两个常量 ab,分别代表 3.02.0

4.2. 定义占位符(Placeholders)

占位符在定义时不需要值,而是在运行时通过 Session 传入:

Operation x = graph.opBuilder("Placeholder", "x")
  .setAttr("dtype", DataType.fromClass(Double.class))
  .build();            
Operation y = graph.opBuilder("Placeholder", "y")
  .setAttr("dtype", DataType.fromClass(Double.class))
  .build();

4.3. 定义函数操作

接下来定义乘法和加法操作:

Operation ax = graph.opBuilder("Mul", "ax")
  .addInput(a.output(0))
  .addInput(x.output(0))
  .build();            
Operation by = graph.opBuilder("Mul", "by")
  .addInput(b.output(0))
  .addInput(y.output(0))
  .build();
Operation z = graph.opBuilder("Add", "z")
  .addInput(ax.output(0))
  .addInput(by.output(0))
  .build();

我们定义了三个操作:两个乘法和一个加法。每个操作的输入是之前操作的输出张量。

5. 图的可视化

随着图的复杂度增加,手动追踪变得困难。TensorFlow 提供了 TensorBoard 来可视化图结构

遗憾的是,Java API 无法直接生成 TensorBoard 所需的 event 文件,但可以用 Python 生成:

writer = tf.summary.FileWriter('.')
......
writer.add_graph(tf.get_default_graph())
writer.flush()

然后在命令行中启动 TensorBoard:

tensorboard --logdir .

mul

6. 使用 Session 执行图

现在我们已经构建了图,但还没运行。如果直接打印输出:

System.out.println(z.output(0));

会得到类似如下输出:

<Add 'z:0' shape=<unknown> dtype=DOUBLE>

这只是描述了一个张量,并没有实际值。我们需要创建一个 Session 来运行图:

Session sess = new Session(graph)

然后运行并获取结果:

Tensor<Double> tensor = sess.runner().fetch("z")
  .feed("x", Tensor.<Double>create(3.0, Double.class))
  .feed("y", Tensor.<Double>create(6.0, Double.class))
  .run().get(0).expect(Double.class);
System.out.println(tensor.doubleValue());

输出为:

21.0

✅ 正确结果!

7. Java API 的使用场景

虽然 TensorFlow 用在简单计算上显得有点“杀鸡用牛刀”,但它的真正优势在于处理大型机器学习模型

Java API 并不适合用于构建和训练复杂模型,但我们可以:

✅ 使用 Python 构建并训练模型
✅ 将训练好的模型导出为 Protocol Buffer 文件
✅ 在 Java 中加载并使用该模型

这在 Android 或其他 Java 客户端中特别有用,比如为图片推荐标题等场景。

8. 使用保存的模型(SavedModel)

TensorFlow 支持将模型保存为语言和平台无关的格式(Protocol Buffer),可以在不同语言中加载使用。

8.1. 保存模型

在 Python 中保存模型:

import tensorflow as tf
graph = tf.Graph()
builder = tf.saved_model.builder.SavedModelBuilder('./model')
with graph.as_default():
  a = tf.constant(2, name='a')
  b = tf.constant(3, name='b')
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  z = tf.math.add(a*x, b*y, name='z')
  sess = tf.Session()
  sess.run(z, feed_dict = {x: 2, y: 3})
  builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
  builder.save()

这段代码会生成一个 saved_model.pb 文件。

8.2. 加载模型

Java 中加载模型:

SavedModelBundle model = SavedModelBundle.load("./model", "serve");    
Tensor<Integer> tensor = model.session().runner().fetch("z")
  .feed("x", Tensor.<Integer>create(3, Integer.class))
  .feed("y", Tensor.<Integer>create(3, Integer.class))
  .run().get(0).expect(Integer.class);    
System.out.println(tensor.intValue());

输出结果为:

15

9. 总结

本教程介绍了 TensorFlow 的核心概念:图和会话,并演示了如何使用 Java API 创建、运行和可视化图。我们还讲解了如何保存和加载模型,以便在 Java 中复用 Python 训练的模型。

通过这些内容,你应该对 TensorFlow Java API 有了基本了解,能够在实际项目中合理使用它。


原始标题:Introduction to Tensorflow for Java | Baeldung