chenjiahe
2022-01-14 346eb22b5ab622064d1f61819240656529800f2b
提交 | 用户 | age
c64e12 1 package com.hx.mybatis.aes.springbean;
C 2
3 import com.alibaba.druid.sql.SQLUtils;
4 import com.alibaba.druid.sql.ast.SQLExpr;
5 import com.alibaba.druid.sql.ast.SQLStatement;
6 import com.alibaba.druid.sql.ast.statement.*;
7 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
8 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
9 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
10 import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
11 import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
12 import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
13 import com.alibaba.druid.stat.TableStat;
14 import com.alibaba.druid.util.JdbcConstants;
15 import com.alibaba.druid.util.JdbcUtils;
16 import com.hx.util.StringUtils;
17
18 import java.util.Collection;
19 import java.util.List;
20 import java.util.Map;
21
22 /**
23  * sql语句处理工具
24  * @author CJH 2022-01-12
25  */
26 public class SqlUtils {
27
28     /**查询加密数据处理,只对查询做处理,select返回不做处理
29      * @param sql sql语句
30      * @param aesKeysTable aes秘钥
31      * @return
32      */
33     public static String selectSql(String sql,Map<String,Map<String,String>> aesKeysTable){
34
35         MySqlStatementParser parser = new MySqlStatementParser(sql);
36         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
37         //获取格式化的slq语句
38         sql = sqlStatement.toString();
39
40         //解析select查询
41         //SQLSelect sqlSelect = sqlStatement.getSelect()
42         //获取sql查询块
43         SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
44         StringBuffer out = new StringBuffer() ;
45         //创建sql解析的标准化输出
46         SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
47
48         //获取表和别名
49         ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
50         sqlStatement.accept(visitorTable);
51         Map<String,String> tableMaps = visitorTable.getTableMap();
52
53         //获取所有的字段
54         MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
55         sqlStatement.accept(visitor);
56         //遍历所有字段
57         Collection<TableStat.Column> columns= visitor.getColumns();
58
59         StringBuilder sqlWhere = new StringBuilder();
60
61         StringBuilder sqlSelect = new StringBuilder();
62         String expr = null;
63         sqlSelect.append("SELECT ");
64         //解析select返回的数据字段项
65         for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
66             if(sqlSelect.length() > 7){
67                 sqlSelect.append(",");
68             }
69             expr = sqlSelectItem.getExpr().toString();
70             if(expr.indexOf("SELECT") == -1){
71                 sqlSelect.append(expr);
72                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
73                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
74                 }
75             }else{
76                 sqlSelect.append("(");
77                 sqlSelect.append(selectSqlHandle(expr,aesKeysTable,tableMaps,columns));
78                 sqlSelect.append(")");
79                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
80                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
81                 }
82             }
83         }
84
85         //解析from
346eb2 86         if(sqlSelectQuery.getFrom() != null){
C 87             out.delete(0, out.length()) ;
88             sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
89             sqlWhere.append(" FROM "+out);
90         }
c64e12 91
C 92         //解析where
346eb2 93         if(sqlSelectQuery.getWhere() != null){
C 94             out.delete(0, out.length()) ;
95             sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
96             sqlWhere.append(" WHERE "+out+" ");
97         }
c64e12 98
C 99         if(sqlSelectQuery.getGroupBy() != null){
100             out.delete(0, out.length()) ;
101             sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
102             sqlWhere.append(" "+out);
103         }
104         if(sqlSelectQuery.getOrderBy() != null){
105             out.delete(0, out.length()) ;
106             sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
107             sqlWhere.append(" "+out);
108         }
109         if(sqlSelectQuery.getLimit() != null){
110             out.delete(0, out.length()) ;
111             sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
112             sqlWhere.append(" "+out);
113         }
114
115         //处理where需要加密得字段
116         sql = sqlWhere.toString();
117         if(!StringUtils.isEmpty(sql)){
118             Map<String,String> aesKeys = null;
119             String aeskey = null;
120             //把剩下的拼接上来
121             String tableAl = null;
122             for(TableStat.Column column:columns){
123                 aesKeys= aesKeysTable.get(column.getTable());
124                 if(aesKeys == null){
125                     continue;
126                 }
127                 aeskey = aesKeys.getOrDefault(column.getName(),null);
128                 if(StringUtils.isEmpty(aeskey)){
129                     continue;
130                 }
131                 tableAl = tableMaps.get(column.getTable());
132                 if(!StringUtils.isEmpty(tableAl)){
133                     tableAl = tableAl+"."+column.getName();
134                 }else{
135                     tableAl = column.getName();
136                 }
137                 sql = sql.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
138             }
139         }
140         return sqlSelect.toString()+sql;
141     }
142
143     public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable
144             ,Map<String,String> tableMaps,Collection<TableStat.Column> columns){
145
146
147         MySqlStatementParser parser = new MySqlStatementParser(sql);
148         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
149         //获取格式化的slq语句
150         sql = sqlStatement.toString();
151
152         //解析select查询
153         //SQLSelect sqlSelect = sqlStatement.getSelect() ;
154         //获取sql查询块
155         SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
156         StringBuffer out = new StringBuffer() ;
157         //创建sql解析的标准化输出
158         SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
159
160         StringBuilder sqlWhere = new StringBuilder();
161
162         StringBuilder sqlSelect = new StringBuilder();
163         String expr = null;
164         sqlSelect.append("SELECT ");
165         //解析select返回的数据字段项
166         for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
167             if(sqlSelect.length() > 7){
168                 sqlSelect.append(",");
169             }
170             expr = sqlSelectItem.getExpr().toString();
171             if(expr.indexOf("SELECT") == -1){
172                 sqlSelect.append(expr);
173                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
174                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
175                 }
176             }else{
177                 sqlSelect.append("(");
178                 selectSqlHandle(expr,aesKeysTable,tableMaps,columns);
179                 sqlSelect.append(")");
180                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
181                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
182                 }
183             }
184         }
185
186         //解析from
346eb2 187         if(sqlSelectQuery.getFrom() != null){
C 188             out.delete(0, out.length()) ;
189             sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
190             sqlWhere.append(" FROM "+out);
191         }
c64e12 192
C 193         //解析where
346eb2 194         if(sqlSelectQuery.getWhere() != null){
C 195             out.delete(0, out.length()) ;
196             sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
197             sqlWhere.append(" WHERE "+out+" ");
198         }
c64e12 199
C 200         if(sqlSelectQuery.getGroupBy() != null){
201             out.delete(0, out.length()) ;
202             sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
203             sqlWhere.append(" "+out);
204         }
205         if(sqlSelectQuery.getOrderBy() != null){
206             out.delete(0, out.length()) ;
207             sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
208             sqlWhere.append(" "+out);
209         }
210         if(sqlSelectQuery.getLimit() != null){
211             out.delete(0, out.length()) ;
212             sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
213             sqlWhere.append(" "+out);
214         }
215
216         sql = sqlWhere.toString();
217         if(!StringUtils.isEmpty(sql)){
218             Map<String,String> aesKeys = null;
219             String aeskey = null;
220             //把剩下的拼接上来
221             String tableAl = null;
222
223             for(TableStat.Column column:columns){
224                 aesKeys= aesKeysTable.get(column.getTable());
225                 if(aesKeys == null){
226                     continue;
227                 }
228                 aeskey = aesKeys.getOrDefault(column.getName(),null);
229                 if(StringUtils.isEmpty(aeskey)){
230                     continue;
231                 }
232                 tableAl = tableMaps.get(column.getTable());
233                 if(!StringUtils.isEmpty(tableAl)){
234                     tableAl = tableAl+"."+column.getName();
235                 }else{
236                     tableAl = column.getName();
237                 }
238                 sql = sql.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
239             }
240         }
241         return sqlSelect.toString()+sql;
242     }
243
244     /**新增加密数据处理
245      * @param sql sql语句
246      * @param aesKeysTable aes秘钥
247      * @return
248      */
346eb2 249     public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
C 250         //装载重写的sql语句
251         StringBuilder splicingSql = new StringBuilder();
c64e12 252
346eb2 253         sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
C 254         String[] datas = sql.split("VALUES",2);
c64e12 255
346eb2 256         splicingSql.append(datas[0]+"VALUES ");
c64e12 257
346eb2 258         //重新拼接SQL语句
c64e12 259
346eb2 260         //解析sql语句
C 261         MySqlStatementParser parser = new MySqlStatementParser(sql);
262         SQLStatement statement = parser.parseStatement();
263         MySqlInsertStatement insert = (MySqlInsertStatement)statement;
c64e12 264
346eb2 265         String insertName = insert.getTableName().getSimpleName();
c64e12 266
346eb2 267         //根据表名称获取到AES秘钥
C 268         Map<String,String> aesKeys= aesKeysTable.get(insertName);
269         if(aesKeys == null){
270             return sql;
271         }
c64e12 272
346eb2 273         //获取所有的字段
C 274         List<SQLExpr> columns = insert.getColumns();
c64e12 275
346eb2 276         String fildValue = null;
C 277         String aeskey = null;
278         //遍历值
279         List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
280         for(int j = 0; j<vcl.size(); j++){
281             if( j != 0){
282                 splicingSql.append(",");
283             }
284             for(int i = 0;i < columns.size();i++){
285                 //查询改字段是否需要加密
286                 aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
287                 fildValue = vcl.get(j).getValues().get(i).toString();
288                 if(i == 0){
289                     splicingSql.append("(");
290                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
291                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
292                     }else{
293                         splicingSql.append(fildValue);
294                     }
295                 }else if(i == columns.size()-1){
296                     splicingSql.append(",");
297                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
298                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
299                     }else{
300                         splicingSql.append(fildValue);
301                     }
302                     splicingSql.append(")");
303                 }else{
304                     splicingSql.append(",");
305                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
306                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
307                     }else{
308                         splicingSql.append(fildValue);
309                     }
310                 }
311             }
312         }
313         return splicingSql.toString();
314     }
c64e12 315
C 316     /**更新加密数据处理
317      * @param sql sql语句
318      * @param aesKeysTable aes秘钥
319      * @return
320      */
321     public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){
346eb2 322
c64e12 323         //装载重写的sql语句
C 324         StringBuilder splicingSql = new StringBuilder();
325
326         //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
327         MySqlStatementParser parser = new MySqlStatementParser(sql);
328         SQLStatement sqlStatement = parser.parseStatement();
329         //获取格式化的slq语句
330         sql = sqlStatement.toString();
331
332         MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement;
333
334         String insertName = updateStatement.getTableName().getSimpleName();
335
336         String[] datas = sql.split("WHERE",2);
337
338         Map<String,String> aesKeys = aesKeysTable.get(insertName);
346eb2 339         if(aesKeys == null){
C 340             return sql;
341         }
c64e12 342
C 343         splicingSql.append("UPDATE "+insertName+" SET ");
344         String aeskey = null;
345         String fildValue = null;
346         List<SQLUpdateSetItem> items = updateStatement.getItems();
347         for(int i = 0;i<items.size();i++){
348             if(i != 0){
349                 splicingSql.append(",");
350             }
351             SQLUpdateSetItem item = items.get(i);
352             //查询改字段是否需要加密
353             aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null);
354
355             fildValue = item.getValue().toString();
356             if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
357                 splicingSql.append(item.getColumn()+" = HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
358             }else{
359                 splicingSql.append(item.getColumn()+" = "+fildValue);
360             }
361         }
362         String sqlWhere = " WHERE";
363         //把剩下的拼接上来
364         if(datas.length > 1){
365             for(int i =1;i<datas.length;i++){
366                 sqlWhere = sqlWhere+datas[i];
367             }
368
369             parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
370             sqlStatement = parser.parseStatement();
371
372             ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
373             sqlStatement.accept(visitorTable);
374
375             //获取表和别名
376             Map<String,String> tableMaps = visitorTable.getTableMap();
377             tableMaps.put(insertName,null);
378
379             //获取所有的字段
380             MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
381             sqlStatement.accept(visitor);
382
383             String tableAl = null;
384             //遍历所有字段
385             Collection<TableStat.Column> columns= visitor.getColumns();
386             for(TableStat.Column column:columns){
387
388                 aesKeys= aesKeysTable.get(column.getTable());
389                 if(aesKeys == null){
390                     continue;
391                 }
392                 aeskey = aesKeys.getOrDefault(column.getName(),null);
393                 if(StringUtils.isEmpty(aeskey)){
394                     continue;
395                 }
396                 tableAl = tableMaps.get(column.getTable());
397                 if(!StringUtils.isEmpty(tableAl)){
398                     tableAl = tableAl+"."+column.getName();
399                 }else{
400                     tableAl = column.getName();
401                 }
402                 sqlWhere = sqlWhere.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
403             }
404
405         }
406         splicingSql.append(sqlWhere.toString());
407         return splicingSql.toString();
408     }
409
410     /**删除加密数据处理
411      * @param sql sql语句
412      * @param aesKeysTable aes秘钥
413      * @return
414      */
415     public static String deleteSql(String sql,Map<String,Map<String,String>> aesKeysTable){
416         //装载重写的sql语句
417         StringBuilder splicingSql = new StringBuilder();
418
419         //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
420         MySqlStatementParser parser = new MySqlStatementParser(sql);
421         SQLStatement sqlStatement = parser.parseStatement();
422         //获取格式化的slq语句
423         sql = sqlStatement.toString();
424
425         MySqlDeleteStatement deleteStatement = (MySqlDeleteStatement)sqlStatement;
426
427         String insertName = deleteStatement.getTableName().getSimpleName();
428
429         String[] datas = sql.split("WHERE",2);
430
431         Map<String,String> aesKeys = aesKeysTable.get(insertName);
346eb2 432         if(aesKeys == null){
C 433             return sql;
434         }
c64e12 435
C 436         splicingSql.append("DELETE FROM "+insertName);
437
438         String aeskey = null;
439
440         String sqlWhere = " WHERE";
441         //把剩下的拼接上来
442         if(datas.length > 1){
443             for(int i =1;i<datas.length;i++){
444                 sqlWhere = sqlWhere+datas[i];
445             }
446
447             parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
448             sqlStatement = parser.parseStatement();
449
450             ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
451             sqlStatement.accept(visitorTable);
452
453             //获取表和别名
454             Map<String,String> tableMaps = visitorTable.getTableMap();
455             tableMaps.put(insertName,null);
456
457             //获取所有的字段
458             MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
459             sqlStatement.accept(visitor);
460
461             String tableAl = null;
462             //遍历所有字段
463             Collection<TableStat.Column> columns= visitor.getColumns();
464             for(TableStat.Column column:columns){
465
466                 aesKeys= aesKeysTable.get(column.getTable());
467                 if(aesKeys == null){
468                     continue;
469                 }
470                 aeskey = aesKeys.getOrDefault(column.getName(),null);
471                 if(StringUtils.isEmpty(aeskey)){
472                     continue;
473                 }
474                 tableAl = tableMaps.get(column.getTable());
475                 if(!StringUtils.isEmpty(tableAl)){
476                     tableAl = tableAl+"."+column.getName();
477                 }else{
478                     tableAl = column.getName();
479                 }
480                 sqlWhere = sqlWhere.replaceAll("( |\\n|\\()"+tableAl+"( |\\n|\\))"," AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"') ");
481             }
482
483         }
484         splicingSql.append(sqlWhere.toString());
485         return splicingSql.toString();
486     }
487
488 }