📃 关联文档

✨ 后续升级

权限数据过滤

定义一个注解用于开启权限过滤功能

这次没参与后台业务部分开发并不清楚哪些业务需要该功能,所以没有默认进行开启,将主动权交于业务开发人员手中

import java.lang.annotation.*;

import static java.lang.annotation.ElementType.*;

/**
 * 企业id数据过滤
 *
 * @author ChenQi
 */
@Target({METHOD, ANNOTATION_TYPE, TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DataScope {

    /**
     * 当进行过滤时主表中代表企业id的字段
     */
    String unitField() default "ent_id";

    /**
     * 是否进行数据过滤
     */
    boolean filterData() default true;
}

定义一个对象储存每次请求时相关接口过滤的需使用的数据

import lombok.AllArgsConstructor;
import lombok.Data;

import java.util.Set;

/**
 * 类 DataScopeParam
 * </p>
 *
 * @author ChenQi
 * @since 2022/10/20 17:37
 */
@Data
@AllArgsConstructor
public class DataScopeParam {
    /**
     * 企业筛选字段名称
     */
    private String unitField;

    /**
     * 企业数据范围
     */
    private Set<Long> entIdList;

    /**
     * 是否进行拦截
     */
    private boolean filterField;
}

使用阿里开源的TransmittableThreadLocal

<dependency>
    <groupId>com.alibaba</groupId>
    <artifactId>transmittable-thread-local</artifactId>
</dependency>

创建拦截器修改sql使其能够将权限过滤的字段代入

import cn.hutool.core.collection.CollUtil;
import com.alibaba.ttl.TransmittableThreadLocal;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.lyc.admin.oauth.service.SysUser;
import com.lyc.admin.oauth.utils.SecurityUtils;
import com.lyc.common.base.annotation.DataScope;
import com.lyc.common.base.constant.CommonConstants;
import com.lyc.common.base.utils.CurrentEntIdSearchContextHolder;
import com.lyc.common.base.vo.EntierVO;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SetOperationList;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.After;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;

import java.io.StringReader;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.util.Collection;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * 类 DataPermissionInterceptor
 * </p>
 *
 * @author ChenQi
 * @since 2022/10/20 14:50
 */
@Aspect
@Slf4j
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
@Component
public class UnitDataPermissionInterceptor implements Interceptor {

    ThreadLocal<DataScopeParam> threadLocal = new TransmittableThreadLocal<>();

    /**
     * 清空当前线程上次保存的权限信息
     */
    @After("dataScopePointCut()")
    public void clearThreadLocal() {
        threadLocal.remove();
    }

    /**
     * 配置织入点
     */
    @Pointcut("@annotation(com.lyc.common.base.annotation.DataScope)")
    public void dataScopePointCut() {
    }

    /**
     * @param point JoinPoint
     */
    @Before("dataScopePointCut()")
    public void doBefore(JoinPoint point) {
        // 获得注解
        DataScope controllerDataScope = getAnnotationLog(point);
        if (controllerDataScope != null && SecurityUtils.getUser() != null) {
            // 获取当前用户所具备的企业列表,此处是直接获取用户具备的机构树信息,从机构树中获取对应的企业列表,构建这个机构树是在用户登录时进行操作,此处不做展示
            SysUser sysUser = SecurityUtils.getUser();
            Set<Long> dataScope = sysUser.getTierVos().stream().map(EntierVO::getUnitIdList).flatMap(Collection::stream).collect(Collectors.toSet());
            // 对@DataScope中设置filterData设置为false的注解、管理员用户不进行权限过滤
            DataScopeParam dataScopeParam = new DataScopeParam(controllerDataScope.unitField(), dataScope, controllerDataScope.filterData() && !CommonConstants.SUPER_ADMIN.equals(sysUser.getId()));
            threadLocal.set(dataScopeParam);
            log.debug("当前用户可以查看的企业列表数据 = {}", dataScope);
        }
    }

    /**
     * 是否存在注解,如果存在就获取
     */
    private DataScope getAnnotationLog(JoinPoint joinPoint) {
        org.aspectj.lang.Signature signature = joinPoint.getSignature();
        MethodSignature methodSignature = (MethodSignature) signature;
        Method method = methodSignature.getMethod();
        if (method != null) {
            return method.getAnnotation(DataScope.class);
        }
        return null;
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        DataScopeParam dataScopeParam = threadLocal.get();

        // 获取header中的待过滤的企业列表
        Set<Long> entIdList = CurrentEntIdSearchContextHolder.getEntIdList();
        if (CollUtil.isNotEmpty(entIdList)) {
            if (dataScopeParam == null) {
                // 如果前端需要查询指定企业列表的数据,则主动创建一个DataScopeParam对象进行数据过滤
                dataScopeParam = new DataScopeParam("ent_id", entIdList, true);
            } else {
                // 获取主动查询的企业列表和用户权限所具备企业列表交集
                Set<Long> permissionEntList = dataScopeParam.getEntIdList();
                dataScopeParam.setFilterField(true);
                dataScopeParam.setEntIdList(entIdList.stream().filter(permissionEntList::contains).collect(Collectors.toSet()));
            }
        }

        // 没有添加注解则不往下执行
        if (dataScopeParam == null) {
            return invocation.proceed();
        }

        // 注解配置不过滤数据则不往下执行
        if (!dataScopeParam.isFilterField()) {
            return invocation.proceed();
        }

        SysUser sysUser = SecurityUtils.getUser();
        // 如果非权限用户则不往下执行
        if (sysUser == null) {
            return invocation.proceed();
        }

        StatementHandler statementHandler = PluginUtils.realTarget(invocation.getTarget());
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        // 先判断是不是SELECT操作 不是直接过滤
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        if (!SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {
            return invocation.proceed();
        }


        BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
        // 执行的SQL语句
        String originalSql = boundSql.getSql();
        // SQL语句的参数
        Object parameterObject = boundSql.getParameterObject();
        // 需要过滤的数据
        String finalSql = this.handleSql(originalSql, dataScopeParam.getEntIdList(), dataScopeParam.getUnitField());
        log.warn("数据权限处理过后的SQL: {}", finalSql);

        // 装载改写后的sql
        metaObject.setValue("delegate.boundSql.sql", finalSql);
        return invocation.proceed();
    }


    /**
     * 修改sql
     *
     * @param originalSql 原始sql
     * @param entIdList   需要过滤的企业列表
     * @param fieldName   当前主表中字段名称
     * @return 修改后的语句
     * @throws JSQLParserException
     */
    private String handleSql(String originalSql, Set<Long> entIdList, String fieldName) throws JSQLParserException {
        CCJSqlParserManager parserManager = new CCJSqlParserManager();
        Select select = (Select) parserManager.parse(new StringReader(originalSql));
        SelectBody selectBody = select.getSelectBody();
        if (selectBody instanceof PlainSelect) {
            this.setWhere((PlainSelect) selectBody, entIdList, fieldName);
        } else if (selectBody instanceof SetOperationList) {
            SetOperationList setOperationList = (SetOperationList) selectBody;
            List<SelectBody> selectBodyList = setOperationList.getSelects();
            selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, entIdList, fieldName));
        }
        return select.toString();
    }

    /**
     * 设置 where 条件  --  使用CCJSqlParser将原SQL进行解析并改写
     *
     * @param plainSelect 查询对象
     */
    @SneakyThrows(Exception.class)
    protected void setWhere(PlainSelect plainSelect, Set<Long> entIdList, String fieldName) {
        Table fromItem = (Table) plainSelect.getFromItem();
        // 有别名用别名,无别名用表名,防止字段冲突报错
        Alias fromItemAlias = fromItem.getAlias();
        String mainTableName = fromItemAlias == null ? fromItem.getName() : fromItemAlias.getName();
        // 构建子查询 -- 数据权限过滤SQL
        String dataPermissionSql = "";
        // 当只有一条数据时直接使用ent_id = #{ent_id}
        if (entIdList.size() == 1) {
            EqualsTo selfEqualsTo = new EqualsTo();
            selfEqualsTo.setLeftExpression(new Column(mainTableName + "." + fieldName));
            selfEqualsTo.setRightExpression(new LongValue(entIdList.stream().findFirst().get()));
            dataPermissionSql = selfEqualsTo.toString();
        } else {
            dataPermissionSql = mainTableName + "." + fieldName + " in ( " + CollUtil.join(entIdList, StringPool.COMMA) + " )";
        }

        if (plainSelect.getWhere() == null) {
            plainSelect.setWhere(CCJSqlParserUtil.parseCondExpression(dataPermissionSql));
        } else {
            plainSelect.setWhere(new AndExpression(plainSelect.getWhere(), CCJSqlParserUtil.parseCondExpression(dataPermissionSql)));
        }
    }

    /**
     * 生成拦截对象的代理
     *
     * @param target 目标对象
     * @return 代理对象
     */
    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    /**
     * mybatis配置的属性
     *
     * @param properties mybatis配置的属性
     */
    @Override
    public void setProperties(Properties properties) {

    }
}

考虑到机构用户会指定查询某企业的数据,将以上权限过滤部分改写使其满足新的需求

添加holder用户储存接口请求中需要过滤的企业列表

import com.alibaba.ttl.TransmittableThreadLocal;
import lombok.experimental.UtilityClass;

import java.util.Set;

/**
 * 类 CurrentEntIdSearchContextHolder
 * </p>
 *
 * @author ChenQi
 * @since 2022/10/21 10:13
 */
@UtilityClass
public class CurrentEntIdSearchContextHolder {

    private final ThreadLocal<Set<Long>> THREAD_LOCAL_ENT_LIST = new TransmittableThreadLocal<>();

    /**
     * 设置当前header中的企业列表
     *
     * @param entIdList 需要查询的企业列表
     */
    public void setEntIdList(Set<Long> entIdList) {
        THREAD_LOCAL_ENT_LIST.set(entIdList);
    }

    /**
     * 获取header中的企业列表
     *
     * @return 企业列表
     */
    public Set<Long> getEntIdList() {
        return THREAD_LOCAL_ENT_LIST.get();
    }

    public void clear() {
        THREAD_LOCAL_ENT_LIST.remove();
    }
}

添加过滤器获取并储存待过滤的企业列表

import cn.hutool.core.convert.Convert;
import cn.hutool.core.util.StrUtil;
import com.lyc.common.base.constant.CommonConstants;
import com.lyc.common.base.utils.CurrentEntIdSearchContextHolder;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.GenericFilterBean;

import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.HashSet;
import java.util.Set;

/**
 * 类 ContextHolderFilter
 * </p>
 *
 * @author ChenQi
 * @since 2022/10/21 10:21
 */
@Slf4j
@Component
@Order(Ordered.HIGHEST_PRECEDENCE)
public class EntIdContextHolderFilter extends GenericFilterBean {

    @Override
    @SneakyThrows
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        HttpServletResponse response = (HttpServletResponse) servletResponse;

        Set<Long> entIdList = new HashSet<>();
        String entIdListStr = request.getHeader(CommonConstants.ENT_ID_LIST);

        if (StrUtil.isNotBlank(entIdListStr)) {
            entIdList = Convert.toSet(Long.class, entIdListStr);
            log.debug("获取header中的企业列表为:{}", entIdList);
        }
        CurrentEntIdSearchContextHolder.setEntIdList(entIdList);

        filterChain.doFilter(request, response);
        CurrentEntIdSearchContextHolder.clear();
    }
}

使用方式

添加注解用于过滤数据

同时支持mybatis plus的api和xml中的sql,但是@DataScope中设定的unitField的过滤字段必须在sql的主表中

注解添加在controller中,用于使用mybatis plus api的情况

image-20221021133318823

注解添加在controller或者dao层方法上,用于使用xml中自定义sql的情况

image-20221021134449090

指定查询部分企业列表

在header中添加entIdList

image-20221021134613580