zhouxiang
2022-04-25 8ab2ad5580212b91df848e4c127f2a682485fde3
提交 | 用户 | age
c64e12 1 package com.hx.mybatis.aes.springbean;
C 2
3 import com.alibaba.druid.sql.SQLUtils;
4 import com.alibaba.druid.sql.ast.SQLExpr;
f35d93 5 import com.alibaba.druid.sql.ast.SQLObject;
c64e12 6 import com.alibaba.druid.sql.ast.SQLStatement;
C 7 import com.alibaba.druid.sql.ast.statement.*;
8 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
9 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
10 import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
11 import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
12 import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
13 import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
14 import com.alibaba.druid.stat.TableStat;
15 import com.alibaba.druid.util.JdbcConstants;
16 import com.alibaba.druid.util.JdbcUtils;
17 import com.hx.util.StringUtils;
72586e 18 import org.slf4j.Logger;
Z 19 import org.slf4j.LoggerFactory;
c64e12 20
f35d93 21 import java.util.ArrayList;
c64e12 22 import java.util.Collection;
C 23 import java.util.List;
24 import java.util.Map;
25
26 /**
27  * sql语句处理工具
28  * @author CJH 2022-01-12
29  */
30 public class SqlUtils {
72586e 31     //log4j日志
Z 32     private static Logger logger = LoggerFactory.getLogger(SqlUtils.class.getName());
33
c64e12 34
C 35     /**查询加密数据处理,只对查询做处理,select返回不做处理
36      * @param sql sql语句
37      * @param aesKeysTable aes秘钥
38      * @return
39      */
40     public static String selectSql(String sql,Map<String,Map<String,String>> aesKeysTable){
41
42         MySqlStatementParser parser = new MySqlStatementParser(sql);
43         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
44         //获取格式化的slq语句
45         sql = sqlStatement.toString();
46
72586e 47
Z 48
c64e12 49         //解析select查询
C 50         //SQLSelect sqlSelect = sqlStatement.getSelect()
51         //获取sql查询块
72586e 52         SQLSelectQueryBlock sqlSelectQuery = null;
Z 53         boolean b = true;
326104 54         try{
Z 55             sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
56         }catch (Exception e){
72586e 57             b = false;
8ab2ad 58             logger.error("解析sql报错:"+e.getMessage());
72586e 59         }
Z 60
61         if(!b){
62             return "err";
326104 63         }
Z 64
c64e12 65         StringBuffer out = new StringBuffer() ;
C 66         //创建sql解析的标准化输出
67         SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
68
69         //获取表和别名
70         ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
71         sqlStatement.accept(visitorTable);
72         Map<String,String> tableMaps = visitorTable.getTableMap();
73
74         //获取所有的字段
75         MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
76         sqlStatement.accept(visitor);
77         //遍历所有字段
78         Collection<TableStat.Column> columns= visitor.getColumns();
79
80         StringBuilder sqlWhere = new StringBuilder();
81
82         StringBuilder sqlSelect = new StringBuilder();
83         String expr = null;
84         sqlSelect.append("SELECT ");
85         //解析select返回的数据字段项
86         for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
87             if(sqlSelect.length() > 7){
88                 sqlSelect.append(",");
89             }
f35d93 90
C 91             out.delete(0, out.length()) ;
92             sqlSelectItem.accept(sqlastOutputVisitor) ;
93             expr = out.toString();
94             sqlSelect.append(expr);
95
96            /* if(expr.indexOf("SELECT") == -1){
c64e12 97                 sqlSelect.append(expr);
C 98             }else{
f35d93 99                 //sqlSelect.append("(");
C 100                 sqlSelect.append(expr);
101                 //sqlSelect.append(")");
102                *//* if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
c64e12 103                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
f35d93 104                 }*//*
C 105             }*/
c64e12 106         }
C 107
108         //解析from
346eb2 109         if(sqlSelectQuery.getFrom() != null){
C 110             out.delete(0, out.length()) ;
111             sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
112             sqlWhere.append(" FROM "+out);
113         }
c64e12 114
C 115         //解析where
346eb2 116         if(sqlSelectQuery.getWhere() != null){
C 117             out.delete(0, out.length()) ;
118             sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
119             sqlWhere.append(" WHERE "+out+" ");
120         }
c64e12 121
C 122         if(sqlSelectQuery.getGroupBy() != null){
123             out.delete(0, out.length()) ;
124             sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
125             sqlWhere.append(" "+out);
126         }
127         if(sqlSelectQuery.getOrderBy() != null){
128             out.delete(0, out.length()) ;
129             sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
130             sqlWhere.append(" "+out);
131         }
132         if(sqlSelectQuery.getLimit() != null){
133             out.delete(0, out.length()) ;
134             sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
135             sqlWhere.append(" "+out);
136         }
137
138         //处理where需要加密得字段
139         sql = sqlWhere.toString();
140         if(!StringUtils.isEmpty(sql)){
141             Map<String,String> aesKeys = null;
142             String aeskey = null;
143             //把剩下的拼接上来
144             String tableAl = null;
145             for(TableStat.Column column:columns){
146                 aesKeys= aesKeysTable.get(column.getTable());
147                 if(aesKeys == null){
148                     continue;
149                 }
150                 aeskey = aesKeys.getOrDefault(column.getName(),null);
151                 if(StringUtils.isEmpty(aeskey)){
152                     continue;
153                 }
154                 tableAl = tableMaps.get(column.getTable());
155                 if(!StringUtils.isEmpty(tableAl)){
156                     tableAl = tableAl+"."+column.getName();
157                 }else{
158                     tableAl = column.getName();
159                 }
5c933d 160                 sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 161             }
C 162         }
163         return sqlSelect.toString()+sql;
164     }
165
f35d93 166     /**
C 167      * 处理select返回字段的参数
168      * @param sql
169      * @param aesKeysTable
170      * @param tableMaps
171      * @param columns
172      * @return
173      */
c64e12 174     public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable
C 175             ,Map<String,String> tableMaps,Collection<TableStat.Column> columns){
176
177
178         MySqlStatementParser parser = new MySqlStatementParser(sql);
179         SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
180         //获取格式化的slq语句
181         sql = sqlStatement.toString();
182
183         //解析select查询
184         //SQLSelect sqlSelect = sqlStatement.getSelect() ;
185         //获取sql查询块
186         SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
187         StringBuffer out = new StringBuffer() ;
188         //创建sql解析的标准化输出
189         SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
190
191         StringBuilder sqlWhere = new StringBuilder();
192
193         StringBuilder sqlSelect = new StringBuilder();
194         String expr = null;
195         sqlSelect.append("SELECT ");
196         //解析select返回的数据字段项
197         for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
198             if(sqlSelect.length() > 7){
199                 sqlSelect.append(",");
200             }
201             expr = sqlSelectItem.getExpr().toString();
202             if(expr.indexOf("SELECT") == -1){
203                 sqlSelect.append(expr);
204                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
205                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
206                 }
207             }else{
208                 sqlSelect.append("(");
209                 selectSqlHandle(expr,aesKeysTable,tableMaps,columns);
210                 sqlSelect.append(")");
211                 if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
212                     sqlSelect.append(" AS "+sqlSelectItem.getAlias());
213                 }
214             }
215         }
216
217         //解析from
346eb2 218         if(sqlSelectQuery.getFrom() != null){
C 219             out.delete(0, out.length()) ;
220             sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
221             sqlWhere.append(" FROM "+out);
222         }
c64e12 223
C 224         //解析where
346eb2 225         if(sqlSelectQuery.getWhere() != null){
C 226             out.delete(0, out.length()) ;
227             sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
228             sqlWhere.append(" WHERE "+out+" ");
229         }
c64e12 230
C 231         if(sqlSelectQuery.getGroupBy() != null){
232             out.delete(0, out.length()) ;
233             sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
234             sqlWhere.append(" "+out);
235         }
236         if(sqlSelectQuery.getOrderBy() != null){
237             out.delete(0, out.length()) ;
238             sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
239             sqlWhere.append(" "+out);
240         }
241         if(sqlSelectQuery.getLimit() != null){
242             out.delete(0, out.length()) ;
243             sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
244             sqlWhere.append(" "+out);
245         }
246
247         sql = sqlWhere.toString();
248         if(!StringUtils.isEmpty(sql)){
249             Map<String,String> aesKeys = null;
250             String aeskey = null;
251             //把剩下的拼接上来
252             String tableAl = null;
253
254             for(TableStat.Column column:columns){
255                 aesKeys= aesKeysTable.get(column.getTable());
256                 if(aesKeys == null){
257                     continue;
258                 }
259                 aeskey = aesKeys.getOrDefault(column.getName(),null);
260                 if(StringUtils.isEmpty(aeskey)){
261                     continue;
262                 }
263                 tableAl = tableMaps.get(column.getTable());
264                 if(!StringUtils.isEmpty(tableAl)){
265                     tableAl = tableAl+"."+column.getName();
266                 }else{
267                     tableAl = column.getName();
268                 }
5c933d 269                 sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 270             }
C 271         }
272         return sqlSelect.toString()+sql;
273     }
274
275     /**新增加密数据处理
276      * @param sql sql语句
277      * @param aesKeysTable aes秘钥
278      * @return
279      */
346eb2 280     public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
C 281         //装载重写的sql语句
282         StringBuilder splicingSql = new StringBuilder();
c64e12 283
346eb2 284         sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
C 285         String[] datas = sql.split("VALUES",2);
c64e12 286
346eb2 287         splicingSql.append(datas[0]+"VALUES ");
c64e12 288
346eb2 289         //重新拼接SQL语句
c64e12 290
346eb2 291         //解析sql语句
C 292         MySqlStatementParser parser = new MySqlStatementParser(sql);
293         SQLStatement statement = parser.parseStatement();
294         MySqlInsertStatement insert = (MySqlInsertStatement)statement;
c64e12 295
346eb2 296         String insertName = insert.getTableName().getSimpleName();
c64e12 297
346eb2 298         //根据表名称获取到AES秘钥
C 299         Map<String,String> aesKeys= aesKeysTable.get(insertName);
300         if(aesKeys == null){
301             return sql;
302         }
c64e12 303
346eb2 304         //获取所有的字段
C 305         List<SQLExpr> columns = insert.getColumns();
c64e12 306
346eb2 307         String fildValue = null;
C 308         String aeskey = null;
309         //遍历值
310         List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
311         for(int j = 0; j<vcl.size(); j++){
312             if( j != 0){
313                 splicingSql.append(",");
314             }
315             for(int i = 0;i < columns.size();i++){
316                 //查询改字段是否需要加密
317                 aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
318                 fildValue = vcl.get(j).getValues().get(i).toString();
319                 if(i == 0){
320                     splicingSql.append("(");
321                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
322                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
323                     }else{
324                         splicingSql.append(fildValue);
325                     }
326                 }else if(i == columns.size()-1){
327                     splicingSql.append(",");
328                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
329                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
330                     }else{
331                         splicingSql.append(fildValue);
332                     }
333                     splicingSql.append(")");
334                 }else{
335                     splicingSql.append(",");
336                     if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
337                         splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
338                     }else{
339                         splicingSql.append(fildValue);
340                     }
341                 }
342             }
343         }
344         return splicingSql.toString();
345     }
c64e12 346
C 347     /**更新加密数据处理
348      * @param sql sql语句
349      * @param aesKeysTable aes秘钥
350      * @return
351      */
352     public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){
346eb2 353
c64e12 354         //装载重写的sql语句
C 355         StringBuilder splicingSql = new StringBuilder();
356
357         //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
358         MySqlStatementParser parser = new MySqlStatementParser(sql);
359         SQLStatement sqlStatement = parser.parseStatement();
360         //获取格式化的slq语句
361         sql = sqlStatement.toString();
362
363         MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement;
364
365         String insertName = updateStatement.getTableName().getSimpleName();
366
367         String[] datas = sql.split("WHERE",2);
368
369         Map<String,String> aesKeys = aesKeysTable.get(insertName);
346eb2 370         if(aesKeys == null){
C 371             return sql;
372         }
c64e12 373
C 374         splicingSql.append("UPDATE "+insertName+" SET ");
375         String aeskey = null;
376         String fildValue = null;
377         List<SQLUpdateSetItem> items = updateStatement.getItems();
378         for(int i = 0;i<items.size();i++){
379             if(i != 0){
380                 splicingSql.append(",");
381             }
382             SQLUpdateSetItem item = items.get(i);
383             //查询改字段是否需要加密
384             aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null);
385
386             fildValue = item.getValue().toString();
387             if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
388                 splicingSql.append(item.getColumn()+" = HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
389             }else{
390                 splicingSql.append(item.getColumn()+" = "+fildValue);
391             }
392         }
393         String sqlWhere = " WHERE";
394         //把剩下的拼接上来
395         if(datas.length > 1){
396             for(int i =1;i<datas.length;i++){
397                 sqlWhere = sqlWhere+datas[i];
398             }
399
400             parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
401             sqlStatement = parser.parseStatement();
402
403             ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
404             sqlStatement.accept(visitorTable);
405
406             //获取表和别名
407             Map<String,String> tableMaps = visitorTable.getTableMap();
408             tableMaps.put(insertName,null);
409
410             //获取所有的字段
411             MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
412             sqlStatement.accept(visitor);
413
414             String tableAl = null;
415             //遍历所有字段
416             Collection<TableStat.Column> columns= visitor.getColumns();
417             for(TableStat.Column column:columns){
418
419                 aesKeys= aesKeysTable.get(column.getTable());
420                 if(aesKeys == null){
421                     continue;
422                 }
423                 aeskey = aesKeys.getOrDefault(column.getName(),null);
424                 if(StringUtils.isEmpty(aeskey)){
425                     continue;
426                 }
427                 tableAl = tableMaps.get(column.getTable());
428                 if(!StringUtils.isEmpty(tableAl)){
429                     tableAl = tableAl+"."+column.getName();
430                 }else{
431                     tableAl = column.getName();
432                 }
5c933d 433                 sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 434             }
C 435
436         }
326104 437         splicingSql.append(sqlWhere);
c64e12 438         return splicingSql.toString();
C 439     }
440
441     /**删除加密数据处理
442      * @param sql sql语句
443      * @param aesKeysTable aes秘钥
444      * @return
445      */
446     public static String deleteSql(String sql,Map<String,Map<String,String>> aesKeysTable){
447         //装载重写的sql语句
448         StringBuilder splicingSql = new StringBuilder();
449
450         //sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
451         MySqlStatementParser parser = new MySqlStatementParser(sql);
452         SQLStatement sqlStatement = parser.parseStatement();
453         //获取格式化的slq语句
454         sql = sqlStatement.toString();
455
456         MySqlDeleteStatement deleteStatement = (MySqlDeleteStatement)sqlStatement;
457
458         String insertName = deleteStatement.getTableName().getSimpleName();
459
460         String[] datas = sql.split("WHERE",2);
461
462         Map<String,String> aesKeys = aesKeysTable.get(insertName);
346eb2 463         if(aesKeys == null){
C 464             return sql;
465         }
c64e12 466
C 467         splicingSql.append("DELETE FROM "+insertName);
468
469         String aeskey = null;
470
471         String sqlWhere = " WHERE";
472         //把剩下的拼接上来
473         if(datas.length > 1){
474             for(int i =1;i<datas.length;i++){
475                 sqlWhere = sqlWhere+datas[i];
476             }
477
478             parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
479             sqlStatement = parser.parseStatement();
480
481             ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
482             sqlStatement.accept(visitorTable);
483
484             //获取表和别名
485             Map<String,String> tableMaps = visitorTable.getTableMap();
486             tableMaps.put(insertName,null);
487
488             //获取所有的字段
489             MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
490             sqlStatement.accept(visitor);
491
492             String tableAl = null;
493             //遍历所有字段
494             Collection<TableStat.Column> columns= visitor.getColumns();
495             for(TableStat.Column column:columns){
496
497                 aesKeys= aesKeysTable.get(column.getTable());
498                 if(aesKeys == null){
499                     continue;
500                 }
501                 aeskey = aesKeys.getOrDefault(column.getName(),null);
502                 if(StringUtils.isEmpty(aeskey)){
503                     continue;
504                 }
505                 tableAl = tableMaps.get(column.getTable());
506                 if(!StringUtils.isEmpty(tableAl)){
507                     tableAl = tableAl+"."+column.getName();
508                 }else{
509                     tableAl = column.getName();
510                 }
5c933d 511                 sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
c64e12 512             }
C 513
514         }
326104 515         splicingSql.append(sqlWhere);
c64e12 516         return splicingSql.toString();
C 517     }
518
519 }