chenjiahe
2022-01-13 c64e1248bfda3ac8c5120e529fd096dfc4846629
AES加密插件
2 文件已重命名
1个文件已删除
4个文件已添加
1个文件已修改
861 ■■■■ 已修改文件
src/main/java/com/hx/common/annotations/MysqlHexAes.java 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/mybatis/aes/handler/GenericStringHandler.java 24 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/mybatis/aes/springbean/ConstantBean.java 35 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/mybatis/aes/springbean/ExportTableAliasVisitor.java 53 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/mybatis/aes/springbean/MySqlInterceptor.java 96 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/mybatis/aes/springbean/SqlUtils.java 481 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/mybatis/aes/springbean/VariableAesKey.java 126 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/springbean/ConstantBean.java 34 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/common/annotations/MysqlHexAes.java
@@ -10,11 +10,13 @@
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MysqlHexAes {
    /**秘钥字段*/
    String aesKeyField() default "";
    /**秘钥*/
    String aesKey();
    /**查询解密*/
    /**秘钥-没有就是配置文件设置*/
    String aesKey() default "";
    /**xml生成查询解密*/
    boolean selectDec() default false;
    /**xml更新加密*/
    boolean updateDec() default false;
    /**xml新增加密*/
    boolean insertDec() default false;
}
src/main/java/com/hx/mybatis/aes/handler/GenericStringHandler.java
File was renamed from src/main/java/com/hx/mybatis/handler/aes/GenericStringHandler.java
@@ -1,6 +1,6 @@
package com.hx.mybatis.handler.aes;
package com.hx.mybatis.aes.handler;
import com.hx.springbean.VariableAesKey;
import com.hx.mybatis.aes.springbean.VariableAesKey;
import com.hx.util.mysql.aes.MysqlHexAes;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
@@ -44,12 +44,28 @@
    @Override
    public String getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
        return rs.getString(columnIndex);
        String data = rs.getString(columnIndex);
        if(data != null && data.length()%32==0 && MysqlHexAes.isHexStrValid(data)){
            try{
                data = MysqlHexAes.decryptData(data, VariableAesKey.getAesKey(null),null);
            }catch (Exception e){
                //e.printStackTrace();
            }
        }
        return data;
    }
    @Override
    public String getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
        return cs.getString(columnIndex);
        String data = cs.getString(columnIndex);
        if(data != null && data.length() < 129 && data.length()%32==0 && MysqlHexAes.isHexStrValid(data)){
            try{
                data = MysqlHexAes.decryptData(data, VariableAesKey.getAesKey(null),null);
            }catch (Exception e){
                //e.printStackTrace();
            }
        }
        return data;
    }
}
src/main/java/com/hx/mybatis/aes/springbean/ConstantBean.java
New file
@@ -0,0 +1,35 @@
package com.hx.mybatis.aes.springbean;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
/**
 * 通用常量集中营
 * @author CJH
 */
@Component
public class ConstantBean {
    /**获取AES秘钥的配置(从什么包获取到)*/
    @Value("${mysql.hxe.aes.find.packs:null}")
    private String packPath;
    /**固定AES的秘钥*/
    @Value("${mysql.hxe.aes.fixd.key:null}")
    private String fixedAesKey;
    public String getPackPath() {
        return packPath;
    }
    public void setPackPath(String packPath) {
        this.packPath = packPath;
    }
    public String getFixedAesKey() {
        return fixedAesKey;
    }
    public void setFixedAesKey(String fixedAesKey) {
        this.fixedAesKey = fixedAesKey;
    }
}
src/main/java/com/hx/mybatis/aes/springbean/ExportTableAliasVisitor.java
New file
@@ -0,0 +1,53 @@
package com.hx.mybatis.aes.springbean;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.visitor.SQLASTVisitorAdapter;
import java.util.HashMap;
import java.util.Map;
/**
 * ExportTableAliasVisitor
 * @author Mwg
 * @date 2020/09/08 23:47
 */
