1. 引言

本文将探讨Java 17新增的上下文特定反序列化过滤器功能。我们将构建一个实际场景,通过实践演示如何为应用中的不同情况选择合适的反序列化过滤器。

2. 与JEP 290的关系

JEP 290在Java 9中引入,通过JVM全局过滤器和为每个ObjectInputStream实例定义过滤器的能力,实现了对外部来源反序列化的过滤。这些过滤器基于运行时参数决定是否允许对象被反序列化。

反序列化不受信任数据的危险早已引发讨论,相关机制也在持续改进。现在我们有了更多动态选择反序列化过滤器的选项,创建过滤器也变得更简单。

3. JEP 415在ObjectInputFilter中的新方法

为了提供更多定义反序列化过滤器的时机和方式,Java 17引入的JEP 415允许指定一个JVM全局过滤器工厂,每次反序列化时都会调用它。这样我们的过滤方案就不会过于严格或宽泛。

同时,为了提供更多上下文控制,新增了几个简化过滤器创建和组合的方法:

  • rejectFilter(Predicate<Class<?>> predicate, Status otherStatus):当predicate返回true时拒绝反序列化,否则返回otherStatus
  • allowFilter(Predicate<Class<?>> predicate, Status otherStatus):当predicate返回true时允许反序列化,否则返回otherStatus
  • rejectUndecidedClass(ObjectInputFilter filter):将传入filter的所有UNDECIDED返回值映射为REJECTED(有少数例外情况
  • merge(ObjectInputFilter filter, ObjectInputFilter anotherFilter):尝试测试两个过滤器,遇到第一个REJECTED状态时立即返回。anotherFilter具有空安全特性,当其为空时直接返回filter而非新组合过滤器

*注意:当被反序列化的类信息为null时,rejectFilter()allowFilter()会返回UNDECIDED*

4. 构建场景和设置

为了演示反序列化过滤器工厂的作用,我们的场景将涉及几个在别处序列化的POJO,通过不同的服务类反序列化。我们将用它们模拟阻止外部来源潜在不安全反序列化的情况。最终学习如何定义参数来检测序列化内容中的异常属性。

先从POJO的标记接口开始:

public interface ContextSpecific extends Serializable {}

首先,Sample类包含可通过ObjectInputFilter在反序列化时检查的基本属性,如数组和嵌套对象:

public class Sample implements ContextSpecific, Comparable<Sample> {
    private static final long serialVersionUID = 1L;

    private int[] array;
    private String name;
    private NestedSample nested;

    public Sample(String name) {
        this.name = name;
    }

    public Sample(int[] array) {
        this.array = array;
    }

    public Sample(NestedSample nested) {
        this.nested = nested;
    }

    // 标准getter和setter

    @Override
    public int compareTo(Sample o) {
        if (name == null)
            return -1;

        if (o == null || o.getName() == null)
            return 1;

        return getName().compareTo(o.getName());
    }
}

实现Comparable仅是为了后续将实例添加到TreeSet,这有助于展示代码如何被间接执行。其次,使用NestedSample类改变反序列化对象的深度,我们将用它来设置对象图在反序列化前的深度限制:

public class NestedSample implements ContextSpecific {

    private Sample optional;

    public NestedSample(Sample optional) {
        this.optional = optional;
    }

    // 标准getter和setter
}

*最后创建一个简单的利用示例供后续过滤。其toString()compareTo()方法包含副作用,例如每次向TreeSet添加项时可能被间接调用:*

public class SampleExploit extends Sample {

    public SampleExploit() {
        super("exploit");
    }

    public static void maliciousCode() {
        System.out.println("exploit executed");
    }

    @Override
    public String toString() {
        maliciousCode();
        return "exploit";
    }

    @Override
    public int compareTo(Sample o) {
        maliciousCode();
        return super.compareTo(o);
    }
}

注意:此简单示例仅用于演示,不模拟真实利用场景。

4.1 序列化和反序列化工具

为方便后续测试,创建几个序列化和反序列化对象的工具。先从简单的序列化开始:

public class SerializationUtils {

    public static void serialize(Object object, OutputStream outStream) throws IOException {
        try (ObjectOutputStream objStream = new ObjectOutputStream(outStream)) {
            objStream.writeObject(object);
        }
    }
}

*同样为测试辅助,创建一个将所有非拒绝对象反序列化为集合的方法,以及一个可选接收另一个过滤器的deserialize()方法:*

public class DeserializationUtils {

    public static Object deserialize(InputStream inStream) {
        return deserialize(inStream, null);
    }
    public static Object deserialize(InputStream inStream, ObjectInputFilter filter) {
        try (ObjectInputStream in = new ObjectInputStream(inStream)) {
            if (filter != null) {
                in.setObjectInputFilter(filter);
            }
            return in.readObject();
        } catch (InvalidClassException e) {
            return null;
        }
    }

    public static Set<ContextSpecific> deserializeIntoSet(InputStream... inputStreams) {
        return deserializeIntoSet(null, inputStreams);
    }

    public static Set<ContextSpecific> deserializeIntoSet(
      ObjectInputFilter filter, InputStream... inputStreams) {
        Set<ContextSpecific> set = new TreeSet<>();

        for (InputStream inputStream : inputStreams) {
            Object object = deserialize(inputStream, filter);
            if (object != null) {
                set.add((ContextSpecific) object);
            }
        }

        return set;
    }
}

注意:在我们的场景中,当发生InvalidClassException时返回null。每当任何过滤器拒绝反序列化时都会抛出此异常。这样就不会破坏*deserializeIntoSet()*,因为我们只收集成功的反序列化并丢弃其他结果。

4.2 创建过滤器

在构建过滤器工厂前,需要一些可选的过滤器。我们将使用ObjectInputFilter.Config.createFilter()创建几个简单过滤器。它接收接受或拒绝包的模式,以及对象反序列化前要检查的几个参数:

public class FilterUtils {

    private static final String DEFAULT_PACKAGE_PATTERN = "java.base/*;!*";
    private static final String POJO_PACKAGE = "com.baeldung.deserializationfilters.pojo";

    // ...
}

首先设置DEFAULT_PACKAGE_PATTERN为接受"java.base"模块中任何类并拒绝其他类的模式。然后设置POJO_PACKAGE为应用中需要反序列化的类所在的包。

基于这些信息,创建作为过滤器基础的方法。*使用baseFilter(),我们将接收要检查的参数名和最大值:*

private static ObjectInputFilter baseFilter(String parameter, int max) {
    return ObjectInputFilter.Config.createFilter(String.format(
      "%s=%d;%s.**;%s", parameter, max, POJO_PACKAGE, DEFAULT_PACKAGE_PATTERN));
}

// ...

使用fallbackFilter()创建更严格的过滤器,仅接受DEFAULT_PACKAGE_PATTERN中的类。它将用于服务类外的反序列化:

public static ObjectInputFilter fallbackFilter() {
    return ObjectInputFilter.Config.createFilter(String.format("%s", DEFAULT_PACKAGE_PATTERN));
}

最后编写用于限制读取字节数、对象中数组大小和反序列化对象图最大深度的过滤器:

public static ObjectInputFilter safeSizeFilter(int max) {
    return baseFilter("maxbytes", max);
}

public static ObjectInputFilter safeArrayFilter(int max) {
    return baseFilter("maxarray", max);
}

public static ObjectInputFilter safeDepthFilter(int max) {
    return baseFilter("maxdepth", max);
}

所有设置完成后,就可以开始编写过滤器工厂了。

5. 创建反序列化过滤器工厂

反序列化过滤器工厂允许我们根据被反序列化的内容动态选择特定过滤器,而不是依赖整个应用的单一过滤器,或在每次创建ObjectInputStream实例时设置不同过滤器。现在我们可以拥有许多上下文特定的过滤器,并在运行时选择或组合它们。

实现机制包括实现BinaryOperator,然后通过jdk.serialFilterFactory JVM属性设置其类名,或调用*ObjectInputFilter.Config.setSerialFilterFactory()*。工厂是JVM全局的且只能设置一次。因此如果通过JVM属性设置,就无法以编程方式替换。此外出于安全考虑,不能将其设置为null。

5.1 选择过滤器的策略

过滤器工厂的策略是基于调用的类选择我们创建的过滤器之一。这就是我们的上下文。因此创建几个调用DeserializationUtils.deserializeIntoSet()的服务类。它们都将通过DeserializationService接口标识:

public interface DeserializationService {

    Set<ContextSpecific> process(InputStream... inputStreams);
}

public class LimitedArrayService implements DeserializationService {

    @Override
    public Set<ContextSpecific> process(InputStream... inputStreams) {
        return DeserializationUtils.deserializeIntoSet(inputStreams);
    }
}

public class LowDepthService implements DeserializationService {
    // process...
}

public class SmallObjectService implements DeserializationService {
    // process...
}

5.2 过滤器工厂结构

我们的过滤器工厂将依赖当前线程的堆栈跟踪检查调用是否来自服务类以及具体是哪个类。先为此编写一个工具方法:

public class ContextSpecificDeserializationFilterFactory implements BinaryOperator<ObjectInputFilter> {

    private static Class<?> findInStack(Class<?> superType) {
        for (StackTraceElement element : Thread.currentThread().getStackTrace()) {
            try {
                Class<?> subType = Class.forName(element.getClassName());
                if (superType.isAssignableFrom(subType)) {
                    return subType;
                }
            } catch (ClassNotFoundException e) {
                return null;
            }
        }
        return null;
    }

    // ...
}

最后重写*apply()*方法:

@Override
public ObjectInputFilter apply(ObjectInputFilter current, ObjectInputFilter next) {
    if (current == null) {
        Class<?> caller = findInStack(DeserializationService.class);

        if (caller == null) {
            current = FilterUtils.fallbackFilter();
        } else if (caller.equals(SmallObjectService.class)) {
            current = FilterUtils.safeSizeFilter(190);
        } else if (caller.equals(LowDepthService.class)) {
            current = FilterUtils.safeDepthFilter(2);
        } else if (caller.equals(LimitedArrayService.class)) {
            current = FilterUtils.safeArrayFilter(3);
        }
    }

    return ObjectInputFilter.merge(current, next);
}

此实现中:

  • 检查current过滤器是否已设置
  • 若未设置,尝试在堆栈中查找服务类
  • 若未找到,使用回退过滤器
  • 否则,若调用来自SmallObjectService,使用值为190的safeSizeFilter()
  • 检查其他可能的服务类,应用相应过滤器
  • 最终将结果过滤器与next过滤器合并,保留可能为ObjectOutputStream实例或通过*ObjectInputFilter.Config.setSerialFilter()*设置的过滤器

注意:safeSizeFilter()的值基于序列化实例预期的最大字节数。**由于SampleExploit类因额外内容序列化后体积更大,在反序列化时会被拒绝。**

6. 测试解决方案

先用几个序列化的Sample对象设置测试。最重要的是,用我们的工厂类调用*setSerialFilterFactory()*:

static ByteArrayOutputStream serialSampleA = new ByteArrayOutputStream();
static ByteArrayOutputStream serialBigSampleA = new ByteArrayOutputStream();

static ByteArrayOutputStream serialSampleC = new ByteArrayOutputStream();
static ByteArrayOutputStream serialBigSampleC = new ByteArrayOutputStream();

@BeforeAll
static void setup() throws IOException {
    ObjectInputFilter.Config.setSerialFilterFactory(new ContextSpecificDeserializationFilterFactory());

    SerializationUtils.serialize(new Sample("simple"), serialSampleA);
    SerializationUtils.serialize(new SampleExploit(), serialBigSampleA);

    SerializationUtils.serialize(new Sample(new NestedSample(null)), serialSampleC);
    SerializationUtils.serialize(new Sample(new NestedSample(new Sample("deep"))), serialBigSampleC);
}

private static ByteArrayInputStream bytes(ByteArrayOutputStream stream) {
    return new ByteArrayInputStream(stream.toByteArray());
}

*此测试中结果集仅包含"simple"对象,因为SampleExploit被拒绝,阻止了maliciousCode()的执行:*

@Test
void whenSmallObjectContext_thenCorrectFilterApplied() {
    Set<ContextSpecific> result = new SmallObjectService().process(
      bytes(serialSampleA),
      bytes(serialBigSampleA)
    );

    assertEquals(1, result.size());
    assertEquals(
      "simple", ((Sample) result.iterator().next()).getName());
}

6.1 组合过滤器

例如,使用LowDepthService时,过滤器工厂应用*safeDepthFilter(2)*,拒绝嵌套超过两层的对象:

@Test
void whenLowDepthContext_thenCorrectFilterApplied() {
    Set<ContextSpecific> result = new LowDepthService().process(
      bytes(serialSampleC),
      bytes(serialBigSampleC)
    );

    assertEquals(1, result.size());
}

但修改*LowDepthService.process()*接受自定义过滤器后:

public Set<ContextSpecific> process(ObjectInputFilter filter, InputStream... inputStreams) {
    return DeserializationUtils.deserializeIntoSet(filter, inputStreams);
}

*可以将safeDepthFilter()与任何其他过滤器组合。这里与safeSizeFilter()组合:*

@Test
void givenExtraFilter_whenCombinedContext_thenMergedFiltersApplied() {
    Set<ContextSpecific> result = new LowDepthService().process(
      FilterUtils.safeSizeFilter(190),
      bytes(serialSampleA),
      bytes(serialBigSampleA),
      bytes(serialSampleC),
      bytes(serialBigSampleC)
    );

    assertEquals(1, result.size());
}

结果只有serialSampleA被允许。

7. 结论

本文展示了Java最新增强功能——上下文特定反序列化过滤器(JEP 415)的实际应用。它通过过滤器工厂引入了反序列化操作中动态且上下文感知的过滤方法。我们的实践场景展示了基于服务的策略,其中不同服务类与特定反序列化上下文关联。此策略为开发者提供了增强安全性的健壮机制。

源代码可在GitHub获取。


原始标题:Context-Specific Deserialization Filters in Java 17