1. 概述

Java 8 提供了创建 Javac 插件的 API,但可惜的是,相关文档少得可怜。

本文将完整演示如何开发一个编译器扩展,在 .class 文件中注入自定义代码。

2. 环境准备

首先需要添加 JDK 的 tools.jar 作为项目依赖:

<dependency>
    <groupId>com.sun</groupId>
    <artifactId>tools</artifactId>
    <version>1.8.0</version>
    <scope>system</scope>
    <systemPath>${java.home}/../lib/tools.jar</systemPath>
</dependency>

每个编译器扩展都必须实现 com.sun.source.util.Plugin 接口。我们创建示例类:

public class SampleJavacPlugin implements Plugin {

    @Override
    public String getName() {
        return "MyPlugin";
    }

    @Override
    public void init(JavacTask task, String... args) {
        Context context = ((BasicJavacTask) task).getContext();
        Log.instance(context)
          .printRawLines(Log.WriterKind.NOTICE, "Hello from " + getName());
    }
}

目前我们只打印 "Hello" 来验证插件是否被正确加载。

最终目标是创建一个插件:为所有带特定注解的数值参数添加运行时检查,当参数不满足条件时抛出异常。

要让 Javac 发现插件,必须通过 ServiceLoader 框架暴露。操作步骤:

  1. 创建文件 META-INF/services/com.sun.source.util.Plugin
  2. 文件内容为插件全限定类名(如 com.baeldung.javac.SampleJavacPlugin

之后可通过 -Xplugin:MyPlugin 参数调用插件:

baeldung/tutorials$ javac -cp ./core-java/target/classes -Xplugin:MyPlugin ./core-java/src/main/java/com/baeldung/javac/TestClass.java
Hello from MyPlugin

⚠️ 注意:-Xplugin 参数值必须与插件 getName() 返回的字符串完全一致

3. 插件生命周期

插件仅通过 init() 方法被编译器调用一次。要监听后续事件,需注册回调。每个源文件的处理阶段都会触发事件:

  • PARSE – 构建抽象语法树 (AST)
  • ENTER – 解析源码导入
  • ANALYZE – 分析 AST 错误
  • GENERATE – 生成目标文件二进制码

还有 ANNOTATION_PROCESSINGANNOTATION_PROCESSING_ROUND 事件,本文不涉及。

例如要在 PARSE 完成后增强编译:

public void init(JavacTask task, String... args) {
    task.addTaskListener(new TaskListener() {
        public void started(TaskEvent e) {
        }

        public void finished(TaskEvent e) {
            if (e.getKind() != TaskEvent.Kind.PARSE) {
                return;
            }
            // 执行代码插桩
        }
    });
}

4. 提取 AST 数据

通过 TaskEvent.getCompilationUnit() 获取编译器生成的 AST,使用 TreeVisitor 接口检查细节。

⚠️ 关键点:只有调用 accept() 方法的 Tree 元素才会触发访问器事件。例如调用 ClassTree.accept(visitor) 只触发 visitClass(),不会自动触发类中方法的 visitMethod()

使用 TreeScanner 解决递归遍历问题:

public void finished(TaskEvent e) {
    if (e.getKind() != TaskEvent.Kind.PARSE) {
        return;
    }
    e.getCompilationUnit().accept(new TreeScanner<Void, Void>() {
        @Override
        public Void visitClass(ClassTree node, Void aVoid) {
            return super.visitClass(node, aVoid);
        }

        @Override
        public Void visitMethod(MethodTree node, Void aVoid) {
            return super.visitMethod(node, aVoid);
        }
    }, null);
}

必须调用 super.visitXxx(node, value) 才能递归处理子节点

5. 修改 AST

我们将演示如何修改 AST:为所有带 @Positive 注解的数值参数插入运行时检查。

先定义注解:

@Documented
@Retention(RetentionPolicy.CLASS)
@Target({ElementType.PARAMETER})
public @interface Positive { }

使用示例:

public void service(@Positive int i) { }

最终目标:让字节码等价于以下源码编译结果:

public void service(@Positive int i) {
    if (i <= 0) {
        throw new IllegalArgumentException("A non-positive argument ("
          + i + ") is given as a @Positive parameter 'i'");
    }
}

核心目标:当 @Positive 标记的参数 ≤ 0 时抛出 IllegalArgumentException

5.1. 确定插桩位置

定位需要插桩的参数:

private static Set<String> TARGET_TYPES = Stream.of(
  byte.class, short.class, char.class, 
  int.class, long.class, float.class, double.class)
 .map(Class::getName)
 .collect(Collectors.toSet());

为简化,仅包含基本数值类型。定义检查方法:

private boolean shouldInstrument(VariableTree parameter) {
    return TARGET_TYPES.contains(parameter.getType().toString())
      && parameter.getModifiers().getAnnotations().stream()
      .anyMatch(a -> Positive.class.getSimpleName()
        .equals(a.getAnnotationType().toString()));
}

finished() 方法中应用检查:

public void finished(TaskEvent e) {
    if (e.getKind() != TaskEvent.Kind.PARSE) {
        return;
    }
    e.getCompilationUnit().accept(new TreeScanner<Void, Void>() {
        @Override
        public Void visitMethod(MethodTree method, Void v) {
            List<VariableTree> parametersToInstrument
              = method.getParameters().stream()
              .filter(SampleJavacPlugin.this::shouldInstrument)
              .collect(Collectors.toList());
            
              if (!parametersToInstrument.isEmpty()) {
                Collections.reverse(parametersToInstrument);
                parametersToInstrument.forEach(p -> addCheck(method, p, context));
            }
            return super.visitMethod(method, v);
        }
    }, null);

🔧 反转参数列表:当多个参数带 @Positive 时,检查代码会作为方法首指令插入,逆序处理确保正确顺序。

5.2. 执行插桩操作

难点:AST 读取属于公开 API,而修改 AST(如添加检查)属于私有 API。

解决方案:通过 TreeMaker 实例创建新 AST 元素

获取 Context 实例:

@Override
public void init(JavacTask task, String... args) {
    Context context = ((BasicJavacTask) task).getContext();
    // ...
}

通过 TreeMaker.instance(Context) 获取 TreeMaker。构建 if 表达式示例:

private static JCTree.JCIf createCheck(VariableTree parameter, Context context) {
    TreeMaker factory = TreeMaker.instance(context);
    Names symbolsTable = Names.instance(context);
        
    return factory.at(((JCTree) parameter).pos)
      .If(factory.Parens(createIfCondition(factory, symbolsTable, parameter)),
        createIfBlock(factory, symbolsTable, parameter), 
        null);
}

📍 关键技巧:使用 factory.at(((JCTree) parameter).pos) 确保异常堆栈行号正确。

构建条件表达式(parameterId < 0):

private static JCTree.JCBinary createIfCondition(TreeMaker factory, 
  Names symbolsTable, VariableTree parameter) {
    Name parameterId = symbolsTable.fromString(parameter.getName().toString());
    return factory.Binary(JCTree.Tag.LE, 
      factory.Ident(parameterId), 
      factory.Literal(TypeTag.INT, 0));
}

构建异常抛出块:

private static JCTree.JCBlock createIfBlock(TreeMaker factory, 
  Names symbolsTable, VariableTree parameter) {
    String parameterName = parameter.getName().toString();
    Name parameterId = symbolsTable.fromString(parameterName);
        
    String errorMessagePrefix = String.format(
      "Argument '%s' of type %s is marked by @%s but got '", 
      parameterName, parameter.getType(), Positive.class.getSimpleName());
    String errorMessageSuffix = "' for it";
        
    return factory.Block(0, com.sun.tools.javac.util.List.of(
      factory.Throw(
        factory.NewClass(null, nil(), 
          factory.Ident(symbolsTable.fromString(
            IllegalArgumentException.class.getSimpleName())),
            com.sun.tools.javac.util.List.of(factory.Binary(JCTree.Tag.PLUS, 
            factory.Binary(JCTree.Tag.PLUS, 
              factory.Literal(TypeTag.CLASS, errorMessagePrefix), 
              factory.Ident(parameterId)), 
              factory.Literal(TypeTag.CLASS, errorMessageSuffix))), null))));
}

将新元素插入 AST:

private void addCheck(MethodTree method, VariableTree parameter, Context context) {
    JCTree.JCIf check = createCheck(parameter, context);
    JCTree.JCBlock body = (JCTree.JCBlock) method.getBody();
    body.stats = body.stats.prepend(check);
}

6. 测试插件

测试包含两个步骤:

  1. 编译测试源码
  2. 运行编译结果验证行为

需要辅助类实现内存编译:

SimpleSourceFile 暴露源码给 Javac:

public class SimpleSourceFile extends SimpleJavaFileObject {
    private String content;

    public SimpleSourceFile(String qualifiedClassName, String testSource) {
        super(URI.create(String.format(
          "file://%s%s", qualifiedClassName.replaceAll("\\.", "/"),
          Kind.SOURCE.extension)), Kind.SOURCE);
        content = testSource;
    }

    @Override
    public CharSequence getCharContent(boolean ignoreEncodingErrors) {
        return content;
    }
}

SimpleClassFile 存储编译结果:

public class SimpleClassFile extends SimpleJavaFileObject {

    private ByteArrayOutputStream out;

    public SimpleClassFile(URI uri) {
        super(uri, Kind.CLASS);
    }

    @Override
    public OutputStream openOutputStream() throws IOException {
        return out = new ByteArrayOutputStream();
    }

    public byte[] getCompiledBinaries() {
        return out.toByteArray();
    }

    // getters
}

SimpleFileManager 管理输出:

public class SimpleFileManager
  extends ForwardingJavaFileManager<StandardJavaFileManager> {

    private List<SimpleClassFile> compiled = new ArrayList<>();

    // standard constructors/getters

    @Override
    public JavaFileObject getJavaFileForOutput(Location location,
      String className, JavaFileObject.Kind kind, FileObject sibling) {
        SimpleClassFile result = new SimpleClassFile(
          URI.create("string://" + className));
        compiled.add(result);
        return result;
    }

    public List<SimpleClassFile> getCompiled() {
        return compiled;
    }
}

内存编译核心逻辑:

public class TestCompiler {
    public byte[] compile(String qualifiedClassName, String testSource) {
        StringWriter output = new StringWriter();

        JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
        SimpleFileManager fileManager = new SimpleFileManager(
          compiler.getStandardFileManager(null, null, null));
        List<SimpleSourceFile> compilationUnits 
          = singletonList(new SimpleSourceFile(qualifiedClassName, testSource));
        List<String> arguments = new ArrayList<>();
        arguments.addAll(asList("-classpath", System.getProperty("java.class.path"),
          "-Xplugin:" + SampleJavacPlugin.NAME));
        JavaCompiler.CompilationTask task 
          = compiler.getTask(output, fileManager, null, arguments, null,
          compilationUnits);
        
        task.call();
        return fileManager.getCompiled().iterator().next().getCompiledBinaries();
    }
}

运行编译结果:

public class TestRunner {

    public Object run(byte[] byteCode, String qualifiedClassName, String methodName,
      Class<?>[] argumentTypes, Object... args) throws Throwable {
        ClassLoader classLoader = new ClassLoader() {
            @Override
            protected Class<?> findClass(String name) throws ClassNotFoundException {
                return defineClass(name, byteCode, 0, byteCode.length);
            }
        };
        Class<?> clazz;
        try {
            clazz = classLoader.loadClass(qualifiedClassName);
        } catch (ClassNotFoundException e) {
            throw new RuntimeException("Can't load compiled test class", e);
        }

        Method method;
        try {
            method = clazz.getMethod(methodName, argumentTypes);
        } catch (NoSuchMethodException e) {
            throw new RuntimeException(
              "Can't find the 'main()' method in the compiled test class", e);
        }

        try {
            return method.invoke(null, args);
        } catch (InvocationTargetException e) {
            throw e.getCause();
        }
    }
}

测试用例示例:

public class SampleJavacPluginTest {

    private static final String CLASS_TEMPLATE
      = "package com.baeldung.javac;\n\n" +
        "public class Test {\n" +
        "    public static %1$s service(@Positive %1$s i) {\n" +
        "        return i;\n" +
        "    }\n" +
        "}\n" +
        "";

    private TestCompiler compiler = new TestCompiler();
    private TestRunner runner = new TestRunner();

    @Test(expected = IllegalArgumentException.class)
    public void givenInt_whenNegative_thenThrowsException() throws Throwable {
        compileAndRun(double.class,-1);
    }
    
    private Object compileAndRun(Class<?> argumentType, Object argument) 
      throws Throwable {
        String qualifiedClassName = "com.baeldung.javac.Test";
        byte[] byteCode = compiler.compile(qualifiedClassName, 
          String.format(CLASS_TEMPLATE, argumentType.getName()));
        return runner.run(byteCode, qualifiedClassName, 
        "service", new Class[] {argumentType}, argument);
    }
}

测试逻辑:编译带 @Positive 参数的 Test 类,传入 -1 触发异常。插件生效时会抛出 IllegalArgumentException

7. 总结

本文完整演示了 Java 编译器插件的创建、测试和运行流程。

完整示例代码可在 GitHub 获取。