ChenJiaHe
2022-01-22 f12d2c4c92dcd7929bce4d2df74f48f8abc08350
src/main/java/com/hx/mybatis/aes/springbean/SqlUtils.java
New file
@@ -0,0 +1,500 @@
package com.hx.mybatis.aes.springbean;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLObject;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
import com.alibaba.druid.stat.TableStat;
import com.alibaba.druid.util.JdbcConstants;
import com.alibaba.druid.util.JdbcUtils;
import com.hx.util.StringUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
/**
 * sql语句处理工具
 * @author CJH 2022-01-12
 */
public class SqlUtils {
    /**查询加密数据处理,只对查询做处理,select返回不做处理
     * @param sql sql语句
     * @param aesKeysTable aes秘钥
     * @return
     */
    public static String selectSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
        //获取格式化的slq语句
        sql = sqlStatement.toString();
        //解析select查询
        //SQLSelect sqlSelect = sqlStatement.getSelect()
        //获取sql查询块
        SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
        StringBuffer out = new StringBuffer() ;
        //创建sql解析的标准化输出
        SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
        //获取表和别名
        ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
        sqlStatement.accept(visitorTable);
        Map<String,String> tableMaps = visitorTable.getTableMap();
        //获取所有的字段
        MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
        sqlStatement.accept(visitor);
        //遍历所有字段
        Collection<TableStat.Column> columns= visitor.getColumns();
        StringBuilder sqlWhere = new StringBuilder();
        StringBuilder sqlSelect = new StringBuilder();
        String expr = null;
        sqlSelect.append("SELECT ");
        //解析select返回的数据字段项
        for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
            if(sqlSelect.length() > 7){
                sqlSelect.append(",");
            }
            out.delete(0, out.length()) ;
            sqlSelectItem.accept(sqlastOutputVisitor) ;
            expr = out.toString();
            sqlSelect.append(expr);
           /* if(expr.indexOf("SELECT") == -1){
                sqlSelect.append(expr);
            }else{
                //sqlSelect.append("(");
                sqlSelect.append(expr);
                //sqlSelect.append(")");
               *//* if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                    sqlSelect.append(" AS "+sqlSelectItem.getAlias());
                }*//*
            }*/
        }
        //解析from
        if(sqlSelectQuery.getFrom() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" FROM "+out);
        }
        //解析where
        if(sqlSelectQuery.getWhere() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" WHERE "+out+" ");
        }
        if(sqlSelectQuery.getGroupBy() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" "+out);
        }
        if(sqlSelectQuery.getOrderBy() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" "+out);
        }
        if(sqlSelectQuery.getLimit() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" "+out);
        }
        //处理where需要加密得字段
        sql = sqlWhere.toString();
        if(!StringUtils.isEmpty(sql)){
            Map<String,String> aesKeys = null;
            String aeskey = null;
            //把剩下的拼接上来
            String tableAl = null;
            for(TableStat.Column column:columns){
                aesKeys= aesKeysTable.get(column.getTable());
                if(aesKeys == null){
                    continue;
                }
                aeskey = aesKeys.getOrDefault(column.getName(),null);
                if(StringUtils.isEmpty(aeskey)){
                    continue;
                }
                tableAl = tableMaps.get(column.getTable());
                if(!StringUtils.isEmpty(tableAl)){
                    tableAl = tableAl+"."+column.getName();
                }else{
                    tableAl = column.getName();
                }
                sql = sql.replaceAll("\\b"+tableAl+"\\b","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }
        return sqlSelect.toString()+sql;
    }
    /**
     * 处理select返回字段的参数
     * @param sql
     * @param aesKeysTable
     * @param tableMaps
     * @param columns
     * @return
     */
    public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable
            ,Map<String,String> tableMaps,Collection<TableStat.Column> columns){
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
        //获取格式化的slq语句
        sql = sqlStatement.toString();
        //解析select查询
        //SQLSelect sqlSelect = sqlStatement.getSelect() ;
        //获取sql查询块
        SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
        StringBuffer out = new StringBuffer() ;
        //创建sql解析的标准化输出
        SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
        StringBuilder sqlWhere = new StringBuilder();
        StringBuilder sqlSelect = new StringBuilder();
        String expr = null;
        sqlSelect.append("SELECT ");
        //解析select返回的数据字段项
        for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
            if(sqlSelect.length() > 7){
                sqlSelect.append(",");
            }
            expr = sqlSelectItem.getExpr().toString();
            if(expr.indexOf("SELECT") == -1){
                sqlSelect.append(expr);
                if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                    sqlSelect.append(" AS "+sqlSelectItem.getAlias());
                }
            }else{
                sqlSelect.append("(");
                selectSqlHandle(expr,aesKeysTable,tableMaps,columns);
                sqlSelect.append(")");
                if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
                    sqlSelect.append(" AS "+sqlSelectItem.getAlias());
                }
            }
        }
        //解析from
        if(sqlSelectQuery.getFrom() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" FROM "+out);
        }
        //解析where
        if(sqlSelectQuery.getWhere() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" WHERE "+out+" ");
        }
        if(sqlSelectQuery.getGroupBy() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" "+out);
        }
        if(sqlSelectQuery.getOrderBy() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" "+out);
        }
        if(sqlSelectQuery.getLimit() != null){
            out.delete(0, out.length()) ;
            sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
            sqlWhere.append(" "+out);
        }
        sql = sqlWhere.toString();
        if(!StringUtils.isEmpty(sql)){
            Map<String,String> aesKeys = null;
            String aeskey = null;
            //把剩下的拼接上来
            String tableAl = null;
            for(TableStat.Column column:columns){
                aesKeys= aesKeysTable.get(column.getTable());
                if(aesKeys == null){
                    continue;
                }
                aeskey = aesKeys.getOrDefault(column.getName(),null);
                if(StringUtils.isEmpty(aeskey)){
                    continue;
                }
                tableAl = tableMaps.get(column.getTable());
                if(!StringUtils.isEmpty(tableAl)){
                    tableAl = tableAl+"."+column.getName();
                }else{
                    tableAl = column.getName();
                }
                sql = sql.replaceAll("\\b"+tableAl+"\\b","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }
        return sqlSelect.toString()+sql;
    }
    /**新增加密数据处理
     * @param sql sql语句
     * @param aesKeysTable aes秘钥
     * @return
     */
    public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        //装载重写的sql语句
        StringBuilder splicingSql = new StringBuilder();
        sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
        String[] datas = sql.split("VALUES",2);
        splicingSql.append(datas[0]+"VALUES ");
        //重新拼接SQL语句
        //解析sql语句
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLStatement statement = parser.parseStatement();
        MySqlInsertStatement insert = (MySqlInsertStatement)statement;
        String insertName = insert.getTableName().getSimpleName();
        //根据表名称获取到AES秘钥
        Map<String,String> aesKeys= aesKeysTable.get(insertName);
        if(aesKeys == null){
            return sql;
        }
        //获取所有的字段
        List<SQLExpr> columns = insert.getColumns();
        String fildValue = null;
        String aeskey = null;
        //遍历值
        List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
        for(int j = 0; j<vcl.size(); j++){
            if( j != 0){
                splicingSql.append(",");
            }
            for(int i = 0;i < columns.size();i++){
                //查询改字段是否需要加密
                aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
                fildValue = vcl.get(j).getValues().get(i).toString();
                if(i == 0){
                    splicingSql.append("(");
                    if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                        splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                    }else{
                        splicingSql.append(fildValue);
                    }
                }else if(i == columns.size()-1){
                    splicingSql.append(",");
                    if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                        splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                    }else{
                        splicingSql.append(fildValue);
                    }
                    splicingSql.append(")");
                }else{
                    splicingSql.append(",");
                    if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                        splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
                    }else{
                        splicingSql.append(fildValue);
                    }
                }
            }
        }
        return splicingSql.toString();
    }
    /**更新加密数据处理
     * @param sql sql语句
     * @param aesKeysTable aes秘钥
     * @return
     */
    public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        //装载重写的sql语句
        StringBuilder splicingSql = new StringBuilder();
        //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLStatement sqlStatement = parser.parseStatement();
        //获取格式化的slq语句
        sql = sqlStatement.toString();
        MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement;
        String insertName = updateStatement.getTableName().getSimpleName();
        String[] datas = sql.split("WHERE",2);
        Map<String,String> aesKeys = aesKeysTable.get(insertName);
        if(aesKeys == null){
            return sql;
        }
        splicingSql.append("UPDATE "+insertName+" SET ");
        String aeskey = null;
        String fildValue = null;
        List<SQLUpdateSetItem> items = updateStatement.getItems();
        for(int i = 0;i<items.size();i++){
            if(i != 0){
                splicingSql.append(",");
            }
            SQLUpdateSetItem item = items.get(i);
            //查询改字段是否需要加密
            aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null);
            fildValue = item.getValue().toString();
            if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
                splicingSql.append(item.getColumn()+" = HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
            }else{
                splicingSql.append(item.getColumn()+" = "+fildValue);
            }
        }
        String sqlWhere = " WHERE";
        //把剩下的拼接上来
        if(datas.length > 1){
            for(int i =1;i<datas.length;i++){
                sqlWhere = sqlWhere+datas[i];
            }
            parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
            sqlStatement = parser.parseStatement();
            ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
            sqlStatement.accept(visitorTable);
            //获取表和别名
            Map<String,String> tableMaps = visitorTable.getTableMap();
            tableMaps.put(insertName,null);
            //获取所有的字段
            MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
            sqlStatement.accept(visitor);
            String tableAl = null;
            //遍历所有字段
            Collection<TableStat.Column> columns= visitor.getColumns();
            for(TableStat.Column column:columns){
                aesKeys= aesKeysTable.get(column.getTable());
                if(aesKeys == null){
                    continue;
                }
                aeskey = aesKeys.getOrDefault(column.getName(),null);
                if(StringUtils.isEmpty(aeskey)){
                    continue;
                }
                tableAl = tableMaps.get(column.getTable());
                if(!StringUtils.isEmpty(tableAl)){
                    tableAl = tableAl+"."+column.getName();
                }else{
                    tableAl = column.getName();
                }
                sqlWhere = sqlWhere.replaceAll("\\b"+tableAl+"\\b","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }
        splicingSql.append(sqlWhere.toString());
        return splicingSql.toString();
    }
    /**删除加密数据处理
     * @param sql sql语句
     * @param aesKeysTable aes秘钥
     * @return
     */
    public static String deleteSql(String sql,Map<String,Map<String,String>> aesKeysTable){
        //装载重写的sql语句
        StringBuilder splicingSql = new StringBuilder();
        //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLStatement sqlStatement = parser.parseStatement();
        //获取格式化的slq语句
        sql = sqlStatement.toString();
        MySqlDeleteStatement deleteStatement = (MySqlDeleteStatement)sqlStatement;
        String insertName = deleteStatement.getTableName().getSimpleName();
        String[] datas = sql.split("WHERE",2);
        Map<String,String> aesKeys = aesKeysTable.get(insertName);
        if(aesKeys == null){
            return sql;
        }
        splicingSql.append("DELETE FROM "+insertName);
        String aeskey = null;
        String sqlWhere = " WHERE";
        //把剩下的拼接上来
        if(datas.length > 1){
            for(int i =1;i<datas.length;i++){
                sqlWhere = sqlWhere+datas[i];
            }
            parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
            sqlStatement = parser.parseStatement();
            ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
            sqlStatement.accept(visitorTable);
            //获取表和别名
            Map<String,String> tableMaps = visitorTable.getTableMap();
            tableMaps.put(insertName,null);
            //获取所有的字段
            MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
            sqlStatement.accept(visitor);
            String tableAl = null;
            //遍历所有字段
            Collection<TableStat.Column> columns= visitor.getColumns();
            for(TableStat.Column column:columns){
                aesKeys= aesKeysTable.get(column.getTable());
                if(aesKeys == null){
                    continue;
                }
                aeskey = aesKeys.getOrDefault(column.getName(),null);
                if(StringUtils.isEmpty(aeskey)){
                    continue;
                }
                tableAl = tableMaps.get(column.getTable());
                if(!StringUtils.isEmpty(tableAl)){
                    tableAl = tableAl+"."+column.getName();
                }else{
                    tableAl = column.getName();
                }
                sqlWhere = sqlWhere.replaceAll("\\b"+tableAl+"\\b","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
            }
        }
        splicingSql.append(sqlWhere.toString());
        return splicingSql.toString();
    }
}