魔法值怎么能忍!JPA的Specification大改造!

2024-08-01#coding#JPA#java阅读时间约 11 分钟

写一个简单的查询接口

如果让你通过JPA做编程式sql构造你会怎么做?一般来说,我们都是通过Specification来构造条件,进行动态条件查询。比如我所在的项目中,很多就是这么写的:

/**
     * <p>根据条件查询订单信息</p>
     * @param orderQueryCondition 订单查询条件
     * @return java.util.List<com.demo.po.OrderPO>
     */
    public List<OrderPO> findByCondition(OrderQueryCondition orderQueryCondition) {

        Specification specification = (root, criteriaQuery, criteriaBuilder) -> {
            List<Predicate> predicates = new ArrayList<>();

            if(!StringUtils.isEmpty(orderQueryCondition.getOrderItemName())) {
                predicates.add(criteriaBuilder.equal(root.get("orderItemName"), orderQueryCondition.getOrderItemName()));
            }

            if(!StringUtils.isEmpty(orderQueryCondition.getOrderNo())) {
                predicates.add(criteriaBuilder.equal(root.get("orderNo"), orderQueryCondition.getOrderNo()));
            }

            return criteriaBuilder.and(predicates.toArray(new Predicate[predicates.size()]));
        };

        return orderDao.findAll(specification);

    }

代码很简单,简单讲解下

其实就是通过lambda实现了Specification接口内的toPredicate()方法,在orderDao.findAll(specification)方法内,它的实现类SimpleJpaRepository会从容器中获取EntityManager对象,分别获取到toPredicate()方法所需要的root, criteriaQuery, criteriaBuilder三个对象,最终构造出Predicate对象,而Predicate对象就是JPA所需要的最终的条件封装对象。

那这个代码有没问题?

从执行角度上来说挺好的,能动态生成sql,很灵活(请暂时对这个例子忽略索引等问题),代码也写得很清晰很好懂。

但是它好不好?

我认为它并不好。因为它有魔法值

魔法值

啥是魔法值?这里指的不是游戏里面的magic point。指的是一种写死的数值或者字符串。魔法值有两个不好的地方

不易于理解

我在项目里面见得最多魔法值的地方是一些if条件判断里面,比如以下的代码:

if ("1".equals(this.medicareFlag) || "2".equals(this.medicareFlag)){
    return 1;
}else if ("0".equals(this.medicareFlag)){
    return 0;
}else {
    return 2;
}

