chenjiahe
2022-06-28 4484acb56810d06ef5c2f78190a93688a61f83a7
src/main/java/com/hx/mybatis/aes/springbean/SqlUtils.java
@@ -3,10 +3,7 @@
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLInListExpr;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.mysql.ast.clause.MySqlSelectIntoStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
@@ -17,6 +14,8 @@
import com.alibaba.druid.util.JdbcConstants;
import com.alibaba.druid.util.JdbcUtils;
import com.hx.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.List;
@@ -27,8 +26,11 @@
 * @author CJH 2022-01-12
 */
public class SqlUtils {
    //log4j日志
    private static Logger logger = LoggerFactory.getLogger(SqlUtils.class.getName());
    /**查询加密数据处理,只对查询做处理,select返回不做处理
    /**查询加密数据处理,只对查询做处理
     * @param sql sql语句
     * @param aesKeysTable aes秘钥
     * @return
@@ -37,13 +39,88 @@
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
        //获取格式化的slq语句
        sql = sqlStatement.toString();
        SQLSelect sqlSelect = sqlStatement.getSelect();
        if (sqlSelect.getQuery() instanceof SQLSelectQueryBlock) {
            // 非union的查询语句
            return selectSqlRoutine( sqlStatement,aesKeysTable);
        } else if (sqlSelect.getQuery() instanceof SQLUnionQuery) {
            // union的查询语句
            return selectSqlUnion( sql, sqlStatement, aesKeysTable);
        }else {
            return selectSqlRoutine( sqlStatement,aesKeysTable);
        }
    }
    /**查询加密数据处理,只对查询做处理,select返回不做处理(Union特殊语句)
     * @param sql sql语句
     * @param aesKeysTable aes秘钥
     * @return
     */
    public static String selectSqlUnion(String sql,SQLSelectStatement sqlStatement,Map<String,Map<String,String>> aesKeysTable){
        //获取表和别名
        ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
        sqlStatement.accept(visitorTable);
        Map<String,String> tableMaps = visitorTable.getTableMap();
        //获取所有的字段
        MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
        sqlStatement.accept(visitor);
        //遍历所有字段
        Collection<TableStat.Column> columns= visitor.getColumns();
        //处理需要加密得字段
        if(!StringUtils.isEmpty(sql)){
            Map<String,String> aesKeys = null;
            String aeskey = null;
            //把剩下的拼接上来
            String tableAl = null;
            for(TableStat.Column column:columns){
                aesKeys= aesKeysTable.get(column.getTable());
                if(aesKeys == null){
                    continue;
                }
                aeskey = aesKeys.getOrDefault(column.getName(),null);
                if(StringUtils.isEmpty(aeskey)){
                    continue;
                }
                tableAl = tableMaps.get(column.getTable());
                if(!StringUtils.isEmpty(tableAl)){
                    tableAl = tableAl+"."+column.getName();
                }else{
                    tableAl = column.getName();
                }
                sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }
        return sql;
    }
    /**查询加密数据处理,只对查询做处理,select返回不做处理(常规语句)
     * @param sqlStatement sql语句
     * @param aesKeysTable aes秘钥
     * @return
     */
    public static String selectSqlRoutine(SQLSelectStatement sqlStatement,Map<String,Map<String,String>> aesKeysTable){
        //解析select查询
        //SQLSelect sqlSelect = sqlStatement.getSelect()
        //获取sql查询块
        SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
        SQLSelectQueryBlock sqlSelectQuery = null;
        boolean b = true;
        try{
            sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
        }catch (Exception e){
            b = false;
            logger.error("解析sql报错:"+e.getMessage());
        }
        if(!b){
            return "err";
        }
        StringBuffer out = new StringBuffer() ;
        //创建sql解析的标准化输出
        SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
@@ -69,31 +146,37 @@
            if(sqlSelect.length() > 7){
                sqlSelect.append(",");
            }
            expr = sqlSelectItem.getExpr().toString();
            if(expr.indexOf("SELECT") == -1){
            out.delete(0, out.length()) ;
            sqlSelectItem.accept(sqlastOutputVisitor) ;
            expr = out.toString();
            sqlSelect.append(expr);
           /* if(expr.indexOf("SELECT") == -1){
                sqlSelect.append(expr);
                if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                    sqlSelect.append(" AS "+sqlSelectItem.getAlias());
                }
            }else{
                sqlSelect.append("(");
                sqlSelect.append(selectSqlHandle(expr,aesKeysTable,tableMaps,columns));
                sqlSelect.append(")");
                if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                //sqlSelect.append("(");
                sqlSelect.append(expr);
                //sqlSelect.append(")");
               *//* if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                    sqlSelect.append(" AS "+sqlSelectItem.getAlias());
                }
            }
                }*//*
            }*/
        }
        //解析from
        out.delete(0, out.length()) ;
        sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" FROM "+out);
        if(sqlSelectQuery.getFrom() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" FROM "+out);
        }
        //解析where
        out.delete(0, out.length()) ;
        sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" WHERE "+out+" ");
        if(sqlSelectQuery.getWhere() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" WHERE "+out+" ");
        }
        if(sqlSelectQuery.getGroupBy() != null){
            out.delete(0, out.length()) ;
@@ -112,7 +195,7 @@
        }
        //处理where需要加密得字段
        sql = sqlWhere.toString();
        String sql = sqlWhere.toString();
        if(!StringUtils.isEmpty(sql)){
            Map<String,String> aesKeys = null;
            String aeskey = null;
@@ -133,12 +216,20 @@
                }else{
                    tableAl = column.getName();
                }
                sql = sql.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
                sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }
        return sqlSelect.toString()+sql;
    }
    /**
     * 处理select返回字段的参数
     * @param sql
     * @param aesKeysTable
     * @param tableMaps
     * @param columns
     * @return
     */
    public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable
            ,Map<String,String> tableMaps,Collection<TableStat.Column> columns){
@@ -183,14 +274,18 @@
        }
        //解析from
        out.delete(0, out.length()) ;
        sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" FROM "+out);
        if(sqlSelectQuery.getFrom() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" FROM "+out);
        }
        //解析where
        out.delete(0, out.length()) ;
        sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" WHERE "+out+" ");
        if(sqlSelectQuery.getWhere() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" WHERE "+out+" ");
        }
        if(sqlSelectQuery.getGroupBy() != null){
            out.delete(0, out.length()) ;
@@ -230,7 +325,7 @@
                }else{
                    tableAl = column.getName();
                }
                sql = sql.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
                sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }
        return sqlSelect.toString()+sql;