public class ExportTableAliasVisitor extends SQLASTVisitorAdapter {
    private Map<String,String> tableMap = new HashMap<>();
    public Map<String, String> getTableMap() {
        return tableMap;
    }
    public void setTableMap(Map<String, String> tableMap) {
        this.tableMap = tableMap;
    }
    @Override
    public boolean visit(SQLExprTableSource x) {
        //别名,如果有别名,别名保持不变
        //System.out.println("alias:"+x.getAlias());//别名
        //System.out.println("expr:"+x.getExpr());//表名
        tableMap.put(x.getExpr().toString(),x.getAlias());
        //String s = StringUtils.isEmpty(x.getAlias()) ? x.getExpr().toString() : x.getAlias();
        // 修改表名,不包含点才加 select id from c left join d on c.id = d.id 中的c 和 d
        /*if(!x.getExpr().toString().contains(".")) {
            x.setExpr("`" + dbName.get() + "`." + x.getExpr());
        }*/
        //x.setExpr("mymymytable");//修改表名
        //x.setAlias("aa");//修改别名
        //x.setAlias(s);
        return true;
    }
}
src/main/java/com/hx/mybatis/aes/springbean/MySqlInterceptor.java
New file
@@ -0,0 +1,96 @@
package com.hx.mybatis.aes.springbean;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.springframework.stereotype.Component;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.util.List;
import java.util.Properties;
@Component
@Intercepts({
        @Signature(
                type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class
        })
})
public class MySqlInterceptor implements Interceptor {
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // 方法一
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY, SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
        //先拦截到RoutingStatementHandler,里面有个StatementHandler类型的delegate变量,其实现类是BaseStatementHandler,然后就到BaseStatementHandler的成员变量mappedStatement
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        //id为执行的mapper方法的全路径名,如com.uv.dao.UserMapper.insertUser
        //String id = mappedStatement.getId();
        //sql语句类型 select、delete、insert、update
        SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
        BoundSql boundSql = statementHandler.getBoundSql();
        // 获取节点的配置
        Configuration configuration = mappedStatement.getConfiguration();
        // 获取参数
        Object parameterObject = boundSql.getParameterObject();
        // MetaObject主要是封装了originalObject对象,提供了get和set的方法用于获取和设置originalObject的属性值,主要支持对JavaBean、Collection、Map三种类型对象的操作
        // MetaObject metaObject1 = configuration.newMetaObject(parameterObject);
        //获取sql中问号的基本信息
        List<ParameterMapping> parameterMappings = boundSql
                .getParameterMappings();
        /*for (ParameterMapping parameterMapping : parameterMappings) {
            String propertyName = parameterMapping.getProperty();
            System.out.println("propertyName:"+ propertyName);
            System.out.println("parameterObject:"+ parameterObject);
        }*/
        //这里可以进行sql修改
        //获取到原始sql语句
        String sql = boundSql.getSql();
        //新增
        if(sqlCommandType == SqlCommandType.INSERT){
            sql = SqlUtils.insertSql(sql, VariableAesKey.aesKeysTable);
        }else if(sqlCommandType == SqlCommandType.UPDATE){
            sql = SqlUtils.updateSql(sql, VariableAesKey.aesKeysTable);
        }else if(sqlCommandType == SqlCommandType.SELECT){
            if(VariableAesKey.isRun == 1){
                sql = SqlUtils.selectSql(sql, VariableAesKey.aesKeysTable);
            }
        }else if(sqlCommandType == SqlCommandType.DELETE){
            sql = SqlUtils.deleteSql(sql, VariableAesKey.aesKeysTable);
        }
        //通过反射修改sql语句
        Field field = boundSql.getClass().getDeclaredField("sql");
        field.setAccessible(true);
        field.set(boundSql, sql);
        return invocation.proceed();
    }
    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }
    @Override
    public void setProperties(Properties properties) {
    }
}
src/main/java/com/hx/mybatis/aes/springbean/SqlUtils.java
New file
@@ -0,0 +1,481 @@
package com.hx.mybatis.aes.springbean;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLInListExpr;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.mysql.ast.clause.MySqlSelectIntoStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
import com.alibaba.druid.stat.TableStat;
import com.alibaba.druid.util.JdbcConstants;
import com.alibaba.druid.util.JdbcUtils;
import com.hx.util.StringUtils;
import java.util.Collection;
import java.util.List;
import java.util.Map;
/**
 * sql语句处理工具
 * @author CJH 2022-01-12
 */
public class SqlUtils {
    /**查询加密数据处理,只对查询做处理,select返回不做处理
     * @param sql sql语句
     * @param aesKeysTable aes秘钥
     * @return
     */
    public static String selectSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
        //获取格式化的slq语句
        sql = sqlStatement.toString();
        //解析select查询
        //SQLSelect sqlSelect = sqlStatement.getSelect()
        //获取sql查询块
        SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
        StringBuffer out = new StringBuffer() ;
        //创建sql解析的标准化输出
        SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
        //获取表和别名
        ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
        sqlStatement.accept(visitorTable);
        Map<String,String> tableMaps = visitorTable.getTableMap();
        //获取所有的字段
        MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
        sqlStatement.accept(visitor);
        //遍历所有字段
        Collection<TableStat.Column> columns= visitor.getColumns();
        StringBuilder sqlWhere = new StringBuilder();
        StringBuilder sqlSelect = new StringBuilder();
        String expr = null;
        sqlSelect.append("SELECT ");
        //解析select返回的数据字段项
        for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
            if(sqlSelect.length() > 7){
                sqlSelect.append(",");
            }
            expr = sqlSelectItem.getExpr().toString();
            if(expr.indexOf("SELECT") == -1){
                sqlSelect.append(expr);
                if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                    sqlSelect.append(" AS "+sqlSelectItem.getAlias());
                }
            }else{
                sqlSelect.append("(");
                sqlSelect.append(selectSqlHandle(expr,aesKeysTable,tableMaps,columns));
                sqlSelect.append(")");
                if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                    sqlSelect.append(" AS "+sqlSelectItem.getAlias());
                }
            }
        }
        //解析from
        out.delete(0, out.length()) ;
        sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" FROM "+out);
        //解析where
        out.delete(0, out.length()) ;
        sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" WHERE "+out+" ");
        if(sqlSelectQuery.getGroupBy() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" "+out);
        }
        if(sqlSelectQuery.getOrderBy() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" "+out);
        }
        if(sqlSelectQuery.getLimit() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" "+out);
        }
        //处理where需要加密得字段
        sql = sqlWhere.toString();
        if(!StringUtils.isEmpty(sql)){
            Map<String,String> aesKeys = null;
            String aeskey = null;
            //把剩下的拼接上来
            String tableAl = null;
            for(TableStat.Column column:columns){
                aesKeys= aesKeysTable.get(column.getTable());
                if(aesKeys == null){
                    continue;
                }
                aeskey = aesKeys.getOrDefault(column.getName(),null);
                if(StringUtils.isEmpty(aeskey)){
                    continue;
                }
                tableAl = tableMaps.get(column.getTable());
                if(!StringUtils.isEmpty(tableAl)){
                    tableAl = tableAl+"."+column.getName();
                }else{
                    tableAl = column.getName();
                }
                sql = sql.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
            }
        }
        return sqlSelect.toString()+sql;
    }
    public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable
            ,Map<String,String> tableMaps,Collection<TableStat.Column> columns){
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
        //获取格式化的slq语句
        sql = sqlStatement.toString();
        //解析select查询
        //SQLSelect sqlSelect = sqlStatement.getSelect() ;
        //获取sql查询块
        SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
        StringBuffer out = new StringBuffer() ;
        //创建sql解析的标准化输出
        SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
        StringBuilder sqlWhere = new StringBuilder();
        StringBuilder sqlSelect = new StringBuilder();
        String expr = null;
        sqlSelect.append("SELECT ");
        //解析select返回的数据字段项
        for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
            if(sqlSelect.length() > 7){
                sqlSelect.append(",");
            }
            expr = sqlSelectItem.getExpr().toString();
            if(expr.indexOf("SELECT") == -1){
                sqlSelect.append(expr);
                if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                    sqlSelect.append(" AS "+sqlSelectItem.getAlias());
                }
            }else{
                sqlSelect.append("(");
                selectSqlHandle(expr,aesKeysTable,tableMaps,columns);
                sqlSelect.append(")");
                if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                    sqlSelect.append(" AS "+sqlSelectItem.getAlias());
                }
            }
        }
        //解析from
        out.delete(0, out.length()) ;
        sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" FROM "+out);
        //解析where
        out.delete(0, out.length()) ;
        sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" WHERE "+out+" ");
        if(sqlSelectQuery.getGroupBy() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" "+out);
        }
        if(sqlSelectQuery.getOrderBy() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" "+out);
        }
        if(sqlSelectQuery.getLimit() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" "+out);
        }
        sql = sqlWhere.toString();
        if(!StringUtils.isEmpty(sql)){
            Map<String,String> aesKeys = null;
            String aeskey = null;
            //把剩下的拼接上来
            String tableAl = null;
            for(TableStat.Column column:columns){
                aesKeys= aesKeysTable.get(column.getTable());
                if(aesKeys == null){
                    continue;
                }
                aeskey = aesKeys.getOrDefault(column.getName(),null);
                if(StringUtils.isEmpty(aeskey)){
                    continue;
                }
                tableAl = tableMaps.get(column.getTable());
                if(!StringUtils.isEmpty(tableAl)){
                    tableAl = tableAl+"."+column.getName();
                }else{
                    tableAl = column.getName();
                }
                sql = sql.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
            }
        }
        return sqlSelect.toString()+sql;
    }
    /**新增加密数据处理
     * @param sql sql语句
     * @param aesKeysTable aes秘钥
     * @return
     */
   public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
       //装载重写的sql语句
       StringBuilder splicingSql = new StringBuilder();
       sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
       String[] datas = sql.split("VALUES",2);
       splicingSql.append(datas[0]+"VALUES ");
       //重新拼接SQL语句
       //解析sql语句
       MySqlStatementParser parser = new MySqlStatementParser(sql);
       SQLStatement statement = parser.parseStatement();
       MySqlInsertStatement insert = (MySqlInsertStatement)statement;
       String insertName = insert.getTableName().getSimpleName();
       //根据表名称获取到AES秘钥
       Map<String,String> aesKeys= aesKeysTable.get(insertName);
       if(aesKeys == null){
           return sql;
       }
       //获取所有的字段
       List<SQLExpr> columns = insert.getColumns();
       String fildValue = null;
       String aeskey = null;
       //遍历值
       List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
       for(int j = 0; j<vcl.size(); j++){
           if( j != 0){
               splicingSql.append(",");
           }
           for(int i = 0;i < columns.size();i++){
               //查询改字段是否需要加密
               aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
               fildValue = vcl.get(j).getValues().get(i).toString();
               if(i == 0){
                   splicingSql.append("(");
                   if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                       splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                   }else{
                       splicingSql.append(fildValue);
                   }
               }else if(i == columns.size()-1){
                   splicingSql.append(",");
                   if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                       splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                   }else{
                       splicingSql.append(fildValue);
                   }
                   splicingSql.append(")");
               }else{
                   splicingSql.append(",");
                   if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                       splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                   }else{
                       splicingSql.append(fildValue);
                   }
               }
           }
       }
       return splicingSql.toString();
   }
    /**更新加密数据处理
     * @param sql sql语句
     * @param aesKeysTable aes秘钥
     * @return
     */
    public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        //装载重写的sql语句
        StringBuilder splicingSql = new StringBuilder();
        //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLStatement sqlStatement = parser.parseStatement();
        //获取格式化的slq语句
        sql = sqlStatement.toString();
        MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement;
        String insertName = updateStatement.getTableName().getSimpleName();
        String[] datas = sql.split("WHERE",2);
        Map<String,String> aesKeys = aesKeysTable.get(insertName);
        splicingSql.append("UPDATE "+insertName+" SET ");
        String aeskey = null;
        String fildValue = null;
        List<SQLUpdateSetItem> items = updateStatement.getItems();
        for(int i = 0;i<items.size();i++){
            if(i != 0){
                splicingSql.append(",");
            }
            SQLUpdateSetItem item = items.get(i);
            //查询改字段是否需要加密
            aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null);
            fildValue = item.getValue().toString();
            if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                splicingSql.append(item.getColumn()+" = HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
            }else{
                splicingSql.append(item.getColumn()+" = "+fildValue);
            }
        }
        String sqlWhere = " WHERE";
        //把剩下的拼接上来
        if(datas.length > 1){
            for(int i =1;i<datas.length;i++){
                sqlWhere = sqlWhere+datas[i];
            }
            parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
            sqlStatement = parser.parseStatement();
            ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
            sqlStatement.accept(visitorTable);
            //获取表和别名
            Map<String,String> tableMaps = visitorTable.getTableMap();
            tableMaps.put(insertName,null);
            //获取所有的字段
            MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
            sqlStatement.accept(visitor);
            String tableAl = null;
            //遍历所有字段
            Collection<TableStat.Column> columns= visitor.getColumns();
            for(TableStat.Column column:columns){
                aesKeys= aesKeysTable.get(column.getTable());
                if(aesKeys == null){
                    continue;
                }
                aeskey = aesKeys.getOrDefault(column.getName(),null);
                if(StringUtils.isEmpty(aeskey)){
                    continue;
                }
                tableAl = tableMaps.get(column.getTable());
                if(!StringUtils.isEmpty(tableAl)){
                    tableAl = tableAl+"."+column.getName();
                }else{
                    tableAl = column.getName();
                }
                sqlWhere = sqlWhere.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
            }
        }
        splicingSql.append(sqlWhere.toString());
        return splicingSql.toString();
    }
    /**删除加密数据处理
     * @param sql sql语句
     * @param aesKeysTable aes秘钥
     * @return
     */
    public static String deleteSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        //装载重写的sql语句
        StringBuilder splicingSql = new StringBuilder();
        //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLStatement sqlStatement = parser.parseStatement();
        //获取格式化的slq语句
        sql = sqlStatement.toString();
        MySqlDeleteStatement deleteStatement = (MySqlDeleteStatement)sqlStatement;
        String insertName = deleteStatement.getTableName().getSimpleName();
        String[] datas = sql.split("WHERE",2);
        Map<String,String> aesKeys = aesKeysTable.get(insertName);
        splicingSql.append("DELETE FROM "+insertName);
        String aeskey = null;
        String sqlWhere = " WHERE";
        //把剩下的拼接上来
        if(datas.length > 1){
            for(int i =1;i<datas.length;i++){
                sqlWhere = sqlWhere+datas[i];
            }
            parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
            sqlStatement = parser.parseStatement();
            ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
            sqlStatement.accept(visitorTable);
            //获取表和别名
            Map<String,String> tableMaps = visitorTable.getTableMap();
            tableMaps.put(insertName,null);
            //获取所有的字段
            MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
            sqlStatement.accept(visitor);
            String tableAl = null;
            //遍历所有字段
            Collection<TableStat.Column> columns= visitor.getColumns();
            for(TableStat.Column column:columns){
                aesKeys= aesKeysTable.get(column.getTable());
                if(aesKeys == null){
                    continue;
                }
                aeskey = aesKeys.getOrDefault(column.getName(),null);
                if(StringUtils.isEmpty(aeskey)){
                    continue;
                }
                tableAl = tableMaps.get(column.getTable());
                if(!StringUtils.isEmpty(tableAl)){
                    tableAl = tableAl+"."+column.getName();
                }else{
                    tableAl = column.getName();
                }
                sqlWhere = sqlWhere.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
            }
        }
        splicingSql.append(sqlWhere.toString());
        return splicingSql.toString();
    }
}
src/main/java/com/hx/mybatis/aes/springbean/VariableAesKey.java
File was renamed from src/main/java/com/hx/springbean/VariableAesKey.java
@@ -1,8 +1,8 @@
package com.hx.springbean;
package com.hx.mybatis.aes.springbean;
import com.gitee.sunchenbin.mybatis.actable.annotation.Table;
import com.hx.common.annotations.MysqlHexAes;
import com.hx.util.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
@@ -24,8 +24,16 @@
    @Resource
    private ConstantBean constantBean;
    /**存储AES的秘钥*/
    /**是否已经启动完*/
    public static int isRun = 0;
    /**存储所有AES的秘钥*/
    public static Map<String,String> aesKeys = new HashMap<>();
    /**根据表明来存储AES秘钥*/
    public static Map<String,Map<String,String>> aesKeysTable = new HashMap<>();
    /**固定的aes秘钥*/
    public static String AES_KEY = null;
    /**存储AES秘钥*/
    public static void setAesKey(String aesKeyFild,String aesKey){
@@ -33,7 +41,14 @@
    }
    /**获取AES秘钥*/
    public static String getAesKey(String aesKeyFild){
        return aesKeys.get(aesKeyFild);
        if(aesKeyFild == null){
            return AES_KEY;
        }
        if(StringUtils.isEmpty(aesKeys.get(aesKeyFild))){
            return AES_KEY;
        }else {
            return  aesKeys.get(aesKeyFild);
        }
    }
    /**
@@ -41,11 +56,29 @@
     */
    @PostConstruct
    public void VariableAesKey(){
        isRun = 1;
        //项目启动的时候填入
        System.err.println("扫描获取AES:" + constantBean.getPackPath());
        AES_KEY = constantBean.getFixedAesKey();
        if(StringUtils.noNull(constantBean.getPackPath())){
            Set<Class<?>> classes = classData(constantBean.getPackPath());
            Map<String,String> aesKeysFild = new HashMap<>();
            boolean isAes = false;
            String tableName = null;
            for(Class<?> cl:classes){
                //表名称
                Table table = cl.getAnnotation(Table.class);
                if(table == null){
                    continue;
                }
                tableName = table.name();
                aesKeysFild = new HashMap<>();
                isAes = false;
                // 取得本类的全部属性
                Field[] fields = cl.getDeclaredFields();
                fields = getPatentFields(fields,cl);
@@ -56,32 +89,31 @@
                        // 根据注解类型返回方法的指定类型注解
                        MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class);
                        String aesKeyField = mysqlHexAes.aesKeyField();
                        //String aesKeyField = mysqlHexAes.aesKeyField();
                        String aesKey = mysqlHexAes.aesKey();
                        if(StringUtils.isEmpty(aesKey)){
                            throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName());
                        }
                        if(StringUtils.noNull(aesKeyField)){
                            String key = aesKeys.get(aesKeyField);
                            if(StringUtils.isEmpty(key)){
                                aesKeys.put(aesKeyField,aesKey);
                            }else{
                                if(!aesKey.equals(key)){
                                    throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
                                }
                            aesKey = constantBean.getFixedAesKey();
                            if(StringUtils.isEmpty(aesKey)){
                                throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName());
                            }
                        }
                        String key = aesKeys.get(field.getName());
                        if(StringUtils.isEmpty(key)){
                            aesKeys.put(field.getName(),aesKey);
                            aesKeysFild.put(field.getName(),aesKey);
                            isAes = true;
                        }else{
                            String key = aesKeys.get(field.getName());
                            if(StringUtils.isEmpty(key)){
                                aesKeys.put(field.getName(),aesKey);
                            }else{
                                if(!aesKey.equals(key)){
                                    throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
                                }
                            isAes = true;
                            aesKeysFild.put(field.getName(),aesKey);
                            if(!aesKey.equals(key)){
                                throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
                            }
                        }
                    }
                }
                if(isAes){
                    aesKeysTable.put(tableName,aesKeysFild);
                }
            }
        }
@@ -90,31 +122,35 @@
    /**获取包下面的所有文件*/
    public static Set<Class<?>> classData(String packPath){
        Set<Class<?>> classes = new LinkedHashSet<Class<?>>();
        // 是否循环迭代
        boolean recursive = true;
        // 获取包的名字 并进行替换
        String packageName = packPath;
        String packageDirName = packageName.replace('.', '/');
        // 定义一个枚举的集合 并进行循环来处理这个目录下的things
        Enumeration<URL> dirs;
        try{
            dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
            // 循环迭代下去
            while (dirs.hasMoreElements()){
                // 获取下一个元素
                URL url = dirs.nextElement();
                // 得到协议的名称
                String protocol = url.getProtocol();
                // 如果是以文件的形式保存在服务器上
                if ("file".equals(protocol)) {
                    // 获取包的物理路径
                    String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                    // 以文件的方式扫描整个包下的文件 并添加到集合中
                    findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
        //截取
        String[] packPaths = packPath.split(";|,");
        for( String packageName : packPaths){
            // 是否循环迭代
            boolean recursive = true;
            // 获取包的名字 并进行替换
            String packageDirName = packageName.replace('.', '/');
            // 定义一个枚举的集合 并进行循环来处理这个目录下的things
            Enumeration<URL> dirs;
            try{
                dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
                // 循环迭代下去
                while (dirs.hasMoreElements()){
                    // 获取下一个元素
                    URL url = dirs.nextElement();
                    // 得到协议的名称
                    String protocol = url.getProtocol();
                    // 如果是以文件的形式保存在服务器上
                    if ("file".equals(protocol)) {
                        // 获取包的物理路径
                        String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                        // 以文件的方式扫描整个包下的文件 并添加到集合中
                        findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
                    }
                }
            }catch (IOException e){
                e.printStackTrace();
            }
        }catch (IOException e){
            e.printStackTrace();
        }
        return classes;
    }
src/main/java/com/hx/springbean/ConstantBean.java
File was deleted