chenjiahe
2022-04-25 fb2c9fa355cc2e09bc051677dba89f86e9c0bd00
提交 | 用户 | age
c64e12 1 package com.hx.mybatis.aes.springbean;
C 2
3 import com.alibaba.druid.sql.SQLUtils;
4 import com.alibaba.druid.sql.ast.SQLExpr;
f35d93 5 import com.alibaba.druid.sql.ast.SQLObject;
c64e12 6 import com.alibaba.druid.sql.ast.SQLStatement;
C 7 import com.alibaba.druid.sql.ast.statement.*;
8 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
9 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
10 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
11 import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
12 import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
13 import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
14 import com.alibaba.druid.stat.TableStat;
15 import com.alibaba.druid.util.JdbcConstants;
16 import com.alibaba.druid.util.JdbcUtils;
17 import com.hx.util.StringUtils;
72586e 18 import org.slf4j.Logger;
Z 19 import org.slf4j.LoggerFactory;
c64e12 20
f35d93 21 import java.util.ArrayList;
c64e12 22 import java.util.Collection;
C 23 import java.util.List;
24 import java.util.Map;
25
26 /**
27  * sql语句处理工具
28  * @author CJH 2022-01-12
29  */
30 public class SqlUtils {
72586e 31     //log4j日志
Z 32     private static Logger logger = LoggerFactory.getLogger(SqlUtils.class.getName());
33
c64e12 34
fb2c9f 35     /**查询加密数据处理,只对查询做处理
c64e12 36      * @param sql sql语句
C 37      * @param aesKeysTable aes秘钥
38      * @return
39      */
40     public static String selectSql(String sql,Map<String,Map<String,String>> aesKeysTable){
41
42         MySqlStatementParser parser = new MySqlStatementParser(sql);
43         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
fb2c9f 44
C 45         //获取表和别名
46         ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
47         sqlStatement.accept(visitorTable);
48         Map<String,String> tableMaps = visitorTable.getTableMap();
49
50         //获取所有的字段
51         MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
52         sqlStatement.accept(visitor);
53         //遍历所有字段
54         Collection<TableStat.Column> columns= visitor.getColumns();
55
56         //处理需要加密得字段
57
58         if(!StringUtils.isEmpty(sql)){
59             Map<String,String> aesKeys = null;
60             String aeskey = null;
61             //把剩下的拼接上来
62             String tableAl = null;
63             for(TableStat.Column column:columns){
64                 aesKeys= aesKeysTable.get(column.getTable());
65                 if(aesKeys == null){
66                     continue;
67                 }
68                 aeskey = aesKeys.getOrDefault(column.getName(),null);
69                 if(StringUtils.isEmpty(aeskey)){
70                     continue;
71                 }
72                 tableAl = tableMaps.get(column.getTable());
73                 if(!StringUtils.isEmpty(tableAl)){
74                     tableAl = tableAl+"."+column.getName();
75                 }else{
76                     tableAl = column.getName();
77                 }
78                 sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
79             }
80         }
81         return sql;
82     }
83
84
85     /**查询加密数据处理,只对查询做处理,select返回不做处理(备份)
86      * @param sql sql语句
87      * @param aesKeysTable aes秘钥
88      * @return
89      */
90     public static String selectSqlDemo(String sql,Map<String,Map<String,String>> aesKeysTable){
91
92         MySqlStatementParser parser = new MySqlStatementParser(sql);
93         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
c64e12 94         //获取格式化的slq语句
C 95         sql = sqlStatement.toString();
96
72586e 97
Z 98
c64e12 99         //解析select查询
C 100         //SQLSelect sqlSelect = sqlStatement.getSelect()
101         //获取sql查询块
72586e 102         SQLSelectQueryBlock sqlSelectQuery = null;
Z 103         boolean b = true;
326104 104         try{
Z 105             sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
106         }catch (Exception e){
72586e 107             b = false;
8ab2ad 108             logger.error("解析sql报错:"+e.getMessage());
72586e 109         }
Z 110
111         if(!b){
112             return "err";
326104 113         }
Z 114
c64e12 115         StringBuffer out = new StringBuffer() ;
C 116         //创建sql解析的标准化输出
117         SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
118
119         //获取表和别名
120         ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
121         sqlStatement.accept(visitorTable);
122         Map<String,String> tableMaps = visitorTable.getTableMap();
123
124         //获取所有的字段
125         MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
126         sqlStatement.accept(visitor);
127         //遍历所有字段
128         Collection<TableStat.Column> columns= visitor.getColumns();
129
130         StringBuilder sqlWhere = new StringBuilder();
131
132         StringBuilder sqlSelect = new StringBuilder();
133         String expr = null;
134         sqlSelect.append("SELECT ");
135         //解析select返回的数据字段项
136         for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
137             if(sqlSelect.length() > 7){
138                 sqlSelect.append(",");
139             }
f35d93 140
C 141             out.delete(0, out.length()) ;
142             sqlSelectItem.accept(sqlastOutputVisitor) ;
143             expr = out.toString();
144             sqlSelect.append(expr);
145
146            /* if(expr.indexOf("SELECT") == -1){
c64e12 147                 sqlSelect.append(expr);
C 148             }else{
f35d93 149                 //sqlSelect.append("(");
C 150                 sqlSelect.append(expr);
151                 //sqlSelect.append(")");
152                *//* if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
c64e12 153                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
f35d93 154                 }*//*
C 155             }*/
c64e12 156         }
C 157
158         //解析from
346eb2 159         if(sqlSelectQuery.getFrom() != null){
C 160             out.delete(0, out.length()) ;
161             sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
162             sqlWhere.append(" FROM "+out);
163         }
c64e12 164
C 165         //解析where
346eb2 166         if(sqlSelectQuery.getWhere() != null){
C 167             out.delete(0, out.length()) ;
168             sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
169             sqlWhere.append(" WHERE "+out+" ");
170         }
c64e12 171
C 172         if(sqlSelectQuery.getGroupBy() != null){
173             out.delete(0, out.length()) ;
174             sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
175             sqlWhere.append(" "+out);
176         }
177         if(sqlSelectQuery.getOrderBy() != null){
178             out.delete(0, out.length()) ;
179             sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
180             sqlWhere.append(" "+out);
181         }
182         if(sqlSelectQuery.getLimit() != null){
183             out.delete(0, out.length()) ;
184             sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
185             sqlWhere.append(" "+out);
186         }
187
188         //处理where需要加密得字段
189         sql = sqlWhere.toString();
190         if(!StringUtils.isEmpty(sql)){
191             Map<String,String> aesKeys = null;
192             String aeskey = null;
193             //把剩下的拼接上来
194             String tableAl = null;
195             for(TableStat.Column column:columns){
196                 aesKeys= aesKeysTable.get(column.getTable());
197                 if(aesKeys == null){
198                     continue;
199                 }
200                 aeskey = aesKeys.getOrDefault(column.getName(),null);
201                 if(StringUtils.isEmpty(aeskey)){
202                     continue;
203                 }
204                 tableAl = tableMaps.get(column.getTable());
205                 if(!StringUtils.isEmpty(tableAl)){
206                     tableAl = tableAl+"."+column.getName();
207                 }else{
208                     tableAl = column.getName();
209                 }
5c933d 210                 sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 211             }
C 212         }
213         return sqlSelect.toString()+sql;
214     }
215
f35d93 216     /**
C 217      * 处理select返回字段的参数
218      * @param sql
219      * @param aesKeysTable
220      * @param tableMaps
221      * @param columns
222      * @return
223      */
c64e12 224     public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable
C 225             ,Map<String,String> tableMaps,Collection<TableStat.Column> columns){
226
227
228         MySqlStatementParser parser = new MySqlStatementParser(sql);
229         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
230         //获取格式化的slq语句
231         sql = sqlStatement.toString();
232
233         //解析select查询
234         //SQLSelect sqlSelect = sqlStatement.getSelect() ;
235         //获取sql查询块
236         SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
237         StringBuffer out = new StringBuffer() ;
238         //创建sql解析的标准化输出
239         SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
240
241         StringBuilder sqlWhere = new StringBuilder();
242
243         StringBuilder sqlSelect = new StringBuilder();
244         String expr = null;
245         sqlSelect.append("SELECT ");
246         //解析select返回的数据字段项
247         for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
248             if(sqlSelect.length() > 7){
249                 sqlSelect.append(",");
250             }
251             expr = sqlSelectItem.getExpr().toString();
252             if(expr.indexOf("SELECT") == -1){
253                 sqlSelect.append(expr);
254                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
255                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
256                 }
257             }else{
258                 sqlSelect.append("(");
259                 selectSqlHandle(expr,aesKeysTable,tableMaps,columns);
260                 sqlSelect.append(")");
261                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
262                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
263                 }
264             }
265         }
266
267         //解析from
346eb2 268         if(sqlSelectQuery.getFrom() != null){
C 269             out.delete(0, out.length()) ;
270             sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
271             sqlWhere.append(" FROM "+out);
272         }
c64e12 273
C 274         //解析where
346eb2 275         if(sqlSelectQuery.getWhere() != null){
C 276             out.delete(0, out.length()) ;
277             sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
278             sqlWhere.append(" WHERE "+out+" ");
279         }
c64e12 280
C 281         if(sqlSelectQuery.getGroupBy() != null){
282             out.delete(0, out.length()) ;
283             sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
284             sqlWhere.append(" "+out);
285         }
286         if(sqlSelectQuery.getOrderBy() != null){
287             out.delete(0, out.length()) ;
288             sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
289             sqlWhere.append(" "+out);
290         }
291         if(sqlSelectQuery.getLimit() != null){
292             out.delete(0, out.length()) ;
293             sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
294             sqlWhere.append(" "+out);
295         }
296
297         sql = sqlWhere.toString();
298         if(!StringUtils.isEmpty(sql)){
299             Map<String,String> aesKeys = null;
300             String aeskey = null;
301             //把剩下的拼接上来
302             String tableAl = null;
303
304             for(TableStat.Column column:columns){
305                 aesKeys= aesKeysTable.get(column.getTable());
306                 if(aesKeys == null){
307                     continue;
308                 }
309                 aeskey = aesKeys.getOrDefault(column.getName(),null);
310                 if(StringUtils.isEmpty(aeskey)){
311                     continue;
312                 }
313                 tableAl = tableMaps.get(column.getTable());
314                 if(!StringUtils.isEmpty(tableAl)){
315                     tableAl = tableAl+"."+column.getName();
316                 }else{
317                     tableAl = column.getName();
318                 }
5c933d 319                 sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 320             }
C 321         }
322         return sqlSelect.toString()+sql;
323     }
324
325     /**新增加密数据处理
326      * @param sql sql语句
327      * @param aesKeysTable aes秘钥
328      * @return
329      */
346eb2 330     public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
C 331         //装载重写的sql语句
332         StringBuilder splicingSql = new StringBuilder();
c64e12 333
346eb2 334         sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
C 335         String[] datas = sql.split("VALUES",2);
c64e12 336
346eb2 337         splicingSql.append(datas[0]+"VALUES ");
c64e12 338
346eb2 339         //重新拼接SQL语句
c64e12 340
346eb2 341         //解析sql语句
C 342         MySqlStatementParser parser = new MySqlStatementParser(sql);
343         SQLStatement statement = parser.parseStatement();
344         MySqlInsertStatement insert = (MySqlInsertStatement)statement;
c64e12 345
346eb2 346         String insertName = insert.getTableName().getSimpleName();
c64e12 347
346eb2 348         //根据表名称获取到AES秘钥
C 349         Map<String,String> aesKeys= aesKeysTable.get(insertName);
350         if(aesKeys == null){
351             return sql;
352         }
c64e12 353
346eb2 354         //获取所有的字段
C 355         List<SQLExpr> columns = insert.getColumns();
c64e12 356
346eb2 357         String fildValue = null;
C 358         String aeskey = null;
359         //遍历值
360         List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
361         for(int j = 0; j<vcl.size(); j++){
362             if( j != 0){
363                 splicingSql.append(",");
364             }
365             for(int i = 0;i < columns.size();i++){
366                 //查询改字段是否需要加密
367                 aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
368                 fildValue = vcl.get(j).getValues().get(i).toString();
369                 if(i == 0){
370                     splicingSql.append("(");
371                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
372                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
373                     }else{
374                         splicingSql.append(fildValue);
375                     }
376                 }else if(i == columns.size()-1){
377                     splicingSql.append(",");
378                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
379                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
380                     }else{
381                         splicingSql.append(fildValue);
382                     }
383                     splicingSql.append(")");
384                 }else{
385                     splicingSql.append(",");
386                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
387                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
388                     }else{
389                         splicingSql.append(fildValue);
390                     }
391                 }
392             }
393         }
394         return splicingSql.toString();
395     }
c64e12 396
C 397     /**更新加密数据处理
398      * @param sql sql语句
399      * @param aesKeysTable aes秘钥
400      * @return
401      */
402     public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){
346eb2 403
c64e12 404         //装载重写的sql语句
C 405         StringBuilder splicingSql = new StringBuilder();
406
407         //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
408         MySqlStatementParser parser = new MySqlStatementParser(sql);
409         SQLStatement sqlStatement = parser.parseStatement();
410         //获取格式化的slq语句
411         sql = sqlStatement.toString();
412
413         MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement;
414
415         String insertName = updateStatement.getTableName().getSimpleName();
416
417         String[] datas = sql.split("WHERE",2);
418
419         Map<String,String> aesKeys = aesKeysTable.get(insertName);
346eb2 420         if(aesKeys == null){
C 421             return sql;
422         }
c64e12 423
C 424         splicingSql.append("UPDATE "+insertName+" SET ");
425         String aeskey = null;
426         String fildValue = null;
427         List<SQLUpdateSetItem> items = updateStatement.getItems();
428         for(int i = 0;i<items.size();i++){
429             if(i != 0){
430                 splicingSql.append(",");
431             }
432             SQLUpdateSetItem item = items.get(i);
433             //查询改字段是否需要加密
434             aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null);
435
436             fildValue = item.getValue().toString();
437             if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
438                 splicingSql.append(item.getColumn()+" = HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
439             }else{
440                 splicingSql.append(item.getColumn()+" = "+fildValue);
441             }
442         }
443         String sqlWhere = " WHERE";
444         //把剩下的拼接上来
445         if(datas.length > 1){
446             for(int i =1;i<datas.length;i++){
447                 sqlWhere = sqlWhere+datas[i];
448             }
449
450             parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
451             sqlStatement = parser.parseStatement();
452
453             ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
454             sqlStatement.accept(visitorTable);
455
456             //获取表和别名
457             Map<String,String> tableMaps = visitorTable.getTableMap();
458             tableMaps.put(insertName,null);
459
460             //获取所有的字段
461             MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
462             sqlStatement.accept(visitor);
463
464             String tableAl = null;
465             //遍历所有字段
466             Collection<TableStat.Column> columns= visitor.getColumns();
467             for(TableStat.Column column:columns){
468
469                 aesKeys= aesKeysTable.get(column.getTable());
470                 if(aesKeys == null){
471                     continue;
472                 }
473                 aeskey = aesKeys.getOrDefault(column.getName(),null);
474                 if(StringUtils.isEmpty(aeskey)){
475                     continue;
476                 }
477                 tableAl = tableMaps.get(column.getTable());
478                 if(!StringUtils.isEmpty(tableAl)){
479                     tableAl = tableAl+"."+column.getName();
480                 }else{
481                     tableAl = column.getName();
482                 }
5c933d 483                 sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 484             }
C 485
486         }
326104 487         splicingSql.append(sqlWhere);
c64e12 488         return splicingSql.toString();
C 489     }
490
491     /**删除加密数据处理
492      * @param sql sql语句
493      * @param aesKeysTable aes秘钥
494      * @return
495      */
496     public static String deleteSql(String sql,Map<String,Map<String,String>> aesKeysTable){
497         //装载重写的sql语句
498         StringBuilder splicingSql = new StringBuilder();
499
500         //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
501         MySqlStatementParser parser = new MySqlStatementParser(sql);
502         SQLStatement sqlStatement = parser.parseStatement();
503         //获取格式化的slq语句
504         sql = sqlStatement.toString();
505
506         MySqlDeleteStatement deleteStatement = (MySqlDeleteStatement)sqlStatement;
507
508         String insertName = deleteStatement.getTableName().getSimpleName();
509
510         String[] datas = sql.split("WHERE",2);
511
512         Map<String,String> aesKeys = aesKeysTable.get(insertName);
346eb2 513         if(aesKeys == null){
C 514             return sql;
515         }
c64e12 516
C 517         splicingSql.append("DELETE FROM "+insertName);
518
519         String aeskey = null;
520
521         String sqlWhere = " WHERE";
522         //把剩下的拼接上来
523         if(datas.length > 1){
524             for(int i =1;i<datas.length;i++){
525                 sqlWhere = sqlWhere+datas[i];
526             }
527
528             parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
529             sqlStatement = parser.parseStatement();
530
531             ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
532             sqlStatement.accept(visitorTable);
533
534             //获取表和别名
535             Map<String,String> tableMaps = visitorTable.getTableMap();
536             tableMaps.put(insertName,null);
537
538             //获取所有的字段
539             MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
540             sqlStatement.accept(visitor);
541
542             String tableAl = null;
543             //遍历所有字段
544             Collection<TableStat.Column> columns= visitor.getColumns();
545             for(TableStat.Column column:columns){
546
547                 aesKeys= aesKeysTable.get(column.getTable());
548                 if(aesKeys == null){
549                     continue;
550                 }
551                 aeskey = aesKeys.getOrDefault(column.getName(),null);
552                 if(StringUtils.isEmpty(aeskey)){
553                     continue;
554                 }
555                 tableAl = tableMaps.get(column.getTable());
556                 if(!StringUtils.isEmpty(tableAl)){
557                     tableAl = tableAl+"."+column.getName();
558                 }else{
559                     tableAl = column.getName();
560                 }
5c933d 561                 sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 562             }
C 563
564         }
326104 565         splicingSql.append(sqlWhere);
c64e12 566         return splicingSql.toString();
C 567     }
568
569 }