New file |
| | |
| | | package com.hx.mybatis.aes.springbean; |
| | | |
| | | import com.alibaba.druid.sql.SQLUtils; |
| | | import com.alibaba.druid.sql.ast.SQLExpr; |
| | | import com.alibaba.druid.sql.ast.SQLStatement; |
| | | import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; |
| | | import com.alibaba.druid.sql.ast.expr.SQLInListExpr; |
| | | import com.alibaba.druid.sql.ast.statement.*; |
| | | import com.alibaba.druid.sql.dialect.mysql.ast.clause.MySqlSelectIntoStatement; |
| | | import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement; |
| | | import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; |
| | | import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement; |
| | | import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; |
| | | import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor; |
| | | import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor; |
| | | import com.alibaba.druid.stat.TableStat; |
| | | import com.alibaba.druid.util.JdbcConstants; |
| | | import com.alibaba.druid.util.JdbcUtils; |
| | | import com.hx.util.StringUtils; |
| | | |
| | | import java.util.Collection; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | /** |
| | | * sql语句处理工具 |
| | | * @author CJH 2022-01-12 |
| | | */ |
| | | public class SqlUtils { |
| | | |
| | | /**查询加密数据处理,只对查询做处理,select返回不做处理 |
| | | * @param sql sql语句 |
| | | * @param aesKeysTable aes秘钥 |
| | | * @return |
| | | */ |
| | | public static String selectSql(String sql,Map<String,Map<String,String>> aesKeysTable){ |
| | | |
| | | MySqlStatementParser parser = new MySqlStatementParser(sql); |
| | | SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect(); |
| | | //获取格式化的slq语句 |
| | | sql = sqlStatement.toString(); |
| | | |
| | | //解析select查询 |
| | | //SQLSelect sqlSelect = sqlStatement.getSelect() |
| | | //获取sql查询块 |
| | | SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ; |
| | | StringBuffer out = new StringBuffer() ; |
| | | //创建sql解析的标准化输出 |
| | | SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ; |
| | | |
| | | //获取表和别名 |
| | | ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor(); |
| | | sqlStatement.accept(visitorTable); |
| | | Map<String,String> tableMaps = visitorTable.getTableMap(); |
| | | |
| | | //获取所有的字段 |
| | | MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor(); |
| | | sqlStatement.accept(visitor); |
| | | //遍历所有字段 |
| | | Collection<TableStat.Column> columns= visitor.getColumns(); |
| | | |
| | | StringBuilder sqlWhere = new StringBuilder(); |
| | | |
| | | StringBuilder sqlSelect = new StringBuilder(); |
| | | String expr = null; |
| | | sqlSelect.append("SELECT "); |
| | | //解析select返回的数据字段项 |
| | | for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) { |
| | | if(sqlSelect.length() > 7){ |
| | | sqlSelect.append(","); |
| | | } |
| | | expr = sqlSelectItem.getExpr().toString(); |
| | | if(expr.indexOf("SELECT") == -1){ |
| | | sqlSelect.append(expr); |
| | | if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){ |
| | | sqlSelect.append(" AS "+sqlSelectItem.getAlias()); |
| | | } |
| | | }else{ |
| | | sqlSelect.append("("); |
| | | sqlSelect.append(selectSqlHandle(expr,aesKeysTable,tableMaps,columns)); |
| | | sqlSelect.append(")"); |
| | | if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){ |
| | | sqlSelect.append(" AS "+sqlSelectItem.getAlias()); |
| | | } |
| | | } |
| | | } |
| | | |
| | | //解析from |
| | | out.delete(0, out.length()) ; |
| | | sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ; |
| | | sqlWhere.append(" FROM "+out); |
| | | |
| | | //解析where |
| | | out.delete(0, out.length()) ; |
| | | sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ; |
| | | sqlWhere.append(" WHERE "+out+" "); |
| | | |
| | | if(sqlSelectQuery.getGroupBy() != null){ |
| | | out.delete(0, out.length()) ; |
| | | sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ; |
| | | sqlWhere.append(" "+out); |
| | | } |
| | | if(sqlSelectQuery.getOrderBy() != null){ |
| | | out.delete(0, out.length()) ; |
| | | sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ; |
| | | sqlWhere.append(" "+out); |
| | | } |
| | | if(sqlSelectQuery.getLimit() != null){ |
| | | out.delete(0, out.length()) ; |
| | | sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ; |
| | | sqlWhere.append(" "+out); |
| | | } |
| | | |
| | | //处理where需要加密得字段 |
| | | sql = sqlWhere.toString(); |
| | | if(!StringUtils.isEmpty(sql)){ |
| | | Map<String,String> aesKeys = null; |
| | | String aeskey = null; |
| | | //把剩下的拼接上来 |
| | | String tableAl = null; |
| | | for(TableStat.Column column:columns){ |
| | | aesKeys= aesKeysTable.get(column.getTable()); |
| | | if(aesKeys == null){ |
| | | continue; |
| | | } |
| | | aeskey = aesKeys.getOrDefault(column.getName(),null); |
| | | if(StringUtils.isEmpty(aeskey)){ |
| | | continue; |
| | | } |
| | | tableAl = tableMaps.get(column.getTable()); |
| | | if(!StringUtils.isEmpty(tableAl)){ |
| | | tableAl = tableAl+"."+column.getName(); |
| | | }else{ |
| | | tableAl = column.getName(); |
| | | } |
| | | sql = sql.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') "); |
| | | } |
| | | } |
| | | return sqlSelect.toString()+sql; |
| | | } |
| | | |
| | | public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable |
| | | ,Map<String,String> tableMaps,Collection<TableStat.Column> columns){ |
| | | |
| | | |
| | | MySqlStatementParser parser = new MySqlStatementParser(sql); |
| | | SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect(); |
| | | //获取格式化的slq语句 |
| | | sql = sqlStatement.toString(); |
| | | |
| | | //解析select查询 |
| | | //SQLSelect sqlSelect = sqlStatement.getSelect() ; |
| | | //获取sql查询块 |
| | | SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ; |
| | | StringBuffer out = new StringBuffer() ; |
| | | //创建sql解析的标准化输出 |
| | | SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ; |
| | | |
| | | StringBuilder sqlWhere = new StringBuilder(); |
| | | |
| | | StringBuilder sqlSelect = new StringBuilder(); |
| | | String expr = null; |
| | | sqlSelect.append("SELECT "); |
| | | //解析select返回的数据字段项 |
| | | for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) { |
| | | if(sqlSelect.length() > 7){ |
| | | sqlSelect.append(","); |
| | | } |
| | | expr = sqlSelectItem.getExpr().toString(); |
| | | if(expr.indexOf("SELECT") == -1){ |
| | | sqlSelect.append(expr); |
| | | if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){ |
| | | sqlSelect.append(" AS "+sqlSelectItem.getAlias()); |
| | | } |
| | | }else{ |
| | | sqlSelect.append("("); |
| | | selectSqlHandle(expr,aesKeysTable,tableMaps,columns); |
| | | sqlSelect.append(")"); |
| | | if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){ |
| | | sqlSelect.append(" AS "+sqlSelectItem.getAlias()); |
| | | } |
| | | } |
| | | } |
| | | |
| | | //解析from |
| | | out.delete(0, out.length()) ; |
| | | sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ; |
| | | sqlWhere.append(" FROM "+out); |
| | | |
| | | //解析where |
| | | out.delete(0, out.length()) ; |
| | | sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ; |
| | | sqlWhere.append(" WHERE "+out+" "); |
| | | |
| | | if(sqlSelectQuery.getGroupBy() != null){ |
| | | out.delete(0, out.length()) ; |
| | | sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ; |
| | | sqlWhere.append(" "+out); |
| | | } |
| | | if(sqlSelectQuery.getOrderBy() != null){ |
| | | out.delete(0, out.length()) ; |
| | | sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ; |
| | | sqlWhere.append(" "+out); |
| | | } |
| | | if(sqlSelectQuery.getLimit() != null){ |
| | | out.delete(0, out.length()) ; |
| | | sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ; |
| | | sqlWhere.append(" "+out); |
| | | } |
| | | |
| | | sql = sqlWhere.toString(); |
| | | if(!StringUtils.isEmpty(sql)){ |
| | | Map<String,String> aesKeys = null; |
| | | String aeskey = null; |
| | | //把剩下的拼接上来 |
| | | String tableAl = null; |
| | | |
| | | for(TableStat.Column column:columns){ |
| | | aesKeys= aesKeysTable.get(column.getTable()); |
| | | if(aesKeys == null){ |
| | | continue; |
| | | } |
| | | aeskey = aesKeys.getOrDefault(column.getName(),null); |
| | | if(StringUtils.isEmpty(aeskey)){ |
| | | continue; |
| | | } |
| | | tableAl = tableMaps.get(column.getTable()); |
| | | if(!StringUtils.isEmpty(tableAl)){ |
| | | tableAl = tableAl+"."+column.getName(); |
| | | }else{ |
| | | tableAl = column.getName(); |
| | | } |
| | | sql = sql.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') "); |
| | | } |
| | | } |
| | | return sqlSelect.toString()+sql; |
| | | } |
| | | |
| | | /**新增加密数据处理 |
| | | * @param sql sql语句 |
| | | * @param aesKeysTable aes秘钥 |
| | | * @return |
| | | */ |
| | | public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){ |
| | | //装载重写的sql语句 |
| | | StringBuilder splicingSql = new StringBuilder(); |
| | | |
| | | sql = SQLUtils.format(sql, JdbcConstants.MYSQL); |
| | | String[] datas = sql.split("VALUES",2); |
| | | |
| | | splicingSql.append(datas[0]+"VALUES "); |
| | | |
| | | //重新拼接SQL语句 |
| | | |
| | | //解析sql语句 |
| | | MySqlStatementParser parser = new MySqlStatementParser(sql); |
| | | SQLStatement statement = parser.parseStatement(); |
| | | MySqlInsertStatement insert = (MySqlInsertStatement)statement; |
| | | |
| | | String insertName = insert.getTableName().getSimpleName(); |
| | | |
| | | //根据表名称获取到AES秘钥 |
| | | Map<String,String> aesKeys= aesKeysTable.get(insertName); |
| | | if(aesKeys == null){ |
| | | return sql; |
| | | } |
| | | |
| | | //获取所有的字段 |
| | | List<SQLExpr> columns = insert.getColumns(); |
| | | |
| | | String fildValue = null; |
| | | String aeskey = null; |
| | | //遍历值 |
| | | List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList(); |
| | | for(int j = 0; j<vcl.size(); j++){ |
| | | if( j != 0){ |
| | | splicingSql.append(","); |
| | | } |
| | | for(int i = 0;i < columns.size();i++){ |
| | | //查询改字段是否需要加密 |
| | | aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null); |
| | | fildValue = vcl.get(j).getValues().get(i).toString(); |
| | | if(i == 0){ |
| | | splicingSql.append("("); |
| | | if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){ |
| | | splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))"); |
| | | }else{ |
| | | splicingSql.append(fildValue); |
| | | } |
| | | }else if(i == columns.size()-1){ |
| | | splicingSql.append(","); |
| | | if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){ |
| | | splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))"); |
| | | }else{ |
| | | splicingSql.append(fildValue); |
| | | } |
| | | splicingSql.append(")"); |
| | | }else{ |
| | | splicingSql.append(","); |
| | | if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){ |
| | | splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))"); |
| | | }else{ |
| | | splicingSql.append(fildValue); |
| | | } |
| | | } |
| | | } |
| | | } |
| | | return splicingSql.toString(); |
| | | } |
| | | |
| | | /**更新加密数据处理 |
| | | * @param sql sql语句 |
| | | * @param aesKeysTable aes秘钥 |
| | | * @return |
| | | */ |
| | | public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){ |
| | | //装载重写的sql语句 |
| | | StringBuilder splicingSql = new StringBuilder(); |
| | | |
| | | //sql = SQLUtils.format(sql, JdbcConstants.MYSQL); |
| | | MySqlStatementParser parser = new MySqlStatementParser(sql); |
| | | SQLStatement sqlStatement = parser.parseStatement(); |
| | | //获取格式化的slq语句 |
| | | sql = sqlStatement.toString(); |
| | | |
| | | |
| | | MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement; |
| | | |
| | | String insertName = updateStatement.getTableName().getSimpleName(); |
| | | |
| | | String[] datas = sql.split("WHERE",2); |
| | | |
| | | Map<String,String> aesKeys = aesKeysTable.get(insertName); |
| | | |
| | | splicingSql.append("UPDATE "+insertName+" SET "); |
| | | |
| | | String aeskey = null; |
| | | String fildValue = null; |
| | | List<SQLUpdateSetItem> items = updateStatement.getItems(); |
| | | for(int i = 0;i<items.size();i++){ |
| | | if(i != 0){ |
| | | splicingSql.append(","); |
| | | } |
| | | |
| | | SQLUpdateSetItem item = items.get(i); |
| | | |
| | | //查询改字段是否需要加密 |
| | | aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null); |
| | | |
| | | fildValue = item.getValue().toString(); |
| | | if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){ |
| | | splicingSql.append(item.getColumn()+" = HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))"); |
| | | }else{ |
| | | splicingSql.append(item.getColumn()+" = "+fildValue); |
| | | } |
| | | } |
| | | |
| | | String sqlWhere = " WHERE"; |
| | | //把剩下的拼接上来 |
| | | if(datas.length > 1){ |
| | | for(int i =1;i<datas.length;i++){ |
| | | sqlWhere = sqlWhere+datas[i]; |
| | | } |
| | | |
| | | parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere); |
| | | sqlStatement = parser.parseStatement(); |
| | | |
| | | ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor(); |
| | | sqlStatement.accept(visitorTable); |
| | | |
| | | //获取表和别名 |
| | | Map<String,String> tableMaps = visitorTable.getTableMap(); |
| | | tableMaps.put(insertName,null); |
| | | |
| | | //获取所有的字段 |
| | | MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor(); |
| | | sqlStatement.accept(visitor); |
| | | |
| | | String tableAl = null; |
| | | //遍历所有字段 |
| | | Collection<TableStat.Column> columns= visitor.getColumns(); |
| | | for(TableStat.Column column:columns){ |
| | | |
| | | aesKeys= aesKeysTable.get(column.getTable()); |
| | | if(aesKeys == null){ |
| | | continue; |
| | | } |
| | | aeskey = aesKeys.getOrDefault(column.getName(),null); |
| | | if(StringUtils.isEmpty(aeskey)){ |
| | | continue; |
| | | } |
| | | tableAl = tableMaps.get(column.getTable()); |
| | | if(!StringUtils.isEmpty(tableAl)){ |
| | | tableAl = tableAl+"."+column.getName(); |
| | | }else{ |
| | | tableAl = column.getName(); |
| | | } |
| | | sqlWhere = sqlWhere.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') "); |
| | | } |
| | | |
| | | } |
| | | splicingSql.append(sqlWhere.toString()); |
| | | return splicingSql.toString(); |
| | | } |
| | | |
| | | /**删除加密数据处理 |
| | | * @param sql sql语句 |
| | | * @param aesKeysTable aes秘钥 |
| | | * @return |
| | | */ |
| | | public static String deleteSql(String sql,Map<String,Map<String,String>> aesKeysTable){ |
| | | //装载重写的sql语句 |
| | | StringBuilder splicingSql = new StringBuilder(); |
| | | |
| | | //sql = SQLUtils.format(sql, JdbcConstants.MYSQL); |
| | | MySqlStatementParser parser = new MySqlStatementParser(sql); |
| | | SQLStatement sqlStatement = parser.parseStatement(); |
| | | //获取格式化的slq语句 |
| | | sql = sqlStatement.toString(); |
| | | |
| | | MySqlDeleteStatement deleteStatement = (MySqlDeleteStatement)sqlStatement; |
| | | |
| | | String insertName = deleteStatement.getTableName().getSimpleName(); |
| | | |
| | | String[] datas = sql.split("WHERE",2); |
| | | |
| | | Map<String,String> aesKeys = aesKeysTable.get(insertName); |
| | | |
| | | splicingSql.append("DELETE FROM "+insertName); |
| | | |
| | | String aeskey = null; |
| | | |
| | | String sqlWhere = " WHERE"; |
| | | //把剩下的拼接上来 |
| | | if(datas.length > 1){ |
| | | for(int i =1;i<datas.length;i++){ |
| | | sqlWhere = sqlWhere+datas[i]; |
| | | } |
| | | |
| | | parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere); |
| | | sqlStatement = parser.parseStatement(); |
| | | |
| | | ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor(); |
| | | sqlStatement.accept(visitorTable); |
| | | |
| | | //获取表和别名 |
| | | Map<String,String> tableMaps = visitorTable.getTableMap(); |
| | | tableMaps.put(insertName,null); |
| | | |
| | | //获取所有的字段 |
| | | MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor(); |
| | | sqlStatement.accept(visitor); |
| | | |
| | | String tableAl = null; |
| | | //遍历所有字段 |
| | | Collection<TableStat.Column> columns= visitor.getColumns(); |
| | | for(TableStat.Column column:columns){ |
| | | |
| | | aesKeys= aesKeysTable.get(column.getTable()); |
| | | if(aesKeys == null){ |
| | | continue; |
| | | } |
| | | aeskey = aesKeys.getOrDefault(column.getName(),null); |
| | | if(StringUtils.isEmpty(aeskey)){ |
| | | continue; |
| | | } |
| | | tableAl = tableMaps.get(column.getTable()); |
| | | if(!StringUtils.isEmpty(tableAl)){ |
| | | tableAl = tableAl+"."+column.getName(); |
| | | }else{ |
| | | tableAl = column.getName(); |
| | | } |
| | | sqlWhere = sqlWhere.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') "); |
| | | } |
| | | |
| | | } |
| | | splicingSql.append(sqlWhere.toString()); |
| | | return splicingSql.toString(); |
| | | } |
| | | |
| | | } |