chenjiahe
2022-04-06 591b9669678180c4ef9171169a88f69df84496e7
src/main/java/com/hx/mybatis/aes/springbean/SqlUtils.java
@@ -2,11 +2,9 @@
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLObject;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLInListExpr;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.mysql.ast.clause.MySqlSelectIntoStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
@@ -18,6 +16,7 @@
import com.alibaba.druid.util.JdbcUtils;
import com.hx.util.StringUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
@@ -69,31 +68,37 @@
            if(sqlSelect.length() > 7){
                sqlSelect.append(",");
            }
            expr = sqlSelectItem.getExpr().toString();
            if(expr.indexOf("SELECT") == -1){
            out.delete(0, out.length()) ;
            sqlSelectItem.accept(sqlastOutputVisitor) ;
            expr = out.toString();
            sqlSelect.append(expr);
           /* if(expr.indexOf("SELECT") == -1){
                sqlSelect.append(expr);
                if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                    sqlSelect.append(" AS "+sqlSelectItem.getAlias());
                }
            }else{
                sqlSelect.append("(");
                sqlSelect.append(selectSqlHandle(expr,aesKeysTable,tableMaps,columns));
                sqlSelect.append(")");
                if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                //sqlSelect.append("(");
                sqlSelect.append(expr);
                //sqlSelect.append(")");
               *//* if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                    sqlSelect.append(" AS "+sqlSelectItem.getAlias());
                }
            }
                }*//*
            }*/
        }
        //解析from
        out.delete(0, out.length()) ;
        sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" FROM "+out);
        if(sqlSelectQuery.getFrom() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" FROM "+out);
        }
        //解析where
        out.delete(0, out.length()) ;
        sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" WHERE "+out+" ");
        if(sqlSelectQuery.getWhere() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" WHERE "+out+" ");
        }
        if(sqlSelectQuery.getGroupBy() != null){
            out.delete(0, out.length()) ;
@@ -133,12 +138,20 @@
                }else{
                    tableAl = column.getName();
                }
                sql = sql.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
                sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }
        return sqlSelect.toString()+sql;
    }
    /**
     * 处理select返回字段的参数
     * @param sql
     * @param aesKeysTable
     * @param tableMaps
     * @param columns
     * @return
     */
    public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable
            ,Map<String,String> tableMaps,Collection<TableStat.Column> columns){
@@ -183,14 +196,18 @@
        }
        //解析from
        out.delete(0, out.length()) ;
        sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" FROM "+out);
        if(sqlSelectQuery.getFrom() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" FROM "+out);
        }
        //解析where
        out.delete(0, out.length()) ;
        sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
        sqlWhere.append(" WHERE "+out+" ");
        if(sqlSelectQuery.getWhere() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" WHERE "+out+" ");
        }
        if(sqlSelectQuery.getGroupBy() != null){
            out.delete(0, out.length()) ;
@@ -230,7 +247,7 @@
                }else{
                    tableAl = column.getName();
                }
                sql = sql.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
                sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }
        return sqlSelect.toString()+sql;
