chenjiahe
2022-04-02 0344dc66783353faff118a0ea91f7d2aa07bbd4a
提交 | 用户 | age
c64e12 1 package com.hx.mybatis.aes.springbean;
C 2
3 import com.alibaba.druid.sql.SQLUtils;
4 import com.alibaba.druid.sql.ast.SQLExpr;
f35d93 5 import com.alibaba.druid.sql.ast.SQLObject;
c64e12 6 import com.alibaba.druid.sql.ast.SQLStatement;
C 7 import com.alibaba.druid.sql.ast.statement.*;
8 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
9 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
10 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
11 import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
12 import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
13 import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
14 import com.alibaba.druid.stat.TableStat;
15 import com.alibaba.druid.util.JdbcConstants;
16 import com.alibaba.druid.util.JdbcUtils;
17 import com.hx.util.StringUtils;
18
f35d93 19 import java.util.ArrayList;
c64e12 20 import java.util.Collection;
C 21 import java.util.List;
22 import java.util.Map;
23
24 /**
25  * sql语句处理工具
26  * @author CJH 2022-01-12
27  */
28 public class SqlUtils {
29
30     /**查询加密数据处理,只对查询做处理,select返回不做处理
31      * @param sql sql语句
32      * @param aesKeysTable aes秘钥
33      * @return
34      */
35     public static String selectSql(String sql,Map<String,Map<String,String>> aesKeysTable){
36
37         MySqlStatementParser parser = new MySqlStatementParser(sql);
38         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
39         //获取格式化的slq语句
40         sql = sqlStatement.toString();
41
42         //解析select查询
43         //SQLSelect sqlSelect = sqlStatement.getSelect()
44         //获取sql查询块
45         SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
46         StringBuffer out = new StringBuffer() ;
47         //创建sql解析的标准化输出
48         SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
49
50         //获取表和别名
51         ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
52         sqlStatement.accept(visitorTable);
53         Map<String,String> tableMaps = visitorTable.getTableMap();
54
55         //获取所有的字段
56         MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
57         sqlStatement.accept(visitor);
58         //遍历所有字段
59         Collection<TableStat.Column> columns= visitor.getColumns();
60
61         StringBuilder sqlWhere = new StringBuilder();
62
63         StringBuilder sqlSelect = new StringBuilder();
64         String expr = null;
65         sqlSelect.append("SELECT ");
66         //解析select返回的数据字段项
67         for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
68             if(sqlSelect.length() > 7){
69                 sqlSelect.append(",");
70             }
f35d93 71
C 72             out.delete(0, out.length()) ;
73             sqlSelectItem.accept(sqlastOutputVisitor) ;
74             expr = out.toString();
75             sqlSelect.append(expr);
76
77            /* if(expr.indexOf("SELECT") == -1){
c64e12 78                 sqlSelect.append(expr);
C 79             }else{
f35d93 80                 //sqlSelect.append("(");
C 81                 sqlSelect.append(expr);
82                 //sqlSelect.append(")");
83                *//* if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
c64e12 84                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
f35d93 85                 }*//*
C 86             }*/
c64e12 87         }
C 88
89         //解析from
346eb2 90         if(sqlSelectQuery.getFrom() != null){
C 91             out.delete(0, out.length()) ;
92             sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
93             sqlWhere.append(" FROM "+out);
94         }
c64e12 95
C 96         //解析where
346eb2 97         if(sqlSelectQuery.getWhere() != null){
C 98             out.delete(0, out.length()) ;
99             sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
100             sqlWhere.append(" WHERE "+out+" ");
101         }
c64e12 102
C 103         if(sqlSelectQuery.getGroupBy() != null){
104             out.delete(0, out.length()) ;
105             sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
106             sqlWhere.append(" "+out);
107         }
108         if(sqlSelectQuery.getOrderBy() != null){
109             out.delete(0, out.length()) ;
110             sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
111             sqlWhere.append(" "+out);
112         }
113         if(sqlSelectQuery.getLimit() != null){
114             out.delete(0, out.length()) ;
115             sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
116             sqlWhere.append(" "+out);
117         }
118
119         //处理where需要加密得字段
120         sql = sqlWhere.toString();
121         if(!StringUtils.isEmpty(sql)){
122             Map<String,String> aesKeys = null;
123             String aeskey = null;
124             //把剩下的拼接上来
125             String tableAl = null;
126             for(TableStat.Column column:columns){
127                 aesKeys= aesKeysTable.get(column.getTable());
128                 if(aesKeys == null){
129                     continue;
130                 }
131                 aeskey = aesKeys.getOrDefault(column.getName(),null);
132                 if(StringUtils.isEmpty(aeskey)){
133                     continue;
134                 }
135                 tableAl = tableMaps.get(column.getTable());
136                 if(!StringUtils.isEmpty(tableAl)){
137                     tableAl = tableAl+"."+column.getName();
138                 }else{
139                     tableAl = column.getName();
140                 }
5c933d 141                 sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 142             }
C 143         }
144         return sqlSelect.toString()+sql;
145     }
146
f35d93 147     /**
C 148      * 处理select返回字段的参数
149      * @param sql
150      * @param aesKeysTable
151      * @param tableMaps
152      * @param columns
153      * @return
154      */
c64e12 155     public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable
C 156             ,Map<String,String> tableMaps,Collection<TableStat.Column> columns){
157
158
159         MySqlStatementParser parser = new MySqlStatementParser(sql);
160         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
161         //获取格式化的slq语句
162         sql = sqlStatement.toString();
163
164         //解析select查询
165         //SQLSelect sqlSelect = sqlStatement.getSelect() ;
166         //获取sql查询块
167         SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
168         StringBuffer out = new StringBuffer() ;
169         //创建sql解析的标准化输出
170         SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
171
172         StringBuilder sqlWhere = new StringBuilder();
173
174         StringBuilder sqlSelect = new StringBuilder();
175         String expr = null;
176         sqlSelect.append("SELECT ");
177         //解析select返回的数据字段项
178         for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
179             if(sqlSelect.length() > 7){
180                 sqlSelect.append(",");
181             }
182             expr = sqlSelectItem.getExpr().toString();
183             if(expr.indexOf("SELECT") == -1){
184                 sqlSelect.append(expr);
185                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
186                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
187                 }
188             }else{
189                 sqlSelect.append("(");
190                 selectSqlHandle(expr,aesKeysTable,tableMaps,columns);
191                 sqlSelect.append(")");
192                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
193                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
194                 }
195             }
196         }
197
198         //解析from
346eb2 199         if(sqlSelectQuery.getFrom() != null){
C 200             out.delete(0, out.length()) ;
201             sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
202             sqlWhere.append(" FROM "+out);
203         }
c64e12 204
C 205         //解析where
346eb2 206         if(sqlSelectQuery.getWhere() != null){
C 207             out.delete(0, out.length()) ;
208             sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
209             sqlWhere.append(" WHERE "+out+" ");
210         }
c64e12 211
C 212         if(sqlSelectQuery.getGroupBy() != null){
213             out.delete(0, out.length()) ;
214             sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
215             sqlWhere.append(" "+out);
216         }
217         if(sqlSelectQuery.getOrderBy() != null){
218             out.delete(0, out.length()) ;
219             sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
220             sqlWhere.append(" "+out);
221         }
222         if(sqlSelectQuery.getLimit() != null){
223             out.delete(0, out.length()) ;
224             sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
225             sqlWhere.append(" "+out);
226         }
227
228         sql = sqlWhere.toString();
229         if(!StringUtils.isEmpty(sql)){
230             Map<String,String> aesKeys = null;
231             String aeskey = null;
232             //把剩下的拼接上来
233             String tableAl = null;
234
235             for(TableStat.Column column:columns){
236                 aesKeys= aesKeysTable.get(column.getTable());
237                 if(aesKeys == null){
238                     continue;
239                 }
240                 aeskey = aesKeys.getOrDefault(column.getName(),null);
241                 if(StringUtils.isEmpty(aeskey)){
242                     continue;
243                 }
244                 tableAl = tableMaps.get(column.getTable());
245                 if(!StringUtils.isEmpty(tableAl)){
246                     tableAl = tableAl+"."+column.getName();
247                 }else{
248                     tableAl = column.getName();
249                 }
5c933d 250                 sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 251             }
C 252         }
253         return sqlSelect.toString()+sql;
254     }
255
256     /**新增加密数据处理
257      * @param sql sql语句
258      * @param aesKeysTable aes秘钥
259      * @return
260      */
346eb2 261     public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
C 262         //装载重写的sql语句
263         StringBuilder splicingSql = new StringBuilder();
c64e12 264
346eb2 265         sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
C 266         String[] datas = sql.split("VALUES",2);
c64e12 267
346eb2 268         splicingSql.append(datas[0]+"VALUES ");
c64e12 269
346eb2 270         //重新拼接SQL语句
c64e12 271
346eb2 272         //解析sql语句
C 273         MySqlStatementParser parser = new MySqlStatementParser(sql);
274         SQLStatement statement = parser.parseStatement();
275         MySqlInsertStatement insert = (MySqlInsertStatement)statement;
c64e12 276
346eb2 277         String insertName = insert.getTableName().getSimpleName();
c64e12 278
346eb2 279         //根据表名称获取到AES秘钥
C 280         Map<String,String> aesKeys= aesKeysTable.get(insertName);
281         if(aesKeys == null){
282             return sql;
283         }
c64e12 284
346eb2 285         //获取所有的字段
C 286         List<SQLExpr> columns = insert.getColumns();
c64e12 287
346eb2 288         String fildValue = null;
C 289         String aeskey = null;
290         //遍历值
291         List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
292         for(int j = 0; j<vcl.size(); j++){
293             if( j != 0){
294                 splicingSql.append(",");
295             }
296             for(int i = 0;i < columns.size();i++){
297                 //查询改字段是否需要加密
298                 aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
299                 fildValue = vcl.get(j).getValues().get(i).toString();
300                 if(i == 0){
301                     splicingSql.append("(");
302                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
303                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
304                     }else{
305                         splicingSql.append(fildValue);
306                     }
307                 }else if(i == columns.size()-1){
308                     splicingSql.append(",");
309                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
310                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
311                     }else{
312                         splicingSql.append(fildValue);
313                     }
314                     splicingSql.append(")");
315                 }else{
316                     splicingSql.append(",");
317                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
318                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
319                     }else{
320                         splicingSql.append(fildValue);
321                     }
322                 }
323             }
324         }
325         return splicingSql.toString();
326     }
c64e12 327
C 328     /**更新加密数据处理
329      * @param sql sql语句
330      * @param aesKeysTable aes秘钥
331      * @return
332      */
333     public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){
346eb2 334
c64e12 335         //装载重写的sql语句
C 336         StringBuilder splicingSql = new StringBuilder();
337
338         //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
339         MySqlStatementParser parser = new MySqlStatementParser(sql);
340         SQLStatement sqlStatement = parser.parseStatement();
341         //获取格式化的slq语句
342         sql = sqlStatement.toString();
343
344         MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement;
345
346         String insertName = updateStatement.getTableName().getSimpleName();
347
348         String[] datas = sql.split("WHERE",2);
349
350         Map<String,String> aesKeys = aesKeysTable.get(insertName);
346eb2 351         if(aesKeys == null){
C 352             return sql;
353         }
c64e12 354
C 355         splicingSql.append("UPDATE "+insertName+" SET ");
356         String aeskey = null;
357         String fildValue = null;
358         List<SQLUpdateSetItem> items = updateStatement.getItems();
359         for(int i = 0;i<items.size();i++){
360             if(i != 0){
361                 splicingSql.append(",");
362             }
363             SQLUpdateSetItem item = items.get(i);
364             //查询改字段是否需要加密
365             aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null);
366
367             fildValue = item.getValue().toString();
368             if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
369                 splicingSql.append(item.getColumn()+" = HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
370             }else{
371                 splicingSql.append(item.getColumn()+" = "+fildValue);
372             }
373         }
374         String sqlWhere = " WHERE";
375         //把剩下的拼接上来
376         if(datas.length > 1){
377             for(int i =1;i<datas.length;i++){
378                 sqlWhere = sqlWhere+datas[i];
379             }
380
381             parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
382             sqlStatement = parser.parseStatement();
383
384             ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
385             sqlStatement.accept(visitorTable);
386
387             //获取表和别名
388             Map<String,String> tableMaps = visitorTable.getTableMap();
389             tableMaps.put(insertName,null);
390
391             //获取所有的字段
392             MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
393             sqlStatement.accept(visitor);
394
395             String tableAl = null;
396             //遍历所有字段
397             Collection<TableStat.Column> columns= visitor.getColumns();
398             for(TableStat.Column column:columns){
399
400                 aesKeys= aesKeysTable.get(column.getTable());
401                 if(aesKeys == null){
402                     continue;
403                 }
404                 aeskey = aesKeys.getOrDefault(column.getName(),null);
405                 if(StringUtils.isEmpty(aeskey)){
406                     continue;
407                 }
408                 tableAl = tableMaps.get(column.getTable());
409                 if(!StringUtils.isEmpty(tableAl)){
410                     tableAl = tableAl+"."+column.getName();
411                 }else{
412                     tableAl = column.getName();
413                 }
5c933d 414                 sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 415             }
C 416
417         }
418         splicingSql.append(sqlWhere.toString());
419         return splicingSql.toString();
420     }
421
422     /**删除加密数据处理
423      * @param sql sql语句
424      * @param aesKeysTable aes秘钥
425      * @return
426      */
427     public static String deleteSql(String sql,Map<String,Map<String,String>> aesKeysTable){
428         //装载重写的sql语句
429         StringBuilder splicingSql = new StringBuilder();
430
431         //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
432         MySqlStatementParser parser = new MySqlStatementParser(sql);
433         SQLStatement sqlStatement = parser.parseStatement();
434         //获取格式化的slq语句
435         sql = sqlStatement.toString();
436
437         MySqlDeleteStatement deleteStatement = (MySqlDeleteStatement)sqlStatement;
438
439         String insertName = deleteStatement.getTableName().getSimpleName();
440
441         String[] datas = sql.split("WHERE",2);
442
443         Map<String,String> aesKeys = aesKeysTable.get(insertName);
346eb2 444         if(aesKeys == null){
C 445             return sql;
446         }
c64e12 447
C 448         splicingSql.append("DELETE FROM "+insertName);
449
450         String aeskey = null;
451
452         String sqlWhere = " WHERE";
453         //把剩下的拼接上来
454         if(datas.length > 1){
455             for(int i =1;i<datas.length;i++){
456                 sqlWhere = sqlWhere+datas[i];
457             }
458
459             parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
460             sqlStatement = parser.parseStatement();
461
462             ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
463             sqlStatement.accept(visitorTable);
464
465             //获取表和别名
466             Map<String,String> tableMaps = visitorTable.getTableMap();
467             tableMaps.put(insertName,null);
468
469             //获取所有的字段
470             MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
471             sqlStatement.accept(visitor);
472
473             String tableAl = null;
474             //遍历所有字段
475             Collection<TableStat.Column> columns= visitor.getColumns();
476             for(TableStat.Column column:columns){
477
478                 aesKeys= aesKeysTable.get(column.getTable());
479                 if(aesKeys == null){
480                     continue;
481                 }
482                 aeskey = aesKeys.getOrDefault(column.getName(),null);
483                 if(StringUtils.isEmpty(aeskey)){
484                     continue;
485                 }
486                 tableAl = tableMaps.get(column.getTable());
487                 if(!StringUtils.isEmpty(tableAl)){
488                     tableAl = tableAl+"."+column.getName();
489                 }else{
490                     tableAl = column.getName();
491                 }
5c933d 492                 sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 493             }
C 494
495         }
496         splicingSql.append(sqlWhere.toString());
497         return splicingSql.toString();
498     }
499
500 }