@@ -241,72 +336,72 @@
     * @param aesKeysTable aes秘钥
     * @return
     */
   public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
       //装载重写的sql语句
       StringBuilder splicingSql = new StringBuilder();
    public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        //装载重写的sql语句
        StringBuilder splicingSql = new StringBuilder();
       sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
       String[] datas = sql.split("VALUES",2);
        sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
        String[] datas = sql.split("VALUES",2);
       splicingSql.append(datas[0]+"VALUES ");
        splicingSql.append(datas[0]+"VALUES ");
       //重新拼接SQL语句
        //重新拼接SQL语句
       //解析sql语句
       MySqlStatementParser parser = new MySqlStatementParser(sql);
       SQLStatement statement = parser.parseStatement();
       MySqlInsertStatement insert = (MySqlInsertStatement)statement;
        //解析sql语句
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLStatement statement = parser.parseStatement();
        MySqlInsertStatement insert = (MySqlInsertStatement)statement;
       String insertName = insert.getTableName().getSimpleName();
        String insertName = insert.getTableName().getSimpleName();
       //根据表名称获取到AES秘钥
       Map<String,String> aesKeys= aesKeysTable.get(insertName);
       if(aesKeys == null){
           return sql;
       }
        //根据表名称获取到AES秘钥
        Map<String,String> aesKeys= aesKeysTable.get(insertName);
        if(aesKeys == null){
            return sql;
        }
       //获取所有的字段
       List<SQLExpr> columns = insert.getColumns();
        //获取所有的字段
        List<SQLExpr> columns = insert.getColumns();
       String fildValue = null;
       String aeskey = null;
       //遍历值
       List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
       for(int j = 0; j<vcl.size(); j++){
           if( j != 0){
               splicingSql.append(",");
           }
           for(int i = 0;i < columns.size();i++){
               //查询改字段是否需要加密
               aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
               fildValue = vcl.get(j).getValues().get(i).toString();
               if(i == 0){
                   splicingSql.append("(");
                   if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                       splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                   }else{
                       splicingSql.append(fildValue);
                   }
               }else if(i == columns.size()-1){
                   splicingSql.append(",");
                   if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                       splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                   }else{
                       splicingSql.append(fildValue);
                   }
                   splicingSql.append(")");
               }else{
                   splicingSql.append(",");
                   if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                       splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                   }else{
                       splicingSql.append(fildValue);
                   }
               }
           }
       }
       return splicingSql.toString();
   }
        String fildValue = null;
        String aeskey = null;
        //遍历值
        List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
        for(int j = 0; j<vcl.size(); j++){
            if( j != 0){
                splicingSql.append(",");
            }
            for(int i = 0;i < columns.size();i++){
                //查询改字段是否需要加密
                aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
                fildValue = vcl.get(j).getValues().get(i).toString();
                if(i == 0){
                    splicingSql.append("(");
                    if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                        splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                    }else{
                        splicingSql.append(fildValue);
                    }
                }else if(i == columns.size()-1){
                    splicingSql.append(",");
                    if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                        splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                    }else{
                        splicingSql.append(fildValue);
                    }
                    splicingSql.append(")");
                }else{
                    splicingSql.append(",");
                    if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                        splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                    }else{
                        splicingSql.append(fildValue);
                    }
                }
            }
        }
        return splicingSql.toString();
    }
    /**更新加密数据处理
     * @param sql sql语句
@@ -314,6 +409,7 @@
     * @return
     */
    public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        //装载重写的sql语句
        StringBuilder splicingSql = new StringBuilder();
@@ -323,7 +419,6 @@
        //获取格式化的slq语句
        sql = sqlStatement.toString();
        MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement;
        String insertName = updateStatement.getTableName().getSimpleName();
@@ -331,9 +426,11 @@
        String[] datas = sql.split("WHERE",2);
        Map<String,String> aesKeys = aesKeysTable.get(insertName);
        if(aesKeys == null){
            return sql;
        }
        splicingSql.append("UPDATE "+insertName+" SET ");
        String aeskey = null;
        String fildValue = null;
        List<SQLUpdateSetItem> items = updateStatement.getItems();
@@ -341,9 +438,7 @@
            if(i != 0){
                splicingSql.append(",");
            }
            SQLUpdateSetItem item = items.get(i);
            //查询改字段是否需要加密
            aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null);
@@ -354,7 +449,6 @@
                splicingSql.append(item.getColumn()+" = "+fildValue);
            }
        }
        String sqlWhere = " WHERE";
        //把剩下的拼接上来
        if(datas.length > 1){
@@ -395,11 +489,11 @@
                }else{
                    tableAl = column.getName();
                }
                sqlWhere = sqlWhere.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
                sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }
        splicingSql.append(sqlWhere.toString());
        splicingSql.append(sqlWhere);
        return splicingSql.toString();
    }
@@ -425,6 +519,9 @@
        String[] datas = sql.split("WHERE",2);
        Map<String,String> aesKeys = aesKeysTable.get(insertName);
        if(aesKeys == null){
            return sql;
        }
        splicingSql.append("DELETE FROM "+insertName);
@@ -470,11 +567,11 @@
                }else{
                    tableAl = column.getName();
                }
                sqlWhere = sqlWhere.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
                sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }
        splicingSql.append(sqlWhere.toString());
        splicingSql.append(sqlWhere);
        return splicingSql.toString();
    }