@@ -241,72 +258,72 @@
     * @param aesKeysTable aes秘钥
     * @return
     */
   public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
       //装载重写的sql语句
       StringBuilder splicingSql = new StringBuilder();
    public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        //装载重写的sql语句
        StringBuilder splicingSql = new StringBuilder();
       sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
       String[] datas = sql.split("VALUES",2);
        sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
        String[] datas = sql.split("VALUES",2);
       splicingSql.append(datas[0]+"VALUES ");
        splicingSql.append(datas[0]+"VALUES ");
       //重新拼接SQL语句
        //重新拼接SQL语句
       //解析sql语句
       MySqlStatementParser parser = new MySqlStatementParser(sql);
       SQLStatement statement = parser.parseStatement();
       MySqlInsertStatement insert = (MySqlInsertStatement)statement;
        //解析sql语句
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLStatement statement = parser.parseStatement();
        MySqlInsertStatement insert = (MySqlInsertStatement)statement;
       String insertName = insert.getTableName().getSimpleName();
        String insertName = insert.getTableName().getSimpleName();
       //根据表名称获取到AES秘钥
       Map<String,String> aesKeys= aesKeysTable.get(insertName);
       if(aesKeys == null){
           return sql;
       }
        //根据表名称获取到AES秘钥
        Map<String,String> aesKeys= aesKeysTable.get(insertName);
        if(aesKeys == null){
            return sql;
        }
       //获取所有的字段
       List<SQLExpr> columns = insert.getColumns();
        //获取所有的字段
        List<SQLExpr> columns = insert.getColumns();
       String fildValue = null;
       String aeskey = null;
       //遍历值
       List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
       for(int j = 0; j<vcl.size(); j++){
           if( j != 0){
               splicingSql.append(",");
           }
           for(int i = 0;i < columns.size();i++){
               //查询改字段是否需要加密
               aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
               fildValue = vcl.get(j).getValues().get(i).toString();
               if(i == 0){
                   splicingSql.append("(");
                   if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                       splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                   }else{
                       splicingSql.append(fildValue);
                   }
               }else if(i == columns.size()-1){
                   splicingSql.append(",");
                   if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                       splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                   }else{
                       splicingSql.append(fildValue);
                   }
                   splicingSql.append(")");
               }else{
                   splicingSql.append(",");
                   if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                       splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                   }else{
                       splicingSql.append(fildValue);
                   }
               }
           }
       }
       return splicingSql.toString();
   }
        String fildValue = null;
        String aeskey = null;
        //遍历值
        List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
        for(int j = 0; j<vcl.size(); j++){
            if( j != 0){
                splicingSql.append(",");
            }
            for(int i = 0;i < columns.size();i++){
                //查询改字段是否需要加密
                aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
                fildValue = vcl.get(j).getValues().get(i).toString();
                if(i == 0){
                    splicingSql.append("(");
                    if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                        splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                    }else{
                        splicingSql.append(fildValue);
                    }
                }else if(i == columns.size()-1){
                    splicingSql.append(",");
                    if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                        splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                    }else{
                        splicingSql.append(fildValue);
                    }
                    splicingSql.append(")");
                }else{
                    splicingSql.append(",");
                    if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                        splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                    }else{
                        splicingSql.append(fildValue);
                    }
                }
            }
        }
        return splicingSql.toString();
    }
    /**更新加密数据处理
     * @param sql sql语句
@@ -314,6 +331,7 @@
     * @return
     */
    public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        //装载重写的sql语句
        StringBuilder splicingSql = new StringBuilder();
@@ -323,7 +341,6 @@
        //获取格式化的slq语句
        sql = sqlStatement.toString();
        MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement;
        String insertName = updateStatement.getTableName().getSimpleName();
@@ -331,9 +348,11 @@
        String[] datas = sql.split("WHERE",2);
        Map<String,String> aesKeys = aesKeysTable.get(insertName);
        if(aesKeys == null){
            return sql;
        }
        splicingSql.append("UPDATE "+insertName+" SET ");
        String aeskey = null;
        String fildValue = null;
        List<SQLUpdateSetItem> items = updateStatement.getItems();
@@ -341,9 +360,7 @@
            if(i != 0){
                splicingSql.append(",");
            }
            SQLUpdateSetItem item = items.get(i);
            //查询改字段是否需要加密
            aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null);
@@ -354,7 +371,6 @@
                splicingSql.append(item.getColumn()+" = "+fildValue);
            }
        }
        String sqlWhere = " WHERE";
        //把剩下的拼接上来
        if(datas.length > 1){
@@ -395,7 +411,7 @@
                }else{
                    tableAl = column.getName();
                }
                sqlWhere = sqlWhere.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
                sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }
@@ -425,6 +441,9 @@
        String[] datas = sql.split("WHERE",2);
        Map<String,String> aesKeys = aesKeysTable.get(insertName);
        if(aesKeys == null){
            return sql;
        }
        splicingSql.append("DELETE FROM "+insertName);
@@ -470,7 +489,7 @@
                }else{
                    tableAl = column.getName();
                }
                sqlWhere = sqlWhere.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
                sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }