package com.hx.mybatis.aes.springbean;
|
|
import com.alibaba.druid.sql.SQLUtils;
|
import com.alibaba.druid.sql.ast.SQLExpr;
|
import com.alibaba.druid.sql.ast.SQLObject;
|
import com.alibaba.druid.sql.ast.SQLStatement;
|
import com.alibaba.druid.sql.ast.statement.*;
|
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
|
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
|
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
|
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
|
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
|
import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor;
|
import com.alibaba.druid.stat.TableStat;
|
import com.alibaba.druid.util.JdbcConstants;
|
import com.alibaba.druid.util.JdbcUtils;
|
import com.hx.util.StringUtils;
|
|
import java.util.ArrayList;
|
import java.util.Collection;
|
import java.util.List;
|
import java.util.Map;
|
|
/**
|
* sql语句处理工具
|
* @author CJH 2022-01-12
|
*/
|
public class SqlUtils {
|
|
/**查询加密数据处理,只对查询做处理,select返回不做处理
|
* @param sql sql语句
|
* @param aesKeysTable aes秘钥
|
* @return
|
*/
|
public static String selectSql(String sql,Map<String,Map<String,String>> aesKeysTable){
|
|
MySqlStatementParser parser = new MySqlStatementParser(sql);
|
SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
|
//获取格式化的slq语句
|
sql = sqlStatement.toString();
|
|
//解析select查询
|
//SQLSelect sqlSelect = sqlStatement.getSelect()
|
//获取sql查询块
|
SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
|
StringBuffer out = new StringBuffer() ;
|
//创建sql解析的标准化输出
|
SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
|
|
//获取表和别名
|
ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
|
sqlStatement.accept(visitorTable);
|
Map<String,String> tableMaps = visitorTable.getTableMap();
|
|
//获取所有的字段
|
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
|
sqlStatement.accept(visitor);
|
//遍历所有字段
|
Collection<TableStat.Column> columns= visitor.getColumns();
|
|
StringBuilder sqlWhere = new StringBuilder();
|
|
StringBuilder sqlSelect = new StringBuilder();
|
String expr = null;
|
sqlSelect.append("SELECT ");
|
//解析select返回的数据字段项
|
for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
|
if(sqlSelect.length() > 7){
|
sqlSelect.append(",");
|
}
|
|
out.delete(0, out.length()) ;
|
sqlSelectItem.accept(sqlastOutputVisitor) ;
|
expr = out.toString();
|
sqlSelect.append(expr);
|
|
/* if(expr.indexOf("SELECT") == -1){
|
sqlSelect.append(expr);
|
}else{
|
//sqlSelect.append("(");
|
sqlSelect.append(expr);
|
//sqlSelect.append(")");
|
*//* if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
|
sqlSelect.append(" AS "+sqlSelectItem.getAlias());
|
}*//*
|
}*/
|
}
|
|
//解析from
|
if(sqlSelectQuery.getFrom() != null){
|
out.delete(0, out.length()) ;
|
sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
|
sqlWhere.append(" FROM "+out);
|
}
|
|
//解析where
|
if(sqlSelectQuery.getWhere() != null){
|
out.delete(0, out.length()) ;
|
sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
|
sqlWhere.append(" WHERE "+out+" ");
|
}
|
|
if(sqlSelectQuery.getGroupBy() != null){
|
out.delete(0, out.length()) ;
|
sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
|
sqlWhere.append(" "+out);
|
}
|
if(sqlSelectQuery.getOrderBy() != null){
|
out.delete(0, out.length()) ;
|
sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
|
sqlWhere.append(" "+out);
|
}
|
if(sqlSelectQuery.getLimit() != null){
|
out.delete(0, out.length()) ;
|
sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
|
sqlWhere.append(" "+out);
|
}
|
|
//处理where需要加密得字段
|
sql = sqlWhere.toString();
|
if(!StringUtils.isEmpty(sql)){
|
Map<String,String> aesKeys = null;
|
String aeskey = null;
|
//把剩下的拼接上来
|
String tableAl = null;
|
for(TableStat.Column column:columns){
|
aesKeys= aesKeysTable.get(column.getTable());
|
if(aesKeys == null){
|
continue;
|
}
|
aeskey = aesKeys.getOrDefault(column.getName(),null);
|
if(StringUtils.isEmpty(aeskey)){
|
continue;
|
}
|
tableAl = tableMaps.get(column.getTable());
|
if(!StringUtils.isEmpty(tableAl)){
|
tableAl = tableAl+"."+column.getName();
|
}else{
|
tableAl = column.getName();
|
}
|
sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
|
}
|
}
|
return sqlSelect.toString()+sql;
|
}
|
|
/**
|
* 处理select返回字段的参数
|
* @param sql
|
* @param aesKeysTable
|
* @param tableMaps
|
* @param columns
|
* @return
|
*/
|
public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable
|
,Map<String,String> tableMaps,Collection<TableStat.Column> columns){
|
|
|
MySqlStatementParser parser = new MySqlStatementParser(sql);
|
SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect();
|
//获取格式化的slq语句
|
sql = sqlStatement.toString();
|
|
//解析select查询
|
//SQLSelect sqlSelect = sqlStatement.getSelect() ;
|
//获取sql查询块
|
SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ;
|
StringBuffer out = new StringBuffer() ;
|
//创建sql解析的标准化输出
|
SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ;
|
|
StringBuilder sqlWhere = new StringBuilder();
|
|
StringBuilder sqlSelect = new StringBuilder();
|
String expr = null;
|
sqlSelect.append("SELECT ");
|
//解析select返回的数据字段项
|
for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) {
|
if(sqlSelect.length() > 7){
|
sqlSelect.append(",");
|
}
|
expr = sqlSelectItem.getExpr().toString();
|
if(expr.indexOf("SELECT") == -1){
|
sqlSelect.append(expr);
|
if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
|
sqlSelect.append(" AS "+sqlSelectItem.getAlias());
|
}
|
}else{
|
sqlSelect.append("(");
|
selectSqlHandle(expr,aesKeysTable,tableMaps,columns);
|
sqlSelect.append(")");
|
if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){
|
sqlSelect.append(" AS "+sqlSelectItem.getAlias());
|
}
|
}
|
}
|
|
//解析from
|
if(sqlSelectQuery.getFrom() != null){
|
out.delete(0, out.length()) ;
|
sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ;
|
sqlWhere.append(" FROM "+out);
|
}
|
|
//解析where
|
if(sqlSelectQuery.getWhere() != null){
|
out.delete(0, out.length()) ;
|
sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ;
|
sqlWhere.append(" WHERE "+out+" ");
|
}
|
|
if(sqlSelectQuery.getGroupBy() != null){
|
out.delete(0, out.length()) ;
|
sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ;
|
sqlWhere.append(" "+out);
|
}
|
if(sqlSelectQuery.getOrderBy() != null){
|
out.delete(0, out.length()) ;
|
sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ;
|
sqlWhere.append(" "+out);
|
}
|
if(sqlSelectQuery.getLimit() != null){
|
out.delete(0, out.length()) ;
|
sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ;
|
sqlWhere.append(" "+out);
|
}
|
|
sql = sqlWhere.toString();
|
if(!StringUtils.isEmpty(sql)){
|
Map<String,String> aesKeys = null;
|
String aeskey = null;
|
//把剩下的拼接上来
|
String tableAl = null;
|
|
for(TableStat.Column column:columns){
|
aesKeys= aesKeysTable.get(column.getTable());
|
if(aesKeys == null){
|
continue;
|
}
|
aeskey = aesKeys.getOrDefault(column.getName(),null);
|
if(StringUtils.isEmpty(aeskey)){
|
continue;
|
}
|
tableAl = tableMaps.get(column.getTable());
|
if(!StringUtils.isEmpty(tableAl)){
|
tableAl = tableAl+"."+column.getName();
|
}else{
|
tableAl = column.getName();
|
}
|
sql = sql.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
|
}
|
}
|
return sqlSelect.toString()+sql;
|
}
|
|
/**新增加密数据处理
|
* @param sql sql语句
|
* @param aesKeysTable aes秘钥
|
* @return
|
*/
|
public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){
|
//装载重写的sql语句
|
StringBuilder splicingSql = new StringBuilder();
|
|
sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
|
String[] datas = sql.split("VALUES",2);
|
|
splicingSql.append(datas[0]+"VALUES ");
|
|
//重新拼接SQL语句
|
|
//解析sql语句
|
MySqlStatementParser parser = new MySqlStatementParser(sql);
|
SQLStatement statement = parser.parseStatement();
|
MySqlInsertStatement insert = (MySqlInsertStatement)statement;
|
|
String insertName = insert.getTableName().getSimpleName();
|
|
//根据表名称获取到AES秘钥
|
Map<String,String> aesKeys= aesKeysTable.get(insertName);
|
if(aesKeys == null){
|
return sql;
|
}
|
|
//获取所有的字段
|
List<SQLExpr> columns = insert.getColumns();
|
|
String fildValue = null;
|
String aeskey = null;
|
//遍历值
|
List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList();
|
for(int j = 0; j<vcl.size(); j++){
|
if( j != 0){
|
splicingSql.append(",");
|
}
|
for(int i = 0;i < columns.size();i++){
|
//查询改字段是否需要加密
|
aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null);
|
fildValue = vcl.get(j).getValues().get(i).toString();
|
if(i == 0){
|
splicingSql.append("(");
|
if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
|
splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
|
}else{
|
splicingSql.append(fildValue);
|
}
|
}else if(i == columns.size()-1){
|
splicingSql.append(",");
|
if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
|
splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
|
}else{
|
splicingSql.append(fildValue);
|
}
|
splicingSql.append(")");
|
}else{
|
splicingSql.append(",");
|
if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
|
splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
|
}else{
|
splicingSql.append(fildValue);
|
}
|
}
|
}
|
}
|
return splicingSql.toString();
|
}
|
|
/**更新加密数据处理
|
* @param sql sql语句
|
* @param aesKeysTable aes秘钥
|
* @return
|
*/
|
public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){
|
|
//装载重写的sql语句
|
StringBuilder splicingSql = new StringBuilder();
|
|
//sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
|
MySqlStatementParser parser = new MySqlStatementParser(sql);
|
SQLStatement sqlStatement = parser.parseStatement();
|
//获取格式化的slq语句
|
sql = sqlStatement.toString();
|
|
MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement;
|
|
String insertName = updateStatement.getTableName().getSimpleName();
|
|
String[] datas = sql.split("WHERE",2);
|
|
Map<String,String> aesKeys = aesKeysTable.get(insertName);
|
if(aesKeys == null){
|
return sql;
|
}
|
|
splicingSql.append("UPDATE "+insertName+" SET ");
|
String aeskey = null;
|
String fildValue = null;
|
List<SQLUpdateSetItem> items = updateStatement.getItems();
|
for(int i = 0;i<items.size();i++){
|
if(i != 0){
|
splicingSql.append(",");
|
}
|
SQLUpdateSetItem item = items.get(i);
|
//查询改字段是否需要加密
|
aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null);
|
|
fildValue = item.getValue().toString();
|
if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){
|
splicingSql.append(item.getColumn()+" = HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))");
|
}else{
|
splicingSql.append(item.getColumn()+" = "+fildValue);
|
}
|
}
|
String sqlWhere = " WHERE";
|
//把剩下的拼接上来
|
if(datas.length > 1){
|
for(int i =1;i<datas.length;i++){
|
sqlWhere = sqlWhere+datas[i];
|
}
|
|
parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
|
sqlStatement = parser.parseStatement();
|
|
ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
|
sqlStatement.accept(visitorTable);
|
|
//获取表和别名
|
Map<String,String> tableMaps = visitorTable.getTableMap();
|
tableMaps.put(insertName,null);
|
|
//获取所有的字段
|
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
|
sqlStatement.accept(visitor);
|
|
String tableAl = null;
|
//遍历所有字段
|
Collection<TableStat.Column> columns= visitor.getColumns();
|
for(TableStat.Column column:columns){
|
|
aesKeys= aesKeysTable.get(column.getTable());
|
if(aesKeys == null){
|
continue;
|
}
|
aeskey = aesKeys.getOrDefault(column.getName(),null);
|
if(StringUtils.isEmpty(aeskey)){
|
continue;
|
}
|
tableAl = tableMaps.get(column.getTable());
|
if(!StringUtils.isEmpty(tableAl)){
|
tableAl = tableAl+"."+column.getName();
|
}else{
|
tableAl = column.getName();
|
}
|
sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
|
}
|
|
}
|
splicingSql.append(sqlWhere.toString());
|
return splicingSql.toString();
|
}
|
|
/**删除加密数据处理
|
* @param sql sql语句
|
* @param aesKeysTable aes秘钥
|
* @return
|
*/
|
public static String deleteSql(String sql,Map<String,Map<String,String>> aesKeysTable){
|
//装载重写的sql语句
|
StringBuilder splicingSql = new StringBuilder();
|
|
//sql = SQLUtils.format(sql, JdbcConstants.MYSQL);
|
MySqlStatementParser parser = new MySqlStatementParser(sql);
|
SQLStatement sqlStatement = parser.parseStatement();
|
//获取格式化的slq语句
|
sql = sqlStatement.toString();
|
|
MySqlDeleteStatement deleteStatement = (MySqlDeleteStatement)sqlStatement;
|
|
String insertName = deleteStatement.getTableName().getSimpleName();
|
|
String[] datas = sql.split("WHERE",2);
|
|
Map<String,String> aesKeys = aesKeysTable.get(insertName);
|
if(aesKeys == null){
|
return sql;
|
}
|
|
splicingSql.append("DELETE FROM "+insertName);
|
|
String aeskey = null;
|
|
String sqlWhere = " WHERE";
|
//把剩下的拼接上来
|
if(datas.length > 1){
|
for(int i =1;i<datas.length;i++){
|
sqlWhere = sqlWhere+datas[i];
|
}
|
|
parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere);
|
sqlStatement = parser.parseStatement();
|
|
ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor();
|
sqlStatement.accept(visitorTable);
|
|
//获取表和别名
|
Map<String,String> tableMaps = visitorTable.getTableMap();
|
tableMaps.put(insertName,null);
|
|
//获取所有的字段
|
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
|
sqlStatement.accept(visitor);
|
|
String tableAl = null;
|
//遍历所有字段
|
Collection<TableStat.Column> columns= visitor.getColumns();
|
for(TableStat.Column column:columns){
|
|
aesKeys= aesKeysTable.get(column.getTable());
|
if(aesKeys == null){
|
continue;
|
}
|
aeskey = aesKeys.getOrDefault(column.getName(),null);
|
if(StringUtils.isEmpty(aeskey)){
|
continue;
|
}
|
tableAl = tableMaps.get(column.getTable());
|
if(!StringUtils.isEmpty(tableAl)){
|
tableAl = tableAl+"."+column.getName();
|
}else{
|
tableAl = column.getName();
|
}
|
sqlWhere = sqlWhere.replaceAll("((?<!\\.)\\b"+tableAl+"\\b(?!\\.))","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')");
|
}
|
|
}
|
splicingSql.append(sqlWhere.toString());
|
return splicingSql.toString();
|
}
|
|
}
|