chenjiahe
2022-01-14 346eb22b5ab622064d1f61819240656529800f2b
Mysql数据库AES加密工具
3个文件已修改
281 ■■■■■ 已修改文件
src/main/java/com/hx/mybatis/aes/springbean/MySqlInterceptor.java 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/mybatis/aes/springbean/SqlUtils.java 163 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/mybatis/aes/springbean/VariableAesKey.java 117 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/mybatis/aes/springbean/MySqlInterceptor.java
@@ -71,6 +71,7 @@
        }else if(sqlCommandType == SqlCommandType.DELETE){
            sql = SqlUtils.deleteSql(sql, VariableAesKey.aesKeysTable);
        }
        //通过反射修改sql语句
        Field field = boundSql.getClass().getDeclaredField("sql");
        field.setAccessible(true);
src/main/java/com/hx/mybatis/aes/springbean/SqlUtils.java
@@ -3,10 +3,7 @@
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLInListExpr;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.mysql.ast.clause.MySqlSelectIntoStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
@@ -86,14 +83,18 @@
        }
        //解析from
        out.delete(0, out.length()) ;
        sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" FROM "+out);
        if(sqlSelectQuery.getFrom() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" FROM "+out);
        }
        //解析where
        out.delete(0, out.length()) ;
        sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" WHERE "+out+" ");
        if(sqlSelectQuery.getWhere() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" WHERE "+out+" ");
        }
        if(sqlSelectQuery.getGroupBy() != null){
            out.delete(0, out.length()) ;
@@ -183,14 +184,18 @@
        }
        //解析from
        out.delete(0, out.length()) ;
        sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" FROM "+out);
        if(sqlSelectQuery.getFrom() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" FROM "+out);
        }
        //解析where
        out.delete(0, out.length()) ;
        sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" WHERE "+out+" ");
        if(sqlSelectQuery.getWhere() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" WHERE "+out+" ");
        }
        if(sqlSelectQuery.getGroupBy() != null){
            out.delete(0, out.length()) ;
@@ -241,72 +246,72 @@
     * @param aesKeysTable aes秘钥
     * @return
     */
   public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
       //装载重写的sql语句
       StringBuilder splicingSql = new StringBuilder();
    public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        //装载重写的sql语句
        StringBuilder splicingSql = new StringBuilder();
       sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
       String[] datas = sql.split("VALUES",2);
        sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
        String[] datas = sql.split("VALUES",2);
       splicingSql.append(datas[0]+"VALUES ");
        splicingSql.append(datas[0]+"VALUES ");
       //重新拼接SQL语句
        //重新拼接SQL语句
       //解析sql语句
       MySqlStatementParser parser = new MySqlStatementParser(sql);
       SQLStatement statement = parser.parseStatement();
       MySqlInsertStatement insert = (MySqlInsertStatement)statement;
        //解析sql语句
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLStatement statement = parser.parseStatement();
        MySqlInsertStatement insert = (MySqlInsertStatement)statement;
       String insertName = insert.getTableName().getSimpleName();
        String insertName = insert.getTableName().getSimpleName();
       //根据表名称获取到AES秘钥
       Map<String,String> aesKeys= aesKeysTable.get(insertName);
       if(aesKeys == null){
           return sql;
       }
        //根据表名称获取到AES秘钥
        Map<String,String> aesKeys= aesKeysTable.get(insertName);
        if(aesKeys == null){
            return sql;
        }
       //获取所有的字段
       List<SQLExpr> columns = insert.getColumns();
        //获取所有的字段
        List<SQLExpr> columns = insert.getColumns();
       String fildValue = null;
       String aeskey = null;
       //遍历值
       List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
       for(int j = 0; j<vcl.size(); j++){
           if( j != 0){
               splicingSql.append(",");
           }
           for(int i = 0;i < columns.size();i++){
               //查询改字段是否需要加密
               aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
               fildValue = vcl.get(j).getValues().get(i).toString();
               if(i == 0){
                   splicingSql.append("(");
                   if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                       splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                   }else{
                       splicingSql.append(fildValue);
                   }
               }else if(i == columns.size()-1){
                   splicingSql.append(",");
                   if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                       splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                   }else{
                       splicingSql.append(fildValue);
                   }
                   splicingSql.append(")");
               }else{
                   splicingSql.append(",");
                   if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                       splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                   }else{
                       splicingSql.append(fildValue);
                   }
               }
           }
       }
       return splicingSql.toString();
   }
        String fildValue = null;
        String aeskey = null;
        //遍历值
        List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
        for(int j = 0; j<vcl.size(); j++){
            if( j != 0){
                splicingSql.append(",");
            }
            for(int i = 0;i < columns.size();i++){
                //查询改字段是否需要加密
                aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
                fildValue = vcl.get(j).getValues().get(i).toString();
                if(i == 0){
                    splicingSql.append("(");
                    if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                        splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                    }else{
                        splicingSql.append(fildValue);
                    }
                }else if(i == columns.size()-1){
                    splicingSql.append(",");
                    if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                        splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                    }else{
                        splicingSql.append(fildValue);
                    }
                    splicingSql.append(")");
                }else{
                    splicingSql.append(",");
                    if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                        splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                    }else{
                        splicingSql.append(fildValue);
                    }
                }
            }
        }
        return splicingSql.toString();
    }
    /**更新加密数据处理
     * @param sql sql语句
@@ -314,6 +319,7 @@
     * @return
     */
    public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        //装载重写的sql语句
        StringBuilder splicingSql = new StringBuilder();
