📃 关联文档

📄 前置文档

定义数据权限注解

@Target({METHOD, TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DataScope {

    /**
     * 当进行过滤时主表中代表企业id的字段
     */
    String unitField() default "ent_id";

    /**
     * 是否进行数据过滤
     */
    boolean filterData() default true;

    /**
     * 忽略的表名,主要指不包含unitField的表
     *
     * @return
     */
    String[] ignoreTables() default {"sys_file"};
}

定义一个对象用于储存注解中相关信息

import lombok.AllArgsConstructor;
import lombok.Data;

import java.util.List;
import java.util.Set;

/**
 * 类 DataScopeParam
 * </p>
 *
 * @author ChenQi
 * @since 2022/10/20 17:37
 */
@Data
@AllArgsConstructor
public class DataScopeParam {
    /**
     * 企业筛选字段名称
     */
    private String unitField;

    /**
     * 企业数据范围
     */
    private Set<Long> entIdList;

    /**
     * 是否进行拦截
     */
    private boolean filterField;

    /**
     * 忽略不过滤的表名
     */
    private List<String> ignoreTables;
}

权限解析器

import cn.hutool.core.convert.Convert;
import com.lyc.common.base.annotation.DataScope;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.BridgeMethodResolver;
import org.springframework.core.MethodClassKey;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.AnnotationAttributes;
import org.springframework.util.ClassUtils;

import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Proxy;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 类 PermissionClassResolver
 * </p>
 * 权限解析器
 *
 * @author ChenQi
 * @since 2022/10/31 16:16
 */
@Slf4j
public class DataScopeAnnotationClassResolver {

    /**
     * 缓存方法对应的权限拦截
     */
    private final Map<Object, DataScopeParam> dsCache = new ConcurrentHashMap<>();

    public DataScopeAnnotationClassResolver() {
    }

    /**
     * 从缓存获取数据
     *
     * @param method       方法
     * @param targetObject 目标对象
     * @return ds
     */
    public DataScopeParam findKey(Method method, Object targetObject) {
        if (method.getDeclaringClass() == Object.class) {
            return null;
        }
        Object cacheKey = new MethodClassKey(method, targetObject.getClass());
        DataScopeParam dsp = this.dsCache.get(cacheKey);
        if (dsp == null) {
            dsp = computeDatasource(method, targetObject);
            this.dsCache.put(cacheKey, dsp);
        }
        return dsp;
    }

    /**
     * 查找注解的顺序
     * 1. 当前方法
     * 2. 桥接方法
     * 3. 当前类开始一直找到Object
     *
     * @param method       方法
     * @param targetObject 目标对象
     * @return ds
     */
    private DataScopeParam computeDatasource(Method method, Object targetObject) {
        if (!Modifier.isPublic(method.getModifiers())) {
            return null;
        }
        //1. 从当前方法接口中获取
        DataScopeParam dsAttr = findDataSourceAttribute(method);
        if (dsAttr != null) {
            return dsAttr;
        }
        Class<?> targetClass = targetObject.getClass();
        Class<?> userClass = ClassUtils.getUserClass(targetClass);
        // JDK代理时,  获取实现类的方法声明.  method: 接口的方法, specificMethod: 实现类方法
        Method specificMethod = ClassUtils.getMostSpecificMethod(method, userClass);

        specificMethod = BridgeMethodResolver.findBridgedMethod(specificMethod);
        //2. 从桥接方法查找
        dsAttr = findDataSourceAttribute(specificMethod);
        if (dsAttr != null) {
            return dsAttr;
        }
        // 从当前方法声明的类查找
        dsAttr = findDataSourceAttribute(userClass);
        if (dsAttr != null && ClassUtils.isUserLevelMethod(method)) {
            return dsAttr;
        }
        //since 3.4.1 从接口查找,只取第一个找到的
        for (Class<?> interfaceClazz : ClassUtils.getAllInterfacesForClassAsSet(userClass)) {
            dsAttr = findDataSourceAttribute(interfaceClazz);
            if (dsAttr != null) {
                return dsAttr;
            }
        }
        // 如果存在桥接方法
        if (specificMethod != method) {
            // 从桥接方法查找
            dsAttr = findDataSourceAttribute(method);
            if (dsAttr != null) {
                return dsAttr;
            }
            // 从桥接方法声明的类查找
            dsAttr = findDataSourceAttribute(method.getDeclaringClass());
            if (dsAttr != null && ClassUtils.isUserLevelMethod(method)) {
                return dsAttr;
            }
        }
        return getDefaultDataSourceAttr(targetObject);
    }

    /**
     * 默认的获取
     *
     * @param targetObject 目标对象
     * @return DataScopeParam
     */
    private DataScopeParam getDefaultDataSourceAttr(Object targetObject) {
        Class<?> targetClass = targetObject.getClass();
        // 如果不是代理类, 从当前类开始, 不断的找父类的声明
        if (!Proxy.isProxyClass(targetClass)) {
            Class<?> currentClass = targetClass;
            while (currentClass != Object.class) {
                DataScopeParam datasourceAttr = findDataSourceAttribute(currentClass);
                if (datasourceAttr != null) {
                    return datasourceAttr;
                }
                currentClass = currentClass.getSuperclass();
            }
        }
        return null;
    }

    /**
     * 通过 AnnotatedElement 查找标记的注解, 映射为  DatasourceHolder
     *
     * @param ae AnnotatedElement
     * @return 数据源映射持有者
     */
    private DataScopeParam findDataSourceAttribute(AnnotatedElement ae) {
        AnnotationAttributes attributes = AnnotatedElementUtils.getMergedAnnotationAttributes(ae, DataScope.class);
        DataScopeParam dsp = null;
        if (attributes != null) {
            dsp = new DataScopeParam(attributes.getString("unitField"), new HashSet<>(), attributes.getBoolean("filterData"), Convert.toList(String.class, attributes.get("ignoreTables")));
        }
        return dsp;
    }
}

跨线程传递权限对象

import com.alibaba.ttl.TransmittableThreadLocal;

/**
 * 类 UnitDataPermissionContentHolder
 * </p>
 *
 * @author ChenQi
 * @since 2022/10/31 16:36
 */
public final class DataScopeParamContentHolder {

    private DataScopeParamContentHolder() {
    }

    private static final ThreadLocal<DataScopeParam> THREAD_PMS_HOLDER = new TransmittableThreadLocal<>();

    /**
     * 设置当前header中的权限
     *
     * @param dataScopeParam 需要过滤的权限
     */
    public static void set(DataScopeParam dataScopeParam) {
        THREAD_PMS_HOLDER.set(dataScopeParam);
    }

    /**
     * 获取header中的权限
     *
     * @return 权限
     */
    public static DataScopeParam get() {
        return THREAD_PMS_HOLDER.get();
    }

    public static void clear() {
        THREAD_PMS_HOLDER.remove();
    }
}

拦截请求,如果类或方法上有@DataScope注解则将注解内容储存至DataScopeParam并放到DataScopeParamContentHolder线程中

import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/**
 * 类 PermissionIntercept
 * </p>
 * 拦截请求方法,判断方法上或者类上是否存在注解
 *
 * @author ChenQi
 * @since 2022/10/31 16:13
 */
public class DataScopeAnnotationIntercept implements MethodInterceptor {

    private final DataScopeAnnotationClassResolver dataScopeAnnotationClassResolver;

    public DataScopeAnnotationIntercept() {
        dataScopeAnnotationClassResolver = new DataScopeAnnotationClassResolver();
    }

    @Nullable
    @Override
    public Object invoke(@NotNull MethodInvocation methodInvocation) throws Throwable {
        DataScopeParam paramKey = dataScopeAnnotationClassResolver.findKey(methodInvocation.getMethod(), methodInvocation.getThis());
        DataScopeParamContentHolder.set(paramKey);
        try {
            return methodInvocation.proceed();
        } finally {
            DataScopeParamContentHolder.clear();
        }
    }
}

@DataScope切面逻辑

import lombok.NonNull;
import org.aopalliance.aop.Advice;
import org.aopalliance.intercept.MethodInterceptor;
import org.springframework.aop.ClassFilter;
import org.springframework.aop.MethodMatcher;
import org.springframework.aop.Pointcut;
import org.springframework.aop.support.AbstractPointcutAdvisor;
import org.springframework.aop.support.AopUtils;
import org.springframework.aop.support.ComposablePointcut;
import org.springframework.aop.support.StaticMethodMatcher;
import org.springframework.aop.support.annotation.AnnotationMatchingPointcut;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.util.Assert;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;

/**
 * 类 DataScopeAnnotationAdvisor
 * </p>
 *
 * @author ChenQi
 * @since 2022/10/31 16:49
 */
public class DataScopeAnnotationAdvisor extends AbstractPointcutAdvisor implements BeanFactoryAware {

    private final Advice advice;

    private final Pointcut pointcut;

    private final Class<? extends Annotation> annotation;

    public DataScopeAnnotationAdvisor(@NonNull MethodInterceptor advice,
                                      @NonNull Class<? extends Annotation> annotation) {
        this.advice = advice;
        this.annotation = annotation;
        this.pointcut = buildPointcut();
    }

    @Override
    public Pointcut getPointcut() {
        return this.pointcut;
    }

    @Override
    public Advice getAdvice() {
        return this.advice;
    }

    @Override
    public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
        if (this.advice instanceof BeanFactoryAware) {
            ((BeanFactoryAware) this.advice).setBeanFactory(beanFactory);
        }
    }

    private Pointcut buildPointcut() {
        Pointcut cpc = new AnnotationMatchingPointcut(annotation, true);
        Pointcut mpc = new AnnotationMethodPoint(annotation);
        return new ComposablePointcut(cpc).union(mpc);
    }

    /**
     * In order to be compatible with the spring lower than 5.0
     */
    private static class AnnotationMethodPoint implements Pointcut {

        private final Class<? extends Annotation> annotationType;

        public AnnotationMethodPoint(Class<? extends Annotation> annotationType) {
            Assert.notNull(annotationType, "Annotation type must not be null");
            this.annotationType = annotationType;
        }

        @Override
        public ClassFilter getClassFilter() {
            return ClassFilter.TRUE;
        }

        @Override
        public MethodMatcher getMethodMatcher() {
            return new AnnotationMethodMatcher(annotationType);
        }

        private static class AnnotationMethodMatcher extends StaticMethodMatcher {
            private final Class<? extends Annotation> annotationType;

            public AnnotationMethodMatcher(Class<? extends Annotation> annotationType) {
                this.annotationType = annotationType;
            }

            @Override
            public boolean matches(Method method, Class<?> targetClass) {
                if (matchesMethod(method)) {
                    return true;
                }
                // Proxy classes never have annotations on their redeclared methods.
                if (Proxy.isProxyClass(targetClass)) {
                    return false;
                }
                // The method may be on an interface, so let's check on the target class as well.
                Method specificMethod = AopUtils.getMostSpecificMethod(method, targetClass);
                return (specificMethod != method && matchesMethod(specificMethod));
            }

            private boolean matchesMethod(Method method) {
                return AnnotatedElementUtils.hasAnnotation(method, this.annotationType);
            }
        }
    }
}

切面逻辑注册

import com.lyc.common.base.annotation.DataScope;
import org.springframework.aop.Advisor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.Ordered;

/**
 * 类 PermissionInitConfig
 * </p>
 *
 * @author ChenQi
 * @since 2022/10/31 16:46
 */
@Configuration
public class DataScopeInitConfig {

    @Bean
    public Advisor generateAllDataScopeAdvisor() {
        DataScopeAnnotationIntercept intercept = new DataScopeAnnotationIntercept();
        DataScopeAnnotationAdvisor advisor = new DataScopeAnnotationAdvisor(intercept, DataScope.class);
        advisor.setOrder(Ordered.HIGHEST_PRECEDENCE);
        return advisor;
    }
}

最后就是mybatis plus sql增强实现权限拦截

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.convert.Convert;
import cn.hutool.core.util.ReflectUtil;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.lyc.admin.oauth.service.SysUser;
import com.lyc.admin.oauth.utils.SecurityUtils;
import com.lyc.common.base.exception.HasNotAuthException;
import com.lyc.common.base.utils.CurrentEntIdSearchContextHolder;
import com.lyc.common.base.vo.EntierVO;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SetOperationList;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.aspectj.lang.annotation.Aspect;
import org.springframework.stereotype.Component;

import java.io.StringReader;
import java.sql.Connection;
import java.util.*;
import java.util.stream.Collectors;

/**
 * 类 DataPermissionInterceptor
 * </p>
 *
 * @author ChenQi
 * @since 2022/10/20 14:50
 */
@Aspect
@Slf4j
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
@Component
public class UnitDataPermissionInterceptor implements Interceptor {

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        SysUser sysUser = SecurityUtils.getUser();
        // 如果非权限用户则不往下执行
        if (sysUser == null) {
            return invocation.proceed();
        }

        DataScopeParam dataScopeParam = DataScopeParamContentHolder.get();

        if (dataScopeParam != null) {
            dataScopeParam.setEntIdList(Optional.ofNullable(sysUser.getTierVos()).orElse(new ArrayList<>()).stream().map(EntierVO::getUnitIdList).flatMap(Collection::stream).collect(Collectors.toSet()));
        }

        // 获取header中的待过滤的企业列表(前端传的企业列表,比如有多个企业权限指定查其中几个时就在此处和用户权限取交集)
        Set<Long> entIdList = CurrentEntIdSearchContextHolder.getEntIdList();
        if (entIdList != null) {
            if (dataScopeParam == null) {
                dataScopeParam = new DataScopeParam("ent_id", entIdList, true, CollUtil.newArrayList("sys_file"));
            } else {
                // 查询交集
                Set<Long> permissionEntList = dataScopeParam.getEntIdList();
                dataScopeParam.setFilterField(true);
                dataScopeParam.setEntIdList(entIdList.stream().filter(permissionEntList::contains).collect(Collectors.toSet()));
            }
        }

        // 没有添加注解则不往下执行
        if (dataScopeParam == null) {
            return invocation.proceed();
        }

        // 注解配置不过滤数据则不往下执行
        if (!dataScopeParam.isFilterField()) {
            return invocation.proceed();
        }

        StatementHandler statementHandler = PluginUtils.realTarget(invocation.getTarget());
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        // 先判断是不是SELECT操作 不是直接过滤
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        if (SqlCommandType.FLUSH.equals(mappedStatement.getSqlCommandType()) || SqlCommandType.UNKNOWN.equals(mappedStatement.getSqlCommandType())) {
            return invocation.proceed();
        }

        BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
        // 执行的SQL语句
        String originalSql = boundSql.getSql();
        // SQL语句的参数
        Object parameterObject = boundSql.getParameterObject();
        // 拦截插入语句
        if (SqlCommandType.INSERT.equals(mappedStatement.getSqlCommandType())) {
            // 当为insert时将判断是否具备权限
            if (parameterObject != null) {
                Long entId = Convert.toLong(ReflectUtil.getFieldValue(parameterObject, StrUtil.toCamelCase(dataScopeParam.getUnitField())));
                // 判断entId是否在权限范围内
                if (entId != null && !dataScopeParam.getEntIdList().contains(entId)) {
                    throw new HasNotAuthException();
                }
            }
            return invocation.proceed();
        }
        // 拦截更新语句,业务包含逻辑删除所以此处用的update
        if (SqlCommandType.UPDATE.equals(mappedStatement.getSqlCommandType())) {
            // 修改updateSql
            String updateSql = handleUpdateSql(originalSql, dataScopeParam.getEntIdList(), dataScopeParam.getUnitField(), dataScopeParam.getIgnoreTables());
            log.warn("数据权限处理过后UPDATE的SQL: {}", updateSql);
            metaObject.setValue("delegate.boundSql.sql", updateSql);
            return invocation.proceed();
        }
        // 需要过滤的数据
        String finalSql = this.handleSql(originalSql, dataScopeParam.getEntIdList(), dataScopeParam.getUnitField(), dataScopeParam.getIgnoreTables());
        log.warn("数据权限处理过后SELECT的SQL: {}", finalSql);

        // 装载改写后的sql
        metaObject.setValue("delegate.boundSql.sql", finalSql);
        return invocation.proceed();
    }


    /**
     * 修改sql
     *
     * @param originalSql 原始sql
     * @param entIdList   需要过滤的企业列表
     * @param fieldName   当前主表中字段名称
     * @return 修改后的语句
     * @throws JSQLParserException sql修改异常
     */
    private String handleSql(String originalSql, Set<Long> entIdList, String fieldName, List<String> ignores) throws JSQLParserException {
        CCJSqlParserManager parserManager = new CCJSqlParserManager();
        Select select = (Select) parserManager.parse(new StringReader(originalSql));
        SelectBody selectBody = select.getSelectBody();
        if (selectBody instanceof PlainSelect) {
            this.setWhere((PlainSelect) selectBody, entIdList, fieldName, ignores);
        } else if (selectBody instanceof SetOperationList) {
            SetOperationList setOperationList = (SetOperationList) selectBody;
            List<SelectBody> selectBodyList = setOperationList.getSelects();
            selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, entIdList, fieldName, ignores));
        }
        return select.toString();
    }

    /**
     * 修改update语句
     *
     * @param originalSql 元素sql
     * @param entIdList   允许查询的企业列表
     * @param fieldName   表中待过滤查询的列名
     * @param ignores     忽略的表名
     * @return
     * @throws JSQLParserException
     */
    private String handleUpdateSql(String originalSql, Set<Long> entIdList, String fieldName, List<String> ignores) throws JSQLParserException {
        CCJSqlParserManager parserManager = new CCJSqlParserManager();
        Update update = (Update) parserManager.parse(new StringReader(originalSql));
        if (ignores.contains(update.getTable().getName())) {
            // 当前表名的处于不过滤列表则不进行二次封装处理
            return originalSql;
        }
        String dataPermissionSql;
        if (entIdList.size() == 1) {
            EqualsTo selfEqualsTo = new EqualsTo();
            selfEqualsTo.setLeftExpression(new Column(fieldName));
            selfEqualsTo.setRightExpression(new LongValue(entIdList.stream().findFirst().orElse(0L)));
            dataPermissionSql = selfEqualsTo.toString();
        } else {
            dataPermissionSql = fieldName + " in ( " + CollUtil.join(entIdList, StringPool.COMMA) + " )";
        }
        update.setWhere(new AndExpression(update.getWhere(), CCJSqlParserUtil.parseCondExpression(dataPermissionSql)));
        return update.toString();
    }

    /**
     * 设置 where 条件  --  使用CCJSqlParser将原SQL进行解析并改写
     *
     * @param plainSelect 查询对象
     */
    @SneakyThrows(Exception.class)
    protected void setWhere(PlainSelect plainSelect, Set<Long> entIdList, String fieldName, List<String> ignores) {
        Table fromItem = (Table) plainSelect.getFromItem();
        // 有别名用别名,无别名用表名,防止字段冲突报错
        Alias fromItemAlias = fromItem.getAlias();
        if (ignores.contains(fromItem.getName())) {
            // 当前表名的处于不过滤列表则不进行二次封装处理
            return;
        }
        String mainTableName = fromItemAlias == null ? fromItem.getName() : fromItemAlias.getName();
        // 构建子查询 -- 数据权限过滤SQL
        String dataPermissionSql;
        if (entIdList.size() == 1) {
            EqualsTo selfEqualsTo = new EqualsTo();
            selfEqualsTo.setLeftExpression(new Column(mainTableName + "." + fieldName));
            selfEqualsTo.setRightExpression(new LongValue(entIdList.stream().findFirst().orElse(0L)));
            dataPermissionSql = selfEqualsTo.toString();
        } else if (entIdList.size() < 1) {
            dataPermissionSql = mainTableName + "." + fieldName + " in ( " + StringPool.NULL + " )";
        } else {
            dataPermissionSql = mainTableName + "." + fieldName + " in ( " + CollUtil.join(entIdList, StringPool.COMMA) + " )";
        }

        if (plainSelect.getWhere() == null) {
            plainSelect.setWhere(CCJSqlParserUtil.parseCondExpression(dataPermissionSql));
        } else {
            plainSelect.setWhere(new AndExpression(plainSelect.getWhere(), CCJSqlParserUtil.parseCondExpression(dataPermissionSql)));
        }
    }

    /**
     * 生成拦截对象的代理
     *
     * @param target 目标对象
     * @return 代理对象
     */
    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    /**
     * mybatis配置的属性
     *
     * @param properties mybatis配置的属性
     */
    @Override
    public void setProperties(Properties properties) {
        log.info(properties.toString());
    }
}

说明

  • 该注解可用于类或方法上,当用于类上时则该类中所有方法中涉及到的所有方法将会进行sql增强达到权限过滤的目的
  • 含有异步处理的方法需要添加@DataScope(filterData = false)用于忽略权限过滤。比如日志异步插入数据库的方法必须添加。