chenjiahe
2022-01-19 d4efbe547f8b239ebe04c5a30927ddbff1dfd91d
提交 | 用户 | age
c64e12 1 package com.hx.mybatis.aes.springbean;
C 2
3 import com.alibaba.druid.sql.SQLUtils;
4 import com.alibaba.druid.sql.ast.SQLExpr;
f35d93 5 import com.alibaba.druid.sql.ast.SQLObject;
c64e12 6 import com.alibaba.druid.sql.ast.SQLStatement;
C 7 import com.alibaba.druid.sql.ast.statement.*;
8 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
9 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
10 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
11 import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
12 import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
13 import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
14 import com.alibaba.druid.stat.TableStat;
15 import com.alibaba.druid.util.JdbcConstants;
16 import com.alibaba.druid.util.JdbcUtils;
17 import com.hx.util.StringUtils;
18
f35d93 19 import java.util.ArrayList;
c64e12 20 import java.util.Collection;
C 21 import java.util.List;
22 import java.util.Map;
23
24 /**
25  * sql语句处理工具
26  * @author CJH 2022-01-12
27  */
28 public class SqlUtils {
29
30     /**查询加密数据处理,只对查询做处理,select返回不做处理
31      * @param sql sql语句
32      * @param aesKeysTable aes秘钥
33      * @return
34      */
35     public static String selectSql(String sql,Map<String,Map<String,String>> aesKeysTable){
36
37         MySqlStatementParser parser = new MySqlStatementParser(sql);
38         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
39         //获取格式化的slq语句
40         sql = sqlStatement.toString();
41
d4efbe 42         System.out.println("sql:"+sql);
C 43
c64e12 44         //解析select查询
C 45         //SQLSelect sqlSelect = sqlStatement.getSelect()
46         //获取sql查询块
47         SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
48         StringBuffer out = new StringBuffer() ;
49         //创建sql解析的标准化输出
50         SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
51
52         //获取表和别名
53         ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
54         sqlStatement.accept(visitorTable);
55         Map<String,String> tableMaps = visitorTable.getTableMap();
56
57         //获取所有的字段
58         MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
59         sqlStatement.accept(visitor);
60         //遍历所有字段
61         Collection<TableStat.Column> columns= visitor.getColumns();
62
63         StringBuilder sqlWhere = new StringBuilder();
64
65         StringBuilder sqlSelect = new StringBuilder();
66         String expr = null;
67         sqlSelect.append("SELECT ");
68         //解析select返回的数据字段项
69         for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
70             if(sqlSelect.length() > 7){
71                 sqlSelect.append(",");
72             }
f35d93 73
C 74             out.delete(0, out.length()) ;
75             sqlSelectItem.accept(sqlastOutputVisitor) ;
76             expr = out.toString();
77             sqlSelect.append(expr);
78
79            /* if(expr.indexOf("SELECT") == -1){
c64e12 80                 sqlSelect.append(expr);
C 81             }else{
f35d93 82                 //sqlSelect.append("(");
C 83                 sqlSelect.append(expr);
84                 //sqlSelect.append(")");
85                *//* if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
c64e12 86                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
f35d93 87                 }*//*
C 88             }*/
c64e12 89         }
C 90
91         //解析from
346eb2 92         if(sqlSelectQuery.getFrom() != null){
C 93             out.delete(0, out.length()) ;
94             sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
95             sqlWhere.append(" FROM "+out);
96         }
c64e12 97
C 98         //解析where
346eb2 99         if(sqlSelectQuery.getWhere() != null){
C 100             out.delete(0, out.length()) ;
101             sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
102             sqlWhere.append(" WHERE "+out+" ");
103         }
c64e12 104
C 105         if(sqlSelectQuery.getGroupBy() != null){
106             out.delete(0, out.length()) ;
107             sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
108             sqlWhere.append(" "+out);
109         }
110         if(sqlSelectQuery.getOrderBy() != null){
111             out.delete(0, out.length()) ;
112             sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
113             sqlWhere.append(" "+out);
114         }
115         if(sqlSelectQuery.getLimit() != null){
116             out.delete(0, out.length()) ;
117             sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
118             sqlWhere.append(" "+out);
119         }
120
121         //处理where需要加密得字段
122         sql = sqlWhere.toString();
123         if(!StringUtils.isEmpty(sql)){
124             Map<String,String> aesKeys = null;
125             String aeskey = null;
126             //把剩下的拼接上来
127             String tableAl = null;
128             for(TableStat.Column column:columns){
129                 aesKeys= aesKeysTable.get(column.getTable());
130                 if(aesKeys == null){
131                     continue;
132                 }
133                 aeskey = aesKeys.getOrDefault(column.getName(),null);
134                 if(StringUtils.isEmpty(aeskey)){
135                     continue;
136                 }
137                 tableAl = tableMaps.get(column.getTable());
138                 if(!StringUtils.isEmpty(tableAl)){
139                     tableAl = tableAl+"."+column.getName();
140                 }else{
141                     tableAl = column.getName();
142                 }
d4efbe 143                 sql = sql.replaceAll("\\b"+tableAl+"\\b","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 144             }
C 145         }
146         return sqlSelect.toString()+sql;
147     }
148
f35d93 149     /**
C 150      * 处理select返回字段的参数
151      * @param sql
152      * @param aesKeysTable
153      * @param tableMaps
154      * @param columns
155      * @return
156      */
c64e12 157     public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable
C 158             ,Map<String,String> tableMaps,Collection<TableStat.Column> columns){
159
160
161         MySqlStatementParser parser = new MySqlStatementParser(sql);
162         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
163         //获取格式化的slq语句
164         sql = sqlStatement.toString();
165
166         //解析select查询
167         //SQLSelect sqlSelect = sqlStatement.getSelect() ;
168         //获取sql查询块
169         SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
170         StringBuffer out = new StringBuffer() ;
171         //创建sql解析的标准化输出
172         SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
173
174         StringBuilder sqlWhere = new StringBuilder();
175
176         StringBuilder sqlSelect = new StringBuilder();
177         String expr = null;
178         sqlSelect.append("SELECT ");
179         //解析select返回的数据字段项
180         for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
181             if(sqlSelect.length() > 7){
182                 sqlSelect.append(",");
183             }
184             expr = sqlSelectItem.getExpr().toString();
185             if(expr.indexOf("SELECT") == -1){
186                 sqlSelect.append(expr);
187                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
188                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
189                 }
190             }else{
191                 sqlSelect.append("(");
192                 selectSqlHandle(expr,aesKeysTable,tableMaps,columns);
193                 sqlSelect.append(")");
194                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
195                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
196                 }
197             }
198         }
199
200         //解析from
346eb2 201         if(sqlSelectQuery.getFrom() != null){
C 202             out.delete(0, out.length()) ;
203             sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
204             sqlWhere.append(" FROM "+out);
205         }
c64e12 206
C 207         //解析where
346eb2 208         if(sqlSelectQuery.getWhere() != null){
C 209             out.delete(0, out.length()) ;
210             sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
211             sqlWhere.append(" WHERE "+out+" ");
212         }
c64e12 213
C 214         if(sqlSelectQuery.getGroupBy() != null){
215             out.delete(0, out.length()) ;
216             sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
217             sqlWhere.append(" "+out);
218         }
219         if(sqlSelectQuery.getOrderBy() != null){
220             out.delete(0, out.length()) ;
221             sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
222             sqlWhere.append(" "+out);
223         }
224         if(sqlSelectQuery.getLimit() != null){
225             out.delete(0, out.length()) ;
226             sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
227             sqlWhere.append(" "+out);
228         }
229
230         sql = sqlWhere.toString();
231         if(!StringUtils.isEmpty(sql)){
232             Map<String,String> aesKeys = null;
233             String aeskey = null;
234             //把剩下的拼接上来
235             String tableAl = null;
236
237             for(TableStat.Column column:columns){
238                 aesKeys= aesKeysTable.get(column.getTable());
239                 if(aesKeys == null){
240                     continue;
241                 }
242                 aeskey = aesKeys.getOrDefault(column.getName(),null);
243                 if(StringUtils.isEmpty(aeskey)){
244                     continue;
245                 }
246                 tableAl = tableMaps.get(column.getTable());
247                 if(!StringUtils.isEmpty(tableAl)){
248                     tableAl = tableAl+"."+column.getName();
249                 }else{
250                     tableAl = column.getName();
251                 }
d4efbe 252                 sql = sql.replaceAll("\\b"+tableAl+"\\b","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 253             }
C 254         }
255         return sqlSelect.toString()+sql;
256     }
257
258     /**新增加密数据处理
259      * @param sql sql语句
260      * @param aesKeysTable aes秘钥
261      * @return
262      */
346eb2 263     public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
C 264         //装载重写的sql语句
265         StringBuilder splicingSql = new StringBuilder();
c64e12 266
346eb2 267         sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
C 268         String[] datas = sql.split("VALUES",2);
c64e12 269
346eb2 270         splicingSql.append(datas[0]+"VALUES ");
c64e12 271
346eb2 272         //重新拼接SQL语句
c64e12 273
346eb2 274         //解析sql语句
C 275         MySqlStatementParser parser = new MySqlStatementParser(sql);
276         SQLStatement statement = parser.parseStatement();
277         MySqlInsertStatement insert = (MySqlInsertStatement)statement;
c64e12 278
346eb2 279         String insertName = insert.getTableName().getSimpleName();
c64e12 280
346eb2 281         //根据表名称获取到AES秘钥
C 282         Map<String,String> aesKeys= aesKeysTable.get(insertName);
283         if(aesKeys == null){
284             return sql;
285         }
c64e12 286
346eb2 287         //获取所有的字段
C 288         List<SQLExpr> columns = insert.getColumns();
c64e12 289
346eb2 290         String fildValue = null;
C 291         String aeskey = null;
292         //遍历值
293         List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
294         for(int j = 0; j<vcl.size(); j++){
295             if( j != 0){
296                 splicingSql.append(",");
297             }
298             for(int i = 0;i < columns.size();i++){
299                 //查询改字段是否需要加密
300                 aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
301                 fildValue = vcl.get(j).getValues().get(i).toString();
302                 if(i == 0){
303                     splicingSql.append("(");
304                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
305                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
306                     }else{
307                         splicingSql.append(fildValue);
308                     }
309                 }else if(i == columns.size()-1){
310                     splicingSql.append(",");
311                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
312                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
313                     }else{
314                         splicingSql.append(fildValue);
315                     }
316                     splicingSql.append(")");
317                 }else{
318                     splicingSql.append(",");
319                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
320                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
321                     }else{
322                         splicingSql.append(fildValue);
323                     }
324                 }
325             }
326         }
327         return splicingSql.toString();
328     }
c64e12 329
C 330     /**更新加密数据处理
331      * @param sql sql语句
332      * @param aesKeysTable aes秘钥
333      * @return
334      */
335     public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){
346eb2 336
c64e12 337         //装载重写的sql语句
C 338         StringBuilder splicingSql = new StringBuilder();
339
340         //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
341         MySqlStatementParser parser = new MySqlStatementParser(sql);
342         SQLStatement sqlStatement = parser.parseStatement();
343         //获取格式化的slq语句
344         sql = sqlStatement.toString();
345
346         MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement;
347
348         String insertName = updateStatement.getTableName().getSimpleName();
349
350         String[] datas = sql.split("WHERE",2);
351
352         Map<String,String> aesKeys = aesKeysTable.get(insertName);
346eb2 353         if(aesKeys == null){
C 354             return sql;
355         }
c64e12 356
C 357         splicingSql.append("UPDATE "+insertName+" SET ");
358         String aeskey = null;
359         String fildValue = null;
360         List<SQLUpdateSetItem> items = updateStatement.getItems();
361         for(int i = 0;i<items.size();i++){
362             if(i != 0){
363                 splicingSql.append(",");
364             }
365             SQLUpdateSetItem item = items.get(i);
366             //查询改字段是否需要加密
367             aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null);
368
369             fildValue = item.getValue().toString();
370             if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
371                 splicingSql.append(item.getColumn()+" = HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
372             }else{
373                 splicingSql.append(item.getColumn()+" = "+fildValue);
374             }
375         }
376         String sqlWhere = " WHERE";
377         //把剩下的拼接上来
378         if(datas.length > 1){
379             for(int i =1;i<datas.length;i++){
380                 sqlWhere = sqlWhere+datas[i];
381             }
382
383             parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
384             sqlStatement = parser.parseStatement();
385
386             ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
387             sqlStatement.accept(visitorTable);
388
389             //获取表和别名
390             Map<String,String> tableMaps = visitorTable.getTableMap();
391             tableMaps.put(insertName,null);
392
393             //获取所有的字段
394             MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
395             sqlStatement.accept(visitor);
396
397             String tableAl = null;
398             //遍历所有字段
399             Collection<TableStat.Column> columns= visitor.getColumns();
400             for(TableStat.Column column:columns){
401
402                 aesKeys= aesKeysTable.get(column.getTable());
403                 if(aesKeys == null){
404                     continue;
405                 }
406                 aeskey = aesKeys.getOrDefault(column.getName(),null);
407                 if(StringUtils.isEmpty(aeskey)){
408                     continue;
409                 }
410                 tableAl = tableMaps.get(column.getTable());
411                 if(!StringUtils.isEmpty(tableAl)){
412                     tableAl = tableAl+"."+column.getName();
413                 }else{
414                     tableAl = column.getName();
415                 }
d4efbe 416                 sqlWhere = sqlWhere.replaceAll("\\b"+tableAl+"\\b","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 417             }
C 418
419         }
420         splicingSql.append(sqlWhere.toString());
421         return splicingSql.toString();
422     }
423
424     /**删除加密数据处理
425      * @param sql sql语句
426      * @param aesKeysTable aes秘钥
427      * @return
428      */
429     public static String deleteSql(String sql,Map<String,Map<String,String>> aesKeysTable){
430         //装载重写的sql语句
431         StringBuilder splicingSql = new StringBuilder();
432
433         //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
434         MySqlStatementParser parser = new MySqlStatementParser(sql);
435         SQLStatement sqlStatement = parser.parseStatement();
436         //获取格式化的slq语句
437         sql = sqlStatement.toString();
438
439         MySqlDeleteStatement deleteStatement = (MySqlDeleteStatement)sqlStatement;
440
441         String insertName = deleteStatement.getTableName().getSimpleName();
442
443         String[] datas = sql.split("WHERE",2);
444
445         Map<String,String> aesKeys = aesKeysTable.get(insertName);
346eb2 446         if(aesKeys == null){
C 447             return sql;
448         }
c64e12 449
C 450         splicingSql.append("DELETE FROM "+insertName);
451
452         String aeskey = null;
453
454         String sqlWhere = " WHERE";
455         //把剩下的拼接上来
456         if(datas.length > 1){
457             for(int i =1;i<datas.length;i++){
458                 sqlWhere = sqlWhere+datas[i];
459             }
460
461             parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
462             sqlStatement = parser.parseStatement();
463
464             ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
465             sqlStatement.accept(visitorTable);
466
467             //获取表和别名
468             Map<String,String> tableMaps = visitorTable.getTableMap();
469             tableMaps.put(insertName,null);
470
471             //获取所有的字段
472             MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
473             sqlStatement.accept(visitor);
474
475             String tableAl = null;
476             //遍历所有字段
477             Collection<TableStat.Column> columns= visitor.getColumns();
478             for(TableStat.Column column:columns){
479
480                 aesKeys= aesKeysTable.get(column.getTable());
481                 if(aesKeys == null){
482                     continue;
483                 }
484                 aeskey = aesKeys.getOrDefault(column.getName(),null);
485                 if(StringUtils.isEmpty(aeskey)){
486                     continue;
487                 }
488                 tableAl = tableMaps.get(column.getTable());
489                 if(!StringUtils.isEmpty(tableAl)){
490                     tableAl = tableAl+"."+column.getName();
491                 }else{
492                     tableAl = column.getName();
493                 }
d4efbe 494                 sqlWhere = sqlWhere.replaceAll("\\b"+tableAl+"\\b","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 495             }
C 496
497         }
498         splicingSql.append(sqlWhere.toString());
499         return splicingSql.toString();
500     }
501
502 }