MyBatis Sql拦截器(自定义注解实现多租户查询)


转自 : https://blog.csdn.net/weixin_44600430/article/details/112108902

MyBatis拦截器(自定义注解+实现多租户查询)
前言:
公司现有运营管理平台上的功能都要增加多租户, 原本功能都是单租户。

就是要做数据隔离, 登录用户只能看到当前登录用户名下数据, 关键数据表都加了个用户ID字段, 之前的功能都已经写好, 所以就在想怎么在最少改动代码的情况下实现给之前的所有查询增加一个查询条件=值, 后来想到利用mybatis拦截器动态修改sql进行拼接多个查询。

下面就开始利用来进行实现。 (技术框架1.4.8, 公司用的版本太低, 好像mybatis-plus在2.1版本 也增加了多租户拦截器, 但是还是不能完全满足我现有需求)

使用到的技术有: ,

1.0 自定义MyBatis拦截器

/**
 * @Author: ZhiHao
 * @Date: 2020/12/16 16:37
 * @Description: 代理商通用sql拼接拦截器(兼容之前查询, 仅对方法标记注解生效)
 * @Versions 1.0
 **/
//@Component
@Intercepts({@Signature(
        type = StatementHandler.class, //拦截构建sql语句的StatementHandler
        method = "prepare",   //里面的prepare方法
        args = {
                Connection.class,  //方法的参数
                Integer.class
        }
)})
public class AgentSqlInterceptor implements Interceptor {

    private Logger logger = LoggerFactory.getLogger(getClass());

 	// 是否进行拦截动态修改sql
    private boolean required;
	// 表别名
    private String tableAlias;

    public AgentSqlInterceptor(boolean required, String tableAlias) {
        this.required = required;
        this.tableAlias = tableAlias;
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // 需要修改sql语句才拦截
        if (required) {
            StatementHandler statementHandler = (StatementHandler) PluginUtils.realTarget(invocation.getTarget());
            MetaObject metaStatementHandler = SystemMetaObject.forObject(statementHandler);
            // 先判断是不是SELECT操作
            MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement");
            if (!SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {
                return invocation.proceed();
            }
            BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
            String sql = boundSql.getSql();
            logger.info("之前sql语句:{}", sql);
            // 判断是否符合需要增加区分代理商查询条件
            sql = this.ifAgentQuery(sql);
            logger.info("代理商查询sql语句:{}", sql);
            // 最终将修改好的sql语句设置回去执行
            metaStatementHandler.setValue("delegate.boundSql.sql", sql);
        }
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        // 返回拦截器本身, 还是返回目标本身
        return target instanceof StatementHandler ? Plugin.wrap(target, this) : target;
    }

    @Override
    public void setProperties(Properties properties) {

    }

    /**
     * 判断是否需要进行拼接代理商查询条件
     * 有以下几种情况: (0 运营商, 不进行拼接查询全部) , (1 代理商并且是总仓管理员, 仅拼接代理商字段查询)
     * (2 代理商并且是分仓管理员, 需拼接代理商字段与仓库字段查询) , (3 代理商并且是总仓管理员, 需拼接代理商字段查询)
     *
     * @param sql
     * @return java.lang.String
     * @author: ZhiHao
     * @date: 2020/12/17
     */
    private String ifAgentQuery(String sql) {
        // sql解析器解析查询语句(也只有查询能进来)
        Select selectStatement = (Select) CCJSqlParserUtil.parse(sql);
        // 不是单表与多表直接查询
        if (!(selectStatement.getSelectBody() instanceof PlainSelect)) {
            return sql;
        }
        PlainSelect selectBody = (PlainSelect) selectStatement.getSelectBody();
        // 仅对枚举符合表, 并没有代理商字段的进行拼接sql
        Table table = this.ifConform(selectBody);
        if (table != null && this.doesItExist(selectBody.getWhere())) {
            // 进行拼接, 判断是否有where
            Expression where = selectBody.getWhere();
            // 获取查询语句表别名(兼容之前做好功能一个方法表别名不一致)
            String sqlTableAlias = table.getAlias() != null ? table.getAlias().getName() : null;
            String queryConditionsAndValues;
            if (where != null) {
                // 获取代理商查询条件与值
                queryConditionsAndValues = this.getQueryConditionsAndValues( table.getName(),
                        StringUtils.isNotBlank(sqlTableAlias) ? sqlTableAlias : this.tableAlias);
                 // 解析之前表达式, 使用弱引用
                WeakReference weakReference = new WeakReference<>(new ExpressionDeParser());
                where.accept(weakReference.get());
                // 获取之前表达式
                StringBuilder buffer = weakReference.get().getBuffer();
                // 拼接sql
                buffer.append(queryConditionsAndValues);
                // 设置回去
                selectBody.setWhere(CCJSqlParserUtil.parseCondExpression(buffer.toString()));
                weakReference.clear();
            } else {
                // 获取查询条件与值
                queryConditionsAndValues = " 1 = 1 " + this.getQueryConditionsAndValues(table.getName(),
                        StringUtils.isNotBlank(sqlTableAlias) ? sqlTableAlias : this.tableAlias);
                //没有where情况
                selectBody.setWhere(CCJSqlParserUtil.parseCondExpression(queryConditionsAndValues));
            }
            sql = selectStatement.toString();
        }
        return sql;
    }
    
    /**
     * 仅对dms_battery表查询并没有代理商字段的进行拼接sql,
     * 防止拦截器开启时, 兼容以后自定义查询语句包含查询字段的进行了拼接出错
     *
     * @param selectBody
     * @return boolean
     * @author: ZhiHao
     * @date: 2020/12/18
     */
    private Table ifConform(PlainSelect selectBody) {
        // 判断from后面是否符合需要拼接的表名
        FromItem fromItem = selectBody.getFromItem();
        if (fromItem instanceof Table) {
            Table table = (Table) fromItem;
            if (TenantTable.tableMap.get(table.getName()) != null) {
                return table;
            }
        }
        // 上面from后面不满足则判断多表情况是否包含
        List joins = selectBody.getJoins();
        Table table = null;
        if (joins != null && joins.size() > 0) {
            Optional any = joins.stream().filter((join) -> {
                FromItem rightItem = join.getRightItem();
                if (rightItem instanceof Table) {
                    return TenantTable.tableMap.get(((Table) rightItem).getName()) != null ? true : false;
                }
                return false;
            }).findAny();
            table = any.isPresent() ? (Table) any.get().getRightItem() : null;
        }
        return table != null ? table : null;
    }

    /**
     * 判断之前sql是否存在了代理商查询字段
     *
     * @param sql
     * @return boolean
     * @author: ZhiHao
     * @date: 2020/12/21
     */
    private boolean doesItExist(Expression sql) {
        String str = sql != null ? sql.toString() : null;
        if (StringUtils.containsIgnoreCase(str, TenantTable.AGENT_ID.getColumns(null))
                || StringUtils.containsIgnoreCase(str, TenantTable.DEPOT_ID.getColumns(null))) {
            return false;
        }
        return true;
    }

    private final Integer AGENT = 0; //代理商
    private final Integer OPERATOR = 1; //运营商

    /**
     * 根据表名构建条件
     *
     * @param tableName  表名
     * @param tableAlias 别名
     * @return java.lang.String
     * @author: ZhiHao
     * @date: 2020/12/21
     */
    public String getQueryConditionsAndValues(String tableName, String tableAlias) {
        StringBuilder builder = new StringBuilder();
        // 获取登录用户
        IDmsUserService dmsUserService = SpringUtils.getBean(IDmsUserService.class);
        DmsUser dmsUser = dmsUserService.getCurrentUser();
        Integer type = dmsUser.getType();
        Integer terminal = dmsUser.getTerminal();
        Integer agentOrOperatorId = dmsUser.getAgentOrOperatorId();
        Integer depotId = dmsUser.getDepotId();
        switch (TenantTable.tableMap.get(tableName)) {
            // 电池表
            case DMS_BATTERY:
                // 是运营商并是总仓不做拼接可见全部
                if (OPERATOR.equals(type) && TerminalEnums.WEB_ADMIN.getCode().equals(terminal)) {
                    return "";
                }
                // 是运营商并是分仓拼接可见分仓
                if (OPERATOR.equals(type) && TerminalEnums.WEB.getCode().equals(terminal)) {
                    builder.append(" AND ")
                            .append(TenantTable.DEPOT_ID.getColumns(tableAlias))
                            .append(" = ")
                            .append(depotId)
                            .append(" ");
                    return builder.toString();
                }
                // 是代理商并是总仓拼接可见总+分仓
                if (AGENT.equals(type) && TerminalEnums.WEB_ADMIN.getCode().equals(terminal)) {
                    builder.append(" AND ")
                            .append(TenantTable.AGENT_ID.getColumns(tableAlias))
                            .append(" = ")
                            .append(agentOrOperatorId)
                            .append(" ");
                    return builder.toString();
                }
                // 是代理商并是分仓拼接可见分仓
                if (AGENT.equals(type) && TerminalEnums.WEB.getCode().equals(terminal)) {
                    builder.append(" AND ")
                            .append(TenantTable.AGENT_ID.getColumns(tableAlias))
                            .append(" = ")
                            .append(agentOrOperatorId)
                            .append(" AND ")
                            .append(TenantTable.DEPOT_ID.getColumns(tableAlias))
                            .append(" = ")
                            .append(depotId)
                            .append(" ");
                    return builder.toString();
                }
                // 设备表
            case DEVICE:
                if (AGENT.equals(type)) {
                    builder.append(" AND ")
                            .append(TenantTable.AGENT_ID.getColumns(tableAlias))
                            .append(" = ")
                            .append(agentOrOperatorId)
                            .append(" AND ")
                            // 只查询柜子
                            .append(TenantTable.DEVICE_TYPE.getColumns(tableAlias))
                            .append(" = 1 ");
                    return builder.toString();
                }
            default:
                break;
        }

        return null;
    }

    public void setRequired(boolean required) {
        this.required = required;
    }

    public void setTableAlias(String tableAlias) {
        this.tableAlias = tableAlias;
    }
}

2.0 利用AOP+注解实现标记方法才进行拦截

PS: 如果查询方法都是写在DAO层接口里面的, 可以不使用AOP (具体看扩展) , 因为使用到了mybatis-plus很多查询方法都是使用其提供的, 所以注解只能标记到service层,

/**
 * @Author: ZhiHao
 * @Date: 2020/12/17 15:08
 * @Description: 需要增加代理商查询条件 agent_id = xx 的表名
 * @Versions 1.0
 **/
@Aspect
@Component
public class AgentMethodAspect {

