1. 概述
Java 8 提供了创建 Javac 插件的 API,但可惜的是,相关文档少得可怜。
本文将完整演示如何开发一个编译器扩展,在 .class 文件中注入自定义代码。
2. 环境准备
首先需要添加 JDK 的 tools.jar 作为项目依赖:
<dependency>
<groupId>com.sun</groupId>
<artifactId>tools</artifactId>
<version>1.8.0</version>
<scope>system</scope>
<systemPath>${java.home}/../lib/tools.jar</systemPath>
</dependency>
每个编译器扩展都必须实现 com.sun.source.util.Plugin 接口。我们创建示例类:
public class SampleJavacPlugin implements Plugin {
@Override
public String getName() {
return "MyPlugin";
}
@Override
public void init(JavacTask task, String... args) {
Context context = ((BasicJavacTask) task).getContext();
Log.instance(context)
.printRawLines(Log.WriterKind.NOTICE, "Hello from " + getName());
}
}
目前我们只打印 "Hello" 来验证插件是否被正确加载。
最终目标是创建一个插件:为所有带特定注解的数值参数添加运行时检查,当参数不满足条件时抛出异常。
要让 Javac 发现插件,必须通过 ServiceLoader 框架暴露。操作步骤:
- 创建文件
META-INF/services/com.sun.source.util.Plugin
- 文件内容为插件全限定类名(如
com.baeldung.javac.SampleJavacPlugin
)
之后可通过 -Xplugin:MyPlugin
参数调用插件:
baeldung/tutorials$ javac -cp ./core-java/target/classes -Xplugin:MyPlugin ./core-java/src/main/java/com/baeldung/javac/TestClass.java
Hello from MyPlugin
⚠️ 注意:-Xplugin
参数值必须与插件 getName()
返回的字符串完全一致
3. 插件生命周期
插件仅通过 init()
方法被编译器调用一次。要监听后续事件,需注册回调。每个源文件的处理阶段都会触发事件:
- PARSE – 构建抽象语法树 (AST)
- ENTER – 解析源码导入
- ANALYZE – 分析 AST 错误
- GENERATE – 生成目标文件二进制码
还有 ANNOTATION_PROCESSING 和 ANNOTATION_PROCESSING_ROUND 事件,本文不涉及。
例如要在 PARSE 完成后增强编译:
public void init(JavacTask task, String... args) {
task.addTaskListener(new TaskListener() {
public void started(TaskEvent e) {
}
public void finished(TaskEvent e) {
if (e.getKind() != TaskEvent.Kind.PARSE) {
return;
}
// 执行代码插桩
}
});
}
4. 提取 AST 数据
通过 TaskEvent.getCompilationUnit()
获取编译器生成的 AST,使用 TreeVisitor
接口检查细节。
⚠️ 关键点:只有调用 accept()
方法的 Tree
元素才会触发访问器事件。例如调用 ClassTree.accept(visitor)
只触发 visitClass()
,不会自动触发类中方法的 visitMethod()
。
使用 TreeScanner
解决递归遍历问题:
public void finished(TaskEvent e) {
if (e.getKind() != TaskEvent.Kind.PARSE) {
return;
}
e.getCompilationUnit().accept(new TreeScanner<Void, Void>() {
@Override
public Void visitClass(ClassTree node, Void aVoid) {
return super.visitClass(node, aVoid);
}
@Override
public Void visitMethod(MethodTree node, Void aVoid) {
return super.visitMethod(node, aVoid);
}
}, null);
}
✅ 必须调用 super.visitXxx(node, value)
才能递归处理子节点
5. 修改 AST
我们将演示如何修改 AST:为所有带 @Positive
注解的数值参数插入运行时检查。
先定义注解:
@Documented
@Retention(RetentionPolicy.CLASS)
@Target({ElementType.PARAMETER})
public @interface Positive { }
使用示例:
public void service(@Positive int i) { }
最终目标:让字节码等价于以下源码编译结果:
public void service(@Positive int i) {
if (i <= 0) {
throw new IllegalArgumentException("A non-positive argument ("
+ i + ") is given as a @Positive parameter 'i'");
}
}
核心目标:当 @Positive
标记的参数 ≤ 0 时抛出 IllegalArgumentException
5.1. 确定插桩位置
定位需要插桩的参数:
private static Set<String> TARGET_TYPES = Stream.of(
byte.class, short.class, char.class,
int.class, long.class, float.class, double.class)
.map(Class::getName)
.collect(Collectors.toSet());
为简化,仅包含基本数值类型。定义检查方法:
private boolean shouldInstrument(VariableTree parameter) {
return TARGET_TYPES.contains(parameter.getType().toString())
&& parameter.getModifiers().getAnnotations().stream()
.anyMatch(a -> Positive.class.getSimpleName()
.equals(a.getAnnotationType().toString()));
}
在 finished()
方法中应用检查:
public void finished(TaskEvent e) {
if (e.getKind() != TaskEvent.Kind.PARSE) {
return;
}
e.getCompilationUnit().accept(new TreeScanner<Void, Void>() {
@Override
public Void visitMethod(MethodTree method, Void v) {
List<VariableTree> parametersToInstrument
= method.getParameters().stream()
.filter(SampleJavacPlugin.this::shouldInstrument)
.collect(Collectors.toList());
if (!parametersToInstrument.isEmpty()) {
Collections.reverse(parametersToInstrument);
parametersToInstrument.forEach(p -> addCheck(method, p, context));
}
return super.visitMethod(method, v);
}
}, null);
🔧 反转参数列表:当多个参数带 @Positive
时,检查代码会作为方法首指令插入,逆序处理确保正确顺序。
5.2. 执行插桩操作
难点:AST 读取属于公开 API,而修改 AST(如添加检查)属于私有 API。
解决方案:通过 TreeMaker
实例创建新 AST 元素。
获取 Context
实例:
@Override
public void init(JavacTask task, String... args) {
Context context = ((BasicJavacTask) task).getContext();
// ...
}
通过 TreeMaker.instance(Context)
获取 TreeMaker
。构建 if
表达式示例:
private static JCTree.JCIf createCheck(VariableTree parameter, Context context) {
TreeMaker factory = TreeMaker.instance(context);
Names symbolsTable = Names.instance(context);
return factory.at(((JCTree) parameter).pos)
.If(factory.Parens(createIfCondition(factory, symbolsTable, parameter)),
createIfBlock(factory, symbolsTable, parameter),
null);
}
📍 关键技巧:使用 factory.at(((JCTree) parameter).pos)
确保异常堆栈行号正确。
构建条件表达式(parameterId < 0
):
private static JCTree.JCBinary createIfCondition(TreeMaker factory,
Names symbolsTable, VariableTree parameter) {
Name parameterId = symbolsTable.fromString(parameter.getName().toString());
return factory.Binary(JCTree.Tag.LE,
factory.Ident(parameterId),
factory.Literal(TypeTag.INT, 0));
}
构建异常抛出块:
private static JCTree.JCBlock createIfBlock(TreeMaker factory,
Names symbolsTable, VariableTree parameter) {
String parameterName = parameter.getName().toString();
Name parameterId = symbolsTable.fromString(parameterName);
String errorMessagePrefix = String.format(
"Argument '%s' of type %s is marked by @%s but got '",
parameterName, parameter.getType(), Positive.class.getSimpleName());
String errorMessageSuffix = "' for it";
return factory.Block(0, com.sun.tools.javac.util.List.of(
factory.Throw(
factory.NewClass(null, nil(),
factory.Ident(symbolsTable.fromString(
IllegalArgumentException.class.getSimpleName())),
com.sun.tools.javac.util.List.of(factory.Binary(JCTree.Tag.PLUS,
factory.Binary(JCTree.Tag.PLUS,
factory.Literal(TypeTag.CLASS, errorMessagePrefix),
factory.Ident(parameterId)),
factory.Literal(TypeTag.CLASS, errorMessageSuffix))), null))));
}
将新元素插入 AST:
private void addCheck(MethodTree method, VariableTree parameter, Context context) {
JCTree.JCIf check = createCheck(parameter, context);
JCTree.JCBlock body = (JCTree.JCBlock) method.getBody();
body.stats = body.stats.prepend(check);
}
6. 测试插件
测试包含两个步骤:
- 编译测试源码
- 运行编译结果验证行为
需要辅助类实现内存编译:
SimpleSourceFile 暴露源码给 Javac:
public class SimpleSourceFile extends SimpleJavaFileObject {
private String content;
public SimpleSourceFile(String qualifiedClassName, String testSource) {
super(URI.create(String.format(
"file://%s%s", qualifiedClassName.replaceAll("\\.", "/"),
Kind.SOURCE.extension)), Kind.SOURCE);
content = testSource;
}
@Override
public CharSequence getCharContent(boolean ignoreEncodingErrors) {
return content;
}
}
SimpleClassFile 存储编译结果:
public class SimpleClassFile extends SimpleJavaFileObject {
private ByteArrayOutputStream out;
public SimpleClassFile(URI uri) {
super(uri, Kind.CLASS);
}
@Override
public OutputStream openOutputStream() throws IOException {
return out = new ByteArrayOutputStream();
}
public byte[] getCompiledBinaries() {
return out.toByteArray();
}
// getters
}
SimpleFileManager 管理输出:
public class SimpleFileManager
extends ForwardingJavaFileManager<StandardJavaFileManager> {
private List<SimpleClassFile> compiled = new ArrayList<>();
// standard constructors/getters
@Override
public JavaFileObject getJavaFileForOutput(Location location,
String className, JavaFileObject.Kind kind, FileObject sibling) {
SimpleClassFile result = new SimpleClassFile(
URI.create("string://" + className));
compiled.add(result);
return result;
}
public List<SimpleClassFile> getCompiled() {
return compiled;
}
}
内存编译核心逻辑:
public class TestCompiler {
public byte[] compile(String qualifiedClassName, String testSource) {
StringWriter output = new StringWriter();
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
SimpleFileManager fileManager = new SimpleFileManager(
compiler.getStandardFileManager(null, null, null));
List<SimpleSourceFile> compilationUnits
= singletonList(new SimpleSourceFile(qualifiedClassName, testSource));
List<String> arguments = new ArrayList<>();
arguments.addAll(asList("-classpath", System.getProperty("java.class.path"),
"-Xplugin:" + SampleJavacPlugin.NAME));
JavaCompiler.CompilationTask task
= compiler.getTask(output, fileManager, null, arguments, null,
compilationUnits);
task.call();
return fileManager.getCompiled().iterator().next().getCompiledBinaries();
}
}
运行编译结果:
public class TestRunner {
public Object run(byte[] byteCode, String qualifiedClassName, String methodName,
Class<?>[] argumentTypes, Object... args) throws Throwable {
ClassLoader classLoader = new ClassLoader() {
@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
return defineClass(name, byteCode, 0, byteCode.length);
}
};
Class<?> clazz;
try {
clazz = classLoader.loadClass(qualifiedClassName);
} catch (ClassNotFoundException e) {
throw new RuntimeException("Can't load compiled test class", e);
}
Method method;
try {
method = clazz.getMethod(methodName, argumentTypes);
} catch (NoSuchMethodException e) {
throw new RuntimeException(
"Can't find the 'main()' method in the compiled test class", e);
}
try {
return method.invoke(null, args);
} catch (InvocationTargetException e) {
throw e.getCause();
}
}
}
测试用例示例:
public class SampleJavacPluginTest {
private static final String CLASS_TEMPLATE
= "package com.baeldung.javac;\n\n" +
"public class Test {\n" +
" public static %1$s service(@Positive %1$s i) {\n" +
" return i;\n" +
" }\n" +
"}\n" +
"";
private TestCompiler compiler = new TestCompiler();
private TestRunner runner = new TestRunner();
@Test(expected = IllegalArgumentException.class)
public void givenInt_whenNegative_thenThrowsException() throws Throwable {
compileAndRun(double.class,-1);
}
private Object compileAndRun(Class<?> argumentType, Object argument)
throws Throwable {
String qualifiedClassName = "com.baeldung.javac.Test";
byte[] byteCode = compiler.compile(qualifiedClassName,
String.format(CLASS_TEMPLATE, argumentType.getName()));
return runner.run(byteCode, qualifiedClassName,
"service", new Class[] {argumentType}, argument);
}
}
测试逻辑:编译带 @Positive
参数的 Test
类,传入 -1 触发异常。插件生效时会抛出 IllegalArgumentException
。
7. 总结
本文完整演示了 Java 编译器插件的创建、测试和运行流程。
完整示例代码可在 GitHub 获取。