如果你是维护这个代码的人员,看到这个代码你头痛不?1是什么,2是什么,你是谁,我又是谁(

不容易维护

比如说你有一个订单状态要判断,如果你没有把这个状态声明成变量,而是每个地方都像上面那样判断,那么如果你需要修改这个状态名字的时候,就会出现到处都要改。比如下面这样:

public void 方法1(Order order) {
    if("已支付".equals(order)) {
        // blabla...
    }
}
public void 方法2(Order order) {
    if("已支付".equals(order)) {
        // blabla...
    }
}

像上面这样写,要是客户有天突然说我不想叫已支付了,想改成已付款,那就改死人了。

去掉魔法值!

通过上面的例子,魔法值的坏处你应该知道了。同理的,最开始的查询例子,也是用的魔法值。虽然说字段名这种东西一般很少改动,但是这样写一来很容易写错,二来不怕一万只怕万一,要是真要改字段名那就只能哭去了。

其实mybatis-plus就有类似的功能,它定义了LambdaQueryWrapperLambdaQueryChainWrapper来通过lambda的方式构造查询条件。LambdaQueryChainWrapper甚至能链式构造。

一不做二不休,我决定要模仿LambdaQueryChainWrapper做一个能lambda做字段名的同时还能链式构造的工具。

动手做起来!

做之前再次看看原理

还记得上面讲解过Specification构造条件的原理吗?我们可以简单看看SimpleJpaRepository里的源码

展开查看:SimpleJpaRepository部分源码

private final EntityManager em;

/*
 * (non-Javadoc)
 * @see org.springframework.data.jpa.repository.JpaSpecificationExecutor#findAll(org.springframework.data.jpa.domain.Specification)
 */
@Override
public List<T> findAll(@Nullable Specification<T> spec) {
    // 这个方法就是通过Dao调用的入口
    return getQuery(spec, Sort.unsorted()).getResultList();
}

/**
 * Creates a {@link TypedQuery} for the given {@link Specification} and {@link Sort}.
 *
 * @param spec can be {@literal null}.
 * @param sort must not be {@literal null}.
 * @return
 */
protected TypedQuery<T> getQuery(@Nullable Specification<T> spec, Sort sort) {
    return getQuery(spec, getDomainClass(), sort);
}

/**
 * Creates a {@link TypedQuery} for the given {@link Specification} and {@link Sort}.
 *
 * @param spec can be {@literal null}.
 * @param domainClass must not be {@literal null}.
 * @param sort must not be {@literal null}.
 * @return
 */
protected <S extends T> TypedQuery<S> getQuery(@Nullable Specification<S> spec, Class<S> domainClass, Sort sort) {

    // 通过容器获取到的EntityManager获取CriteriaBuilder对象和CriteriaQuery对象
    CriteriaBuilder builder = em.getCriteriaBuilder();
    CriteriaQuery<S> query = builder.createQuery(domainClass);

    Root<S> root = applySpecificationToCriteria(spec, domainClass, query);
    query.select(root);

    if (sort.isSorted()) {
        query.orderBy(toOrders(sort, root, builder));
    }

    return applyRepositoryMethodMetadata(em.createQuery(query));
}

/**
 * Applies the given {@link Specification} to the given {@link CriteriaQuery}.
 *
 * @param spec can be {@literal null}.
 * @param domainClass must not be {@literal null}.
 * @param query must not be {@literal null}.
 * @return
 */
private <S, U extends T> Root<U> applySpecificationToCriteria(@Nullable Specification<U> spec, Class<U> domainClass,
        CriteriaQuery<S> query) {

    Assert.notNull(domainClass, "Domain class must not be null!");
    Assert.notNull(query, "CriteriaQuery must not be null!");

    Root<U> root = query.from(domainClass);

    if (spec == null) {
        return root;
    }

    // 最核心部分,通过EntityManager获取CriteriaBuilder,同时调用Specification的toPredicate()方法
    CriteriaBuilder builder = em.getCriteriaBuilder();
    Predicate predicate = spec.toPredicate(root, query, builder);

    if (predicate != null) {
        query.where(predicate);
    }

    return root;
}

可以看到Specification对象在源码中的applySpecificationToCriteria()方法中被调用了toPredicate()方法,这就是它最核心的部分。这么说,我们只需要实现了toPredicate()这个方法,在这里面做文章,就可以完成我们的工具了。

实现Specification接口,定义一个LambdaSpecification

首先我们实现Specification接口,来分析一下我们还要准备什么。

通过分析原生的写法,它是通过List<Predicate>将所有条件放在一个集合中最后再通过criteriaBuilder整合构造出一个Predicate对象的。

那么我们也定义一个List<Predicate>

不行,因为通过上面的原理,我们知道Specification是通过lambda构造的,实际的root, criteriaQuery, criteriaBuilder三个对象只有在SimpleJpaRepository对象执行的时候才会传入,这样我们在写我们的链式方法的时候就无法获取这几个对象了。

那怎么办?

好办啊,像它一样搞个Function不就好了。我们需要接收RootCriteriaBuilder两个类对象,所以我们定义一个BiFunction来接收。

private List<BiFunction<Root, CriteriaBuilder, Predicate>> predicateFunctions;

那么我们在定义链式条件方法时只需要这么写,我们的等于条件用eq()命名。再加上判空处理:

public LambdaSpecification<T> eq(Boolean ignoreNull, String columnName, Object value) {
    // 自主选择是否忽略null值,也就是null时不加入该条件
    if(ignoreNull && value == null) {
        return this;
    }
    // 将function放入List中,后续给toPredicate()处理
    predicateFunctions.add((root, criteriaBuilder) -> criteriaBuilder.equal(root.get(columnName), value));
    return this;
}

于是,我们的LambdaSpecification类就完成了:

public class LambdaSpecification<T> implements Specification<T> {

    private List<BiFunction<Root, CriteriaBuilder, Predicate>> predicateFunctions;

    private LambdaSpecification() {
        this.predicateFunctions = new ArrayList<>();
    }

    public static <T> LambdaSpecification<T> query() {
        return new LambdaSpecification<>();
    }

    public LambdaSpecification<T> eq(Boolean ignoreNull, String columnName, Object value) {
        // 自主选择是否忽略null值,也就是null时不加入该条件
        if(ignoreNull && value == null) {
            return this;
        }
        // 将function放入List中,后续给toPredicate()处理
        predicateFunctions.add((root, criteriaBuilder) -> criteriaBuilder.equal(root.get(columnName), value));
        return this;
    }

    @Override
    public Predicate toPredicate(Root<T> root, CriteriaQuery<?> query, CriteriaBuilder criteriaBuilder) {
        // 前面构造好的predicateFunctions执行并获得Predicate对象
        List<Predicate> predicates = predicateFunctions.stream()
                    .map(function -> function.apply(root, criteriaBuilder))
                    .collect(Collectors.toList());
        // 整合构造出Predicate对象
        return criteriaBuilder.and(predicates.toArray(new Predicate[predicates.size()]));
    }

}

然后改造下我们的调用代码:

Specification specification = LambdaSpecification.query()
        .eq(true, "orderNo", orderQueryCondition.getOrderNo())
        .eq(true, "orderItemName", orderQueryCondition.getOrderItemName());
return orderDao.findAll(specification);

写起来又清爽又简单对不对,很完美,执行下看结果,我分别调了两个例子,一次只传order_no,一次两个都传了,很完美实现了我们的功能,这是执行的sql:

// 两个都传的sql
Hibernate: select orderpo0_.pid as pid1_0_, orderpo0_.order_item_name as order_item_name2_0_, orderpo0_.order_no as order_no3_0_ from ipn_order orderpo0_ where ( orderpo0_.is_available = 'Y') and orderpo0_.order_no=250037 and orderpo0_.order_item_name=?

// 只传order_no的sql
Hibernate: select orderpo0_.pid as pid1_0_, orderpo0_.order_item_name as order_item_name2_0_, orderpo0_.order_no as order_no3_0_ from ipn_order orderpo0_ where ( orderpo0_.is_available = 'Y') and orderpo0_.order_no=250037

加入lambda的支持

链式调用很完美地实现了,现在来到重头戏,lambda的实现。我们的目标是可以通过OrderPO::getOrderNo传入,可以识别出orderNo字段名。其实也简单,我们只需要将eq()方法里面的String columnName改成Function就可以了。

首先我们如果用到Function,我们要先看如何可以通过Function获取到方法名。

实际上Function对象本身并没有方法获取到方法名,如果要获取方法名,我们得让它支持序列化,转换成SerializedLambda对象才可以获取到。这是因为lambda表达式在支持序列化的时候,java会给它实现一个writeReplace()方法,这个方法返回的对象就是SerializedLambda对象,顾名思义的序列化lambda对象,里面就包含了我们想要的方法名。

因此,我们不能直接用Function类,要定义一个继承Function类的,支持序列化的SerializableFunction

/**
 * <p>能序列化的Function</p>
 *
 * @author VincentHo
 * @date 2024-08-01
 */
public interface SerializableFunction<T, R> extends Function<T, R>, Serializable {
}

然后我们就可以愉快地通过反射获取具体的方法名和字段名了:


/**
 * <p>获取字段名</p>
 * @author VincentHo
 * @date 2024/8/1
 * @param columnNameGetter
 * @return java.lang.String
 */
private <T> String getColumnName(SerializableFunction<T, Object> columnNameGetter) {
    String methodName = getMethodName(columnNameGetter);
    if (methodName.startsWith("get")) {
        String filedName = methodName.substring(3, 4).toLowerCase() + methodName.substring(4);
        return filedName;
    } else {
        throw new RuntimeException(String.format("动态查询生成失败,方法名必须为get方法,当前方法名为:%s", methodName));
    }
}

/**
 * <p>获取方法名</p>
 * 通过writeReplace()方法获取到SerializedLambda对象,从而获取方法名
 * @author VincentHo
 * @date 2024/8/1
 * @param columnNameGetter
 * @return java.lang.String
 */
@SneakyThrows
private <T> String getMethodName(SerializableFunction<T, Object> columnNameGetter) {
    Method writeReplace = columnNameGetter.getClass().getDeclaredMethod("writeReplace");
    writeReplace.setAccessible(true);
    Object sl = writeReplace.invoke(columnNameGetter);
    SerializedLambda serializedLambda = (SerializedLambda)sl;
    return serializedLambda.getImplMethodName();
}

另外,由于泛型擦写的关系,如果直接通过泛型传入OrderPO::getOrderNo,它因无法识别出类型而报错。因此我们还得让我们的LambdaSpecification对象知道泛型T到底是个什么类型。我们加入Class<T>对象用来指定类型

// 增加Class<T>对象
private Class<T> poClass;

// 修改构造方法
private LambdaSpecification(Class<T> poClass) {
    this.predicateFunctions = new ArrayList<>();
    this.poClass = poClass;
}

最后修改完成的代码如下:

展开查看:LambdaSpecification最终完成代码
public class LambdaSpecification<T> implements Specification<T> {

    private List<BiFunction<Root, CriteriaBuilder, Predicate>> predicateFunctions;

    private Class<T> poClass;

    private LambdaSpecification(Class<T> poClass) {
        this.predicateFunctions = new ArrayList<>();
        this.poClass = poClass;
    }

    public static <T> LambdaSpecification<T> query(Class<T> poClazz) {
        return new LambdaSpecification<>(poClazz);
    }

    public LambdaSpecification<T> eq(Boolean ignoreNull, SerializableFunction<T, Object> columnNameGetter, Object value) {
        // 自主选择是否忽略null值,也就是null时不加入该条件
        if(ignoreNull && value == null) {
            return this;
        }
        predicateFunctions.add((root, criteriaBuilder) -> criteriaBuilder.equal(root.get(getColumnName(columnNameGetter)), value));
        return this;
    }

    @Override
    public Predicate toPredicate(Root<T> root, CriteriaQuery<?> query, CriteriaBuilder criteriaBuilder) {
        // 前面构造好的predicateFunctions执行并获得Predicate对象
        List<Predicate> predicates = predicateFunctions.stream()
                    .map(function -> function.apply(root, criteriaBuilder))
                    .collect(Collectors.toList());
        // 整合构造出Predicate对象
        return criteriaBuilder.and(predicates.toArray(new Predicate[predicates.size()]));
    }

    /**
     * <p>获取字段名</p>
     * @author VincentHo
     * @date 2024/8/1
     * @param columnNameGetter
     * @return java.lang.String
     */
    private <T> String getColumnName(SerializableFunction<T, Object> columnNameGetter) {
        String methodName = getMethodName(columnNameGetter);
        if (methodName.startsWith("get")) {
            String filedName = methodName.substring(3, 4).toLowerCase() + methodName.substring(4);
            return filedName;
        } else {
            throw new RuntimeException(String.format("动态查询生成失败,方法名必须为get方法,当前方法名为:%s", methodName));
        }
    }

    /**
     * <p>获取方法名</p>
     * 通过writeReplace()方法获取到SerializedLambda对象,从而获取方法名
     * @author VincentHo
     * @date 2024/8/1
     * @param columnNameGetter
     * @return java.lang.String
     */
    @SneakyThrows
    private <T> String getMethodName(SerializableFunction<T, Object> columnNameGetter) {
        Method writeReplace = columnNameGetter.getClass().getDeclaredMethod("writeReplace");
        writeReplace.setAccessible(true);
        Object sl = writeReplace.invoke(columnNameGetter);
        SerializedLambda serializedLambda = (SerializedLambda)sl;
        return serializedLambda.getImplMethodName();
    }

}

再来改一下我们的调用代码:

public List<OrderPO> findByCondition(OrderQueryCondition orderQueryCondition) {
    Specification specification = LambdaSpecification.query(OrderPO.class)
            .eq(true, OrderPO::getOrderNo, orderQueryCondition.getOrderNo())
            .eq(true, OrderPO::getOrderItemName, orderQueryCondition.getOrderItemName());
    return orderDao.findAll(specification);
}

这样一来就改造好了,终于可以和魔法值说拜拜了!执行结果也很正确,好耶!

Hibernate: select orderpo0_.pid as pid1_0_, orderpo0_.order_item_name as order_item_name2_0_, orderpo0_.order_no as order_no3_0_ from ipn_order orderpo0_ where ( orderpo0_.is_available = 'Y') and orderpo0_.order_no=250037

上面只演示了等于条件,像in条件,like条件之类的,根据所需要扩展即可,比如in条件:

public LambdaSpecification<T> in(Boolean ignoreNull, SerializableFunction<T, Object> columnNameGetter, Object ... values) {
    if(ignoreNull && values == null) {
        return this;
    }
    predicateFunctions.add(
            (root, criteriaBuilder) -> {
                CriteriaBuilder.In in = criteriaBuilder.in(root.get(getColumnName(columnNameGetter)));
                for(Object value : values) {
                    in.value(value);
                }
                return in;
            }
    );
    return this;
}

其他不再赘述了,自己扩展即可。

完成后的源代码:https://github.com/VincenttHo/perfect-jpa-specification

以上就是这次改造的过程,希望能帮到大家,谢谢!