    private Logger log = LoggerFactory.getLogger(getClass());

    //会话工厂
    @Autowired
    private SqlSessionFactory sqlSessionFactory;
    
    private final StampedLock lock = new StampedLock();

    /**
     * 切入点
     *
     * @author: ZhiHao
     * @date: 2020/12/17
     */
    @Pointcut("@annotation(com.xxx.xxx.multitenant.RequiredTenant)")
    public void requiredTenant() {

    }

    /**
     * 环绕通知
     *
     * @param point
     * @return java.lang.Object
     * @author: ZhiHao
     * @date: 2020/12/17
     */
    @Around("requiredTenant()")
    public Object around(ProceedingJoinPoint point) {
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        RequiredTenant requiredTenant = method.getAnnotation(RequiredTenant.class);
        String alias = requiredTenant.tableAlias();
        // 添加过滤器进sql会话工厂配置
        Configuration configuration = sqlSessionFactory.getConfiguration();
        // 判断是否有别名
        AgentSqlInterceptor agentSqlInterceptor = (AgentSqlInterceptor) configuration.getInterceptors().stream().filter(
                interceptor -> interceptor instanceof AgentSqlInterceptor ? true : false
        ).findAny().get();
        
        try {
             // 进行加锁, 控制并发将拦截设置为取消
             if (lock.tryWriteLock(30, TimeUnit.SECONDS) != 0) {
                // 设置拦截与别名
                agentSqlInterceptor.setRequired(true);
                agentSqlInterceptor.setTableAlias(StringUtils.isNotBlank(alias) ? alias : null);
                // 继续执行方法
                return point.proceed();
            }
        } catch (Throwable throwable) {
            throwable.printStackTrace();
            log.info("加锁失败:{}",throwable.getMessage());
        } finally {
            // 执行完毕都将其修改回未拦截标记注解其他请求
            agentSqlInterceptor.setRequired(false);
            // 释放锁
            lock.tryUnlockWrite();
        }
        return null;
    }
    
    /**
     * 仅做首次添加拦截器
     *
     * @author: ZhiHao
     * @date: 2020/12/24
     */
    @Override
    public void afterPropertiesSet() throws Exception {
        AgentSqlInterceptor agentSqlInterceptor = new AgentSqlInterceptor(false, null);
        sqlSessionFactory.getConfiguration().addInterceptor(agentSqlInterceptor);
        log.info("首次添加AgentSqlInterceptor拦截器:{}", agentSqlInterceptor);
    }
}

3.0 注解

/**
 * @Author: ZhiHao
 * @Date: 2020/12/16 19:01
 * @Description: 是否需要代理商查询
 * @Versions 1.0
 **/
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RequiredTenant {

    /**
     * 查询语句表别名(按需选择)
     * @return
     */
    String tableAlias() default "";
}

4.0 枚举

/**
 * @Author: ZhiHao
 * @Date: 2020/12/17 10:08
 * @Description: 需要增加代理商查询条件 agent_id = xx 的表名
 * @Versions 1.0
 **/
public enum TenantTable {

    DMS_BATTERY("dms_battery", null),
    DEVICE("device", null),
    // 代理商查询字段
    AGENT_ID(null, "agent_id"),
    // 仓库查询字段
    DEPOT_ID(null, "depot_id"),
    // 柜子类型查询字段
    DEVICE_TYPE(null, "type"),
    ;

    private String tableName;

    private String columns;

    TenantTable(String tableName, String columns) {
        this.tableName = tableName;
        this.columns = columns;
    }

    public static Map tableMap;

    static {
        tableMap = Arrays.stream(TenantTable.values())
                .collect(Collectors.toMap(tenantTable -> tenantTable.tableName,
                        tenantTable -> tenantTable, (tenantTable, tenantTable2) -> tenantTable2));
    }

    /**
     * 获取表名
     *
     * @return java.lang.String
     * @author: ZhiHao
     * @date: 2020/12/18
     */
    public String getTableName() {
        return tableName;
    }

    /**
     * 获取查询字段
     *
     * @param tableAlias 表别名
     * @return java.lang.String
     * @author: ZhiHao
     * @date: 2020/12/18
     */
    public String getColumns(String tableAlias) {
        if (StringUtils.isNotBlank(tableAlias)) {
            String str = tableAlias + "." + this.columns;
            return str;
        }
        return columns;
    }
}

5.0 Service层方法使用 (测试)

	@RequiredTenant
    @Override
    public List getDepreciation(xxxx indexDataQueryDto) { 
    	// 自定义方法查询
        List listDto = batteryDao.getDepreciation(xxx);
        // mybatis-plus提供方法
        xxxxxx Battery = selectById(xxx);
}

结果:

首次添加AgentSqlInterceptor拦截器:com.gizwits.lease.multitenant.AgentSqlInterceptor@4bd8a2c7
之前sql语句:SELECT IFNULL( db.life - TIMESTAMPDIFF(MONTH,db.initial_time,NOW()) , IFNULL(db.life - TIMESTAMPDIFF(MONTH,db.first_service_time,NOW()),IFNULL(db.life - TIMESTAMPDIFF(MONTH,db.ctime,NOW()),null))) month,
        count(1) number
        FROM dms_battery db WHERE db.status in (0,2,3,6,7,8) AND db.is_cancellation = 1    
        GROUP BY month
 代理商查询sql语句:SELECT IFNULL( db.life - TIMESTAMPDIFF(MONTH,db.initial_time,NOW()) , IFNULL(db.life - TIMESTAMPDIFF(MONTH,db.first_service_time,NOW()),IFNULL(db.life - TIMESTAMPDIFF(MONTH,db.ctime,NOW()),null))) month,
        count(1) number
        FROM dms_battery db WHERE 1 = 1  AND db.agent_id = 8  AND  db.status in (0,2,3,6,7,8) AND db.is_cancellation = 1
        GROUP BY month
之前sql语句:SELECT COUNT(1) FROM dms_battery WHERE (is_cancellation = ? AND is_deleted = ? AND status NOT IN (?,?,?))
代理商查询sql语句:SELECT COUNT(1) FROM dms_battery WHERE 1 = 1  AND agent_id = 8  AND  (is_cancellation = ? AND is_deleted = ? AND status NOT IN (?,?,?))

扩展:

Mybatis-自定义注解加拦截器

最后加锁控制临界资源, 可以更换为使用ThreadLocal