@@ -323,7 +329,6 @@
        //获取格式化的slq语句
        sql = sqlStatement.toString();
        MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement;
        String insertName = updateStatement.getTableName().getSimpleName();
@@ -331,9 +336,11 @@
        String[] datas = sql.split("WHERE",2);
        Map<String,String> aesKeys = aesKeysTable.get(insertName);
        if(aesKeys == null){
            return sql;
        }
        splicingSql.append("UPDATE "+insertName+" SET ");
        String aeskey = null;
        String fildValue = null;
        List<SQLUpdateSetItem> items = updateStatement.getItems();
@@ -341,9 +348,7 @@
            if(i != 0){
                splicingSql.append(",");
            }
            SQLUpdateSetItem item = items.get(i);
            //查询改字段是否需要加密
            aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null);
@@ -354,7 +359,6 @@
                splicingSql.append(item.getColumn()+" = "+fildValue);
            }
        }
        String sqlWhere = " WHERE";
        //把剩下的拼接上来
        if(datas.length > 1){
@@ -425,6 +429,9 @@
        String[] datas = sql.split("WHERE",2);
        Map<String,String> aesKeys = aesKeysTable.get(insertName);
        if(aesKeys == null){
            return sql;
        }
        splicingSql.append("DELETE FROM "+insertName);
src/main/java/com/hx/mybatis/aes/springbean/VariableAesKey.java
@@ -3,6 +3,8 @@
import com.gitee.sunchenbin.mybatis.actable.annotation.Table;
import com.hx.common.annotations.MysqlHexAes;
import com.hx.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
@@ -11,15 +13,21 @@
import java.io.FileFilter;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLDecoder;
import java.util.*;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
/**
 * 获取指定包里面的AES秘钥
 */
@Component
public class VariableAesKey {
    //log4j日志
    private static Logger logger = LoggerFactory.getLogger(VariableAesKey.class.getName());
    @Resource
    private ConstantBean constantBean;
@@ -59,21 +67,22 @@
        isRun = 1;
        //项目启动的时候填入
        System.err.println("扫描获取AES:" + constantBean.getPackPath());
        logger.info("扫描获取AES的包:" + constantBean.getPackPath());
        AES_KEY = constantBean.getFixedAesKey();
        if(!StringUtils.isEmpty(constantBean.getPackPath())){
            Set<Class<?>> classes = classData(constantBean.getPackPath());
            logger.info("扫描获取AES的包classes:" + classes.size());
            Map<String,String> aesKeysFild = new HashMap<>();
            boolean isAes = false;
            String tableName = null;
            for(Class<?> cl:classes){
                //表名称
                Table table = cl.getAnnotation(Table.class);
                if(table == null){
                boolean hasAnnotation = cl.isAnnotationPresent(Table.class);
                if(!hasAnnotation){
                    continue;
                }
                Table table = cl.getAnnotation(Table.class);
                tableName = table.name();
                aesKeysFild = new HashMap<>();
@@ -84,7 +93,7 @@
                fields = getPatentFields(fields,cl);
                for (Field field:fields) {
                    // 判断方法中是否有指定注解类型的注解
                    boolean hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class);
                    hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class);
                    if (hasAnnotation) {
                        // 根据注解类型返回方法的指定类型注解
                        MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class);
@@ -121,37 +130,85 @@
    /**获取包下面的所有文件*/
    public static Set<Class<?>> classData(String packPath){
        Set<Class<?>> classes = new LinkedHashSet<Class<?>>();
        Set<Class<?>> classes = new LinkedHashSet();
        String[] split = packPath.split(",|;");
        String[] var3 = split;
        int var4 = split.length;
        //截取
        String[] packPaths = packPath.split(";|,");
        for( String packageName : packPaths){
            // 是否循环迭代
        label82:
        for(int var5 = 0; var5 < var4; ++var5) {
            String pack = var3[var5];
            boolean recursive = true;
            // 获取包的名字 并进行替换
            String packageDirName = packageName.replace('.', '/');
            // 定义一个枚举的集合 并进行循环来处理这个目录下的things
            Enumeration<URL> dirs;
            try{
                dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
                // 循环迭代下去
                while (dirs.hasMoreElements()){
                    // 获取下一个元素
                    URL url = dirs.nextElement();
                    // 得到协议的名称
                    String protocol = url.getProtocol();
                    // 如果是以文件的形式保存在服务器上
                    if ("file".equals(protocol)) {
                        // 获取包的物理路径
                        String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                        // 以文件的方式扫描整个包下的文件 并添加到集合中
                        findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
            String packageName = pack;
            String packageDirName = pack.replace('.', '/');
            try {
                Enumeration dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
                while(true) {
                    label75:
                    while(true) {
                        if (!dirs.hasMoreElements()) {
                            continue label82;
                        }
                        URL url = (URL)dirs.nextElement();
                        String protocol = url.getProtocol();
                        if ("file".equals(protocol)) {
                            System.err.println("file类型的扫描:" + pack);
                            String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                            findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
                        } else if ("jar".equals(protocol)) {
                            System.err.println("jar类型的扫描");
                            try {
                                JarFile jar = ((JarURLConnection)url.openConnection()).getJarFile();
                                Enumeration entries = jar.entries();
                                while(true) {
                                    JarEntry entry;
                                    String name;
                                    int idx;
                                    do {
                                        do {
                                            if (!entries.hasMoreElements()) {
                                                continue label75;
                                            }
                                            entry = (JarEntry)entries.nextElement();
                                            name = entry.getName();
                                            if (name.charAt(0) == '/') {
                                                name = name.substring(1);
                                            }
                                        } while(!name.startsWith(packageDirName));
                                        idx = name.lastIndexOf(47);
                                        if (idx != -1) {
                                            packageName = name.substring(0, idx).replace('/', '.');
                                        }
                                    } while(idx == -1 && !recursive);
                                    if (name.endsWith(".class") && !entry.isDirectory()) {
                                        String className = name.substring(packageName.length() + 1, name.length() - 6);
                                        try {
                                            classes.add(Class.forName(packageName + '.' + className));
                                        } catch (ClassNotFoundException var20) {
                                            var20.printStackTrace();
                                        }
                                    }
                                }
                            } catch (IOException var21) {
                                var21.printStackTrace();
                            }
                        }
                    }
                }
            }catch (IOException e){
                e.printStackTrace();
            } catch (IOException var22) {
                var22.printStackTrace();
            }
        }
        return classes;
    }