b
zhouxiang
2022-04-24 3261046de3da38bfc68d4bd8fa9140e0044d2715
提交 | 用户 | age
c64e12 1 package com.hx.mybatis.aes.springbean;
C 2
3 import com.alibaba.druid.sql.SQLUtils;
4 import com.alibaba.druid.sql.ast.SQLExpr;
f35d93 5 import com.alibaba.druid.sql.ast.SQLObject;
c64e12 6 import com.alibaba.druid.sql.ast.SQLStatement;
C 7 import com.alibaba.druid.sql.ast.statement.*;
8 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
9 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
10 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
11 import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
12 import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
13 import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
14 import com.alibaba.druid.stat.TableStat;
15 import com.alibaba.druid.util.JdbcConstants;
16 import com.alibaba.druid.util.JdbcUtils;
17 import com.hx.util.StringUtils;
18
f35d93 19 import java.util.ArrayList;
c64e12 20 import java.util.Collection;
C 21 import java.util.List;
22 import java.util.Map;
23
24 /**
25  * sql语句处理工具
26  * @author CJH 2022-01-12
27  */
28 public class SqlUtils {
29
30     /**查询加密数据处理,只对查询做处理,select返回不做处理
31      * @param sql sql语句
32      * @param aesKeysTable aes秘钥
33      * @return
34      */
35     public static String selectSql(String sql,Map<String,Map<String,String>> aesKeysTable){
36
37         MySqlStatementParser parser = new MySqlStatementParser(sql);
38         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
39         //获取格式化的slq语句
40         sql = sqlStatement.toString();
41
42         //解析select查询
43         //SQLSelect sqlSelect = sqlStatement.getSelect()
44         //获取sql查询块
326104 45         SQLSelectQueryBlock sqlSelectQuery;
Z 46         try{
47             sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
48         }catch (Exception e){
49             return sql;
50         }
51
c64e12 52         StringBuffer out = new StringBuffer() ;
C 53         //创建sql解析的标准化输出
54         SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
55
56         //获取表和别名
57         ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
58         sqlStatement.accept(visitorTable);
59         Map<String,String> tableMaps = visitorTable.getTableMap();
60
61         //获取所有的字段
62         MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
63         sqlStatement.accept(visitor);
64         //遍历所有字段
65         Collection<TableStat.Column> columns= visitor.getColumns();
66
67         StringBuilder sqlWhere = new StringBuilder();
68
69         StringBuilder sqlSelect = new StringBuilder();
70         String expr = null;
71         sqlSelect.append("SELECT ");
72         //解析select返回的数据字段项
73         for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
74             if(sqlSelect.length() > 7){
75                 sqlSelect.append(",");
76             }
f35d93 77
C 78             out.delete(0, out.length()) ;
79             sqlSelectItem.accept(sqlastOutputVisitor) ;
80             expr = out.toString();
81             sqlSelect.append(expr);
82
83            /* if(expr.indexOf("SELECT") == -1){
c64e12 84                 sqlSelect.append(expr);
C 85             }else{
f35d93 86                 //sqlSelect.append("(");
C 87                 sqlSelect.append(expr);
88                 //sqlSelect.append(")");
89                *//* if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
c64e12 90                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
f35d93 91                 }*//*
C 92             }*/
c64e12 93         }
C 94
95         //解析from
346eb2 96         if(sqlSelectQuery.getFrom() != null){
C 97             out.delete(0, out.length()) ;
98             sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
99             sqlWhere.append(" FROM "+out);
100         }
c64e12 101
C 102         //解析where
346eb2 103         if(sqlSelectQuery.getWhere() != null){
C 104             out.delete(0, out.length()) ;
105             sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
106             sqlWhere.append(" WHERE "+out+" ");
107         }
c64e12 108
C 109         if(sqlSelectQuery.getGroupBy() != null){
110             out.delete(0, out.length()) ;
111             sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
112             sqlWhere.append(" "+out);
113         }
114         if(sqlSelectQuery.getOrderBy() != null){
115             out.delete(0, out.length()) ;
116             sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
117             sqlWhere.append(" "+out);
118         }
119         if(sqlSelectQuery.getLimit() != null){
120             out.delete(0, out.length()) ;
121             sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
122             sqlWhere.append(" "+out);
123         }
124
125         //处理where需要加密得字段
126         sql = sqlWhere.toString();
127         if(!StringUtils.isEmpty(sql)){
128             Map<String,String> aesKeys = null;
129             String aeskey = null;
130             //把剩下的拼接上来
131             String tableAl = null;
132             for(TableStat.Column column:columns){
133                 aesKeys= aesKeysTable.get(column.getTable());
134                 if(aesKeys == null){
135                     continue;
136                 }
137                 aeskey = aesKeys.getOrDefault(column.getName(),null);
138                 if(StringUtils.isEmpty(aeskey)){
139                     continue;
140                 }
141                 tableAl = tableMaps.get(column.getTable());
142                 if(!StringUtils.isEmpty(tableAl)){
143                     tableAl = tableAl+"."+column.getName();
144                 }else{
145                     tableAl = column.getName();
146                 }
5c933d 147                 sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 148             }
C 149         }
150         return sqlSelect.toString()+sql;
151     }
152
f35d93 153     /**
C 154      * 处理select返回字段的参数
155      * @param sql
156      * @param aesKeysTable
157      * @param tableMaps
158      * @param columns
159      * @return
160      */
c64e12 161     public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable
C 162             ,Map<String,String> tableMaps,Collection<TableStat.Column> columns){
163
164
165         MySqlStatementParser parser = new MySqlStatementParser(sql);
166         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
167         //获取格式化的slq语句
168         sql = sqlStatement.toString();
169
170         //解析select查询
171         //SQLSelect sqlSelect = sqlStatement.getSelect() ;
172         //获取sql查询块
173         SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
174         StringBuffer out = new StringBuffer() ;
175         //创建sql解析的标准化输出
176         SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
177
178         StringBuilder sqlWhere = new StringBuilder();
179
180         StringBuilder sqlSelect = new StringBuilder();
181         String expr = null;
182         sqlSelect.append("SELECT ");
183         //解析select返回的数据字段项
184         for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
185             if(sqlSelect.length() > 7){
186                 sqlSelect.append(",");
187             }
188             expr = sqlSelectItem.getExpr().toString();
189             if(expr.indexOf("SELECT") == -1){
190                 sqlSelect.append(expr);
191                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
192                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
193                 }
194             }else{
195                 sqlSelect.append("(");
196                 selectSqlHandle(expr,aesKeysTable,tableMaps,columns);
197                 sqlSelect.append(")");
198                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
199                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
200                 }
201             }
202         }
203
204         //解析from
346eb2 205         if(sqlSelectQuery.getFrom() != null){
C 206             out.delete(0, out.length()) ;
207             sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
208             sqlWhere.append(" FROM "+out);
209         }
c64e12 210
C 211         //解析where
346eb2 212         if(sqlSelectQuery.getWhere() != null){
C 213             out.delete(0, out.length()) ;
214             sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
215             sqlWhere.append(" WHERE "+out+" ");
216         }
c64e12 217
C 218         if(sqlSelectQuery.getGroupBy() != null){
219             out.delete(0, out.length()) ;
220             sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
221             sqlWhere.append(" "+out);
222         }
223         if(sqlSelectQuery.getOrderBy() != null){
224             out.delete(0, out.length()) ;
225             sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
226             sqlWhere.append(" "+out);
227         }
228         if(sqlSelectQuery.getLimit() != null){
229             out.delete(0, out.length()) ;
230             sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
231             sqlWhere.append(" "+out);
232         }
233
234         sql = sqlWhere.toString();
235         if(!StringUtils.isEmpty(sql)){
236             Map<String,String> aesKeys = null;
237             String aeskey = null;
238             //把剩下的拼接上来
239             String tableAl = null;
240
241             for(TableStat.Column column:columns){
242                 aesKeys= aesKeysTable.get(column.getTable());
243                 if(aesKeys == null){
244                     continue;
245                 }
246                 aeskey = aesKeys.getOrDefault(column.getName(),null);
247                 if(StringUtils.isEmpty(aeskey)){
248                     continue;
249                 }
250                 tableAl = tableMaps.get(column.getTable());
251                 if(!StringUtils.isEmpty(tableAl)){
252                     tableAl = tableAl+"."+column.getName();
253                 }else{
254                     tableAl = column.getName();
255                 }
5c933d 256                 sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 257             }
C 258         }
259         return sqlSelect.toString()+sql;
260     }
261
262     /**新增加密数据处理
263      * @param sql sql语句
264      * @param aesKeysTable aes秘钥
265      * @return
266      */
346eb2 267     public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
C 268         //装载重写的sql语句
269         StringBuilder splicingSql = new StringBuilder();
c64e12 270
346eb2 271         sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
C 272         String[] datas = sql.split("VALUES",2);
c64e12 273
346eb2 274         splicingSql.append(datas[0]+"VALUES ");
c64e12 275
346eb2 276         //重新拼接SQL语句
c64e12 277
346eb2 278         //解析sql语句
C 279         MySqlStatementParser parser = new MySqlStatementParser(sql);
280         SQLStatement statement = parser.parseStatement();
281         MySqlInsertStatement insert = (MySqlInsertStatement)statement;
c64e12 282
346eb2 283         String insertName = insert.getTableName().getSimpleName();
c64e12 284
346eb2 285         //根据表名称获取到AES秘钥
C 286         Map<String,String> aesKeys= aesKeysTable.get(insertName);
287         if(aesKeys == null){
288             return sql;
289         }
c64e12 290
346eb2 291         //获取所有的字段
C 292         List<SQLExpr> columns = insert.getColumns();
c64e12 293
346eb2 294         String fildValue = null;
C 295         String aeskey = null;
296         //遍历值
297         List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
298         for(int j = 0; j<vcl.size(); j++){
299             if( j != 0){
300                 splicingSql.append(",");
301             }
302             for(int i = 0;i < columns.size();i++){
303                 //查询改字段是否需要加密
304                 aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
305                 fildValue = vcl.get(j).getValues().get(i).toString();
306                 if(i == 0){
307                     splicingSql.append("(");
308                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
309                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
310                     }else{
311                         splicingSql.append(fildValue);
312                     }
313                 }else if(i == columns.size()-1){
314                     splicingSql.append(",");
315                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
316                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
317                     }else{
318                         splicingSql.append(fildValue);
319                     }
320                     splicingSql.append(")");
321                 }else{
322                     splicingSql.append(",");
323                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
324                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
325                     }else{
326                         splicingSql.append(fildValue);
327                     }
328                 }
329             }
330         }
331         return splicingSql.toString();
332     }
c64e12 333
C 334     /**更新加密数据处理
335      * @param sql sql语句
336      * @param aesKeysTable aes秘钥
337      * @return
338      */
339     public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){
346eb2 340
c64e12 341         //装载重写的sql语句
C 342         StringBuilder splicingSql = new StringBuilder();
343
344         //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
345         MySqlStatementParser parser = new MySqlStatementParser(sql);
346         SQLStatement sqlStatement = parser.parseStatement();
347         //获取格式化的slq语句
348         sql = sqlStatement.toString();
349
350         MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement;
351
352         String insertName = updateStatement.getTableName().getSimpleName();
353
354         String[] datas = sql.split("WHERE",2);
355
356         Map<String,String> aesKeys = aesKeysTable.get(insertName);
346eb2 357         if(aesKeys == null){
C 358             return sql;
359         }
c64e12 360
C 361         splicingSql.append("UPDATE "+insertName+" SET ");
362         String aeskey = null;
363         String fildValue = null;
364         List<SQLUpdateSetItem> items = updateStatement.getItems();
365         for(int i = 0;i<items.size();i++){
366             if(i != 0){
367                 splicingSql.append(",");
368             }
369             SQLUpdateSetItem item = items.get(i);
370             //查询改字段是否需要加密
371             aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null);
372
373             fildValue = item.getValue().toString();
374             if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
375                 splicingSql.append(item.getColumn()+" = HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
376             }else{
377                 splicingSql.append(item.getColumn()+" = "+fildValue);
378             }
379         }
380         String sqlWhere = " WHERE";
381         //把剩下的拼接上来
382         if(datas.length > 1){
383             for(int i =1;i<datas.length;i++){
384                 sqlWhere = sqlWhere+datas[i];
385             }
386
387             parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
388             sqlStatement = parser.parseStatement();
389
390             ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
391             sqlStatement.accept(visitorTable);
392
393             //获取表和别名
394             Map<String,String> tableMaps = visitorTable.getTableMap();
395             tableMaps.put(insertName,null);
396
397             //获取所有的字段
398             MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
399             sqlStatement.accept(visitor);
400
401             String tableAl = null;
402             //遍历所有字段
403             Collection<TableStat.Column> columns= visitor.getColumns();
404             for(TableStat.Column column:columns){
405
406                 aesKeys= aesKeysTable.get(column.getTable());
407                 if(aesKeys == null){
408                     continue;
409                 }
410                 aeskey = aesKeys.getOrDefault(column.getName(),null);
411                 if(StringUtils.isEmpty(aeskey)){
412                     continue;
413                 }
414                 tableAl = tableMaps.get(column.getTable());
415                 if(!StringUtils.isEmpty(tableAl)){
416                     tableAl = tableAl+"."+column.getName();
417                 }else{
418                     tableAl = column.getName();
419                 }
5c933d 420                 sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 421             }
C 422
423         }
326104 424         splicingSql.append(sqlWhere);
c64e12 425         return splicingSql.toString();
C 426     }
427
428     /**删除加密数据处理
429      * @param sql sql语句
430      * @param aesKeysTable aes秘钥
431      * @return
432      */
433     public static String deleteSql(String sql,Map<String,Map<String,String>> aesKeysTable){
434         //装载重写的sql语句
435         StringBuilder splicingSql = new StringBuilder();
436
437         //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
438         MySqlStatementParser parser = new MySqlStatementParser(sql);
439         SQLStatement sqlStatement = parser.parseStatement();
440         //获取格式化的slq语句
441         sql = sqlStatement.toString();
442
443         MySqlDeleteStatement deleteStatement = (MySqlDeleteStatement)sqlStatement;
444
445         String insertName = deleteStatement.getTableName().getSimpleName();
446
447         String[] datas = sql.split("WHERE",2);
448
449         Map<String,String> aesKeys = aesKeysTable.get(insertName);
346eb2 450         if(aesKeys == null){
C 451             return sql;
452         }
c64e12 453
C 454         splicingSql.append("DELETE FROM "+insertName);
455
456         String aeskey = null;
457
458         String sqlWhere = " WHERE";
459         //把剩下的拼接上来
460         if(datas.length > 1){
461             for(int i =1;i<datas.length;i++){
462                 sqlWhere = sqlWhere+datas[i];
463             }
464
465             parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
466             sqlStatement = parser.parseStatement();
467
468             ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
469             sqlStatement.accept(visitorTable);
470
471             //获取表和别名
472             Map<String,String> tableMaps = visitorTable.getTableMap();
473             tableMaps.put(insertName,null);
474
475             //获取所有的字段
476             MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
477             sqlStatement.accept(visitor);
478
479             String tableAl = null;
480             //遍历所有字段
481             Collection<TableStat.Column> columns= visitor.getColumns();
482             for(TableStat.Column column:columns){
483
484                 aesKeys= aesKeysTable.get(column.getTable());
485                 if(aesKeys == null){
486                     continue;
487                 }
488                 aeskey = aesKeys.getOrDefault(column.getName(),null);
489                 if(StringUtils.isEmpty(aeskey)){
490                     continue;
491                 }
492                 tableAl = tableMaps.get(column.getTable());
493                 if(!StringUtils.isEmpty(tableAl)){
494                     tableAl = tableAl+"."+column.getName();
495                 }else{
496                     tableAl = column.getName();
497                 }
5c933d 498                 sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 499             }
C 500
501         }
326104 502         splicingSql.append(sqlWhere);
c64e12 503         return splicingSql.toString();
C 504     }
505
506 }