src/main/java/com/hx/common/annotations/MysqlHexAes.java
New file @@ -0,0 +1,22 @@ package com.hx.common.annotations; import java.lang.annotation.*; /** * 指定mysql的AES加密字段 * @author CJH */ @Target({ElementType.FIELD}) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface MysqlHexAes { /**秘钥-没有就是配置文件设置*/ String aesKey() default ""; /**xml生成查询解密*/ boolean selectDec() default false; /**xml更新加密*/ boolean updateDec() default false; /**xml新增加密*/ boolean insertDec() default false; } src/main/java/com/hx/common/dao/CommonDao.java
@@ -104,4 +104,11 @@ */ <T extends Serializable> int deleteById(Class<?> mapperClass, Object object); /**更新sql语句(全语句) * @param sqlSentence 查询参数类 * @return 返回条数 */ <T extends Serializable> int updateSentence( SqlSentence sqlSentence); } src/main/java/com/hx/common/dao/CommonMapper.java
New file @@ -0,0 +1,17 @@ package com.hx.common.dao; import com.hx.mybatisTool.SqlSentence; import java.io.Serializable; import java.util.List; import java.util.Map; public interface CommonMapper { /**更新,返回更新数量*/ int updateSentence(SqlSentence sqlSentence); /**查询列表,返回Map的List*/ List<Map<String,Object>> selectListMap(SqlSentence sqlSentence); } src/main/java/com/hx/common/service/CommonService.java
@@ -105,4 +105,11 @@ */ <T extends Serializable> int deleteById(Class<?> mapperClass, Object object); /**更新sql语句(全语句) * @param sqlSentence 查询参数类 * @return 返回条数 */ <T extends Serializable> int updateSentence(SqlSentence sqlSentence); } src/main/java/com/hx/common/service/impl/CommonDaoImpl.java
@@ -1,6 +1,7 @@ package com.hx.common.service.impl; import com.hx.common.dao.CommonDao; import com.hx.common.dao.CommonMapper; import com.hx.mybatisTool.SqlSentence; import org.apache.ibatis.session.SqlSessionFactory; import org.springframework.stereotype.Service; @@ -102,4 +103,11 @@ return sqlSessionFactory.openSession().delete(getStatement(mapperClass,"deleteById"),object); } /**更新sql语句(全语句)*/ @Override public <T extends Serializable> int updateSentence(SqlSentence sqlSentence) { return sqlSessionFactory.openSession().delete(getStatement(CommonMapper.class,"updateSentence"),sqlSentence); } } src/main/java/com/hx/common/service/impl/CommonServiceImpl.java
@@ -97,4 +97,12 @@ return commonDao.deleteById(mapperClass,object); } /**更新sql语句(全语句)*/ @Override public <T extends Serializable> int updateSentence(SqlSentence sqlSentence) { return commonDao.updateSentence(sqlSentence); } } src/main/java/com/hx/common/xml/CommonMapper.xml
New file @@ -0,0 +1,14 @@ <?xml version="1.0" encoding="UTF-8" ?> <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> <!-- namespace:该mapper.xml映射文件的 唯一标识 --> <mapper namespace="com.hx.common.dao.CommonMapper"> <update id="updateSentence" parameterType="com.hx.mybatisTool.SqlSentence" > ${sqlSentence} </update> <select id="selectListMap" resultType="java.util.Map" parameterType="com.hx.mybatisTool.SqlSentence" > ${sqlSentence} </select> </mapper> src/main/java/com/hx/mybatis/aes/handler/GenericStringHandler.java
New file @@ -0,0 +1,72 @@ package com.hx.mybatis.aes.handler; import com.hx.mybatis.aes.springbean.VariableAesKey; import com.hx.util.mysql.aes.MysqlHexAesTool; import org.apache.ibatis.type.BaseTypeHandler; import org.apache.ibatis.type.JdbcType; import org.apache.ibatis.type.MappedJdbcTypes; import org.apache.ibatis.type.MappedTypes; import java.sql.CallableStatement; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; /** * @author CJH * @Date 2021-01-02 * // @MappedTypes注解中的类代表此转换器可以自动转换为的java对象,@MappedJdbcTypes注解中设置的是对应的jdbctype,mysql的json对象对应的jdbctype为VARCHAR。 */ @MappedTypes(value = {String.class}) @MappedJdbcTypes(value = {JdbcType.VARCHAR}, includeNullJdbcType = true) public class GenericStringHandler extends BaseTypeHandler<String> { public GenericStringHandler() { } @Override public void setNonNullParameter(PreparedStatement ps, int i, String parameter, JdbcType jdbcType) throws SQLException { ps.setString(i, parameter); } @Override public String getNullableResult(ResultSet rs, String columnName) throws SQLException { String data = rs.getString(columnName); if(data != null && data.length()%32==0 && MysqlHexAesTool.isHexStrValid(data)){ try{ data = MysqlHexAesTool.decryptData(data, VariableAesKey.getAesKey(columnName),null); }catch (Exception e){ //e.printStackTrace(); } } return data; } @Override public String getNullableResult(ResultSet rs, int columnIndex) throws SQLException { String data = rs.getString(columnIndex); if(data != null && data.length()%32==0 && MysqlHexAesTool.isHexStrValid(data)){ try{ data = MysqlHexAesTool.decryptData(data, VariableAesKey.getAesKey(null),null); }catch (Exception e){ //e.printStackTrace(); } } return data; } @Override public String getNullableResult(CallableStatement cs, int columnIndex) throws SQLException { String data = cs.getString(columnIndex); if(data != null && data.length() < 129 && data.length()%32==0 && MysqlHexAesTool.isHexStrValid(data)){ try{ data = MysqlHexAesTool.decryptData(data, VariableAesKey.getAesKey(null),null); }catch (Exception e){ //e.printStackTrace(); } } return data; } } src/main/java/com/hx/mybatis/aes/springbean/ConstantBean.java
New file @@ -0,0 +1,35 @@ package com.hx.mybatis.aes.springbean; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; /** * 通用常量集中营 * @author CJH */ @Component public class ConstantBean { /**获取AES秘钥的配置(从什么包获取到)*/ @Value("${mysql.hxe.aes.find.packs:null}") private String packPath; /**固定AES的秘钥*/ @Value("${mysql.hxe.aes.fixd.key:null}") private String fixedAesKey; public String getPackPath() { return packPath; } public void setPackPath(String packPath) { this.packPath = packPath; } public String getFixedAesKey() { return fixedAesKey; } public void setFixedAesKey(String fixedAesKey) { this.fixedAesKey = fixedAesKey; } } src/main/java/com/hx/mybatis/aes/springbean/ExportTableAliasVisitor.java
New file @@ -0,0 +1,53 @@ package com.hx.mybatis.aes.springbean; import com.alibaba.druid.sql.ast.statement.SQLExprTableSource; import com.alibaba.druid.sql.visitor.SQLASTVisitorAdapter; import java.util.HashMap; import java.util.Map; /** * ExportTableAliasVisitor * @author Mwg * @date 2020/09/08 23:47 */ public class ExportTableAliasVisitor extends SQLASTVisitorAdapter { private Map<String,String> tableMap = new HashMap<>(); public Map<String, String> getTableMap() { return tableMap; } public void setTableMap(Map<String, String> tableMap) { this.tableMap = tableMap; } @Override public boolean visit(SQLExprTableSource x) { //别名,如果有别名,别名保持不变 //System.out.println("alias:"+x.getAlias());//别名 //System.out.println("expr:"+x.getExpr());//表名 tableMap.put(x.getExpr().toString(),x.getAlias()); //String s = StringUtils.isEmpty(x.getAlias()) ? x.getExpr().toString() : x.getAlias(); // 修改表名,不包含点才加 select id from c left join d on c.id = d.id 中的c 和 d /*if(!x.getExpr().toString().contains(".")) { x.setExpr("`" + dbName.get() + "`." + x.getExpr()); }*/ //x.setExpr("mymymytable");//修改表名 //x.setAlias("aa");//修改别名 //x.setAlias(s); return true; } } src/main/java/com/hx/mybatis/aes/springbean/InitMysqlData.java
New file @@ -0,0 +1,108 @@ package com.hx.mybatis.aes.springbean; import com.gitee.sunchenbin.mybatis.actable.annotation.Column; import com.gitee.sunchenbin.mybatis.actable.annotation.Table; import com.hx.common.annotations.MysqlHexAes; import com.hx.common.dao.CommonMapper; import com.hx.common.service.CommonService; import com.hx.exception.ServiceException; import com.hx.mybatisTool.SqlSentence; import com.hx.util.StringUtils; import com.hx.util.mysql.aes.MysqlHexAesTool; import javax.annotation.PostConstruct; import java.lang.reflect.Field; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; public class InitMysqlData { /** * 项目启动就执行后就执行该方法 */ @PostConstruct public static void initData(String packPath, CommonService commonService){ //项目启动的时候填入 if(!StringUtils.isEmpty(packPath)){ Set<Class<?>> classes = VariableAesKey.classData(packPath); Map<String,String> aesKeysFild = new HashMap<>(); boolean isAes = false; String tableName = null; String fildName = null; String fildValue = null; SqlSentence sqlSentence = new SqlSentence(); Map<String,Object> values = new HashMap<>(); for(Class<?> cl:classes){ //表名称 boolean hasAnnotation = cl.isAnnotationPresent(Table.class); if(!hasAnnotation){ continue; } Table table = cl.getAnnotation(Table.class); tableName = table.name(); aesKeysFild = new HashMap<>(); isAes = false; // 取得本类的全部属性 Field[] fields = cl.getDeclaredFields(); fields = VariableAesKey.getPatentFields(fields,cl); for (Field field:fields) { fildName = null; // 判断方法中是否有指定注解类型的注解 hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class); if (hasAnnotation) { // 根据注解类型返回方法的指定类型注解 MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class); //String aesKeyField = mysqlHexAes.aesKeyField(); String aesKey = mysqlHexAes.aesKey(); if(StringUtils.isEmpty(aesKey)){ aesKey = VariableAesKey.AES_KEY; if(StringUtils.isEmpty(aesKey)){ throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName()); } } hasAnnotation = field.isAnnotationPresent(Column.class); if(hasAnnotation){ Column column = field.getAnnotation(Column.class); fildName = column.name(); } if(StringUtils.isEmpty(fildName)){ fildName = field.getName(); } sqlSentence.sqlSentence("SELECT id,"+fildName+" FROM "+tableName,values); List<Map<String,Object>> list = commonService.selectListMap(CommonMapper.class,sqlSentence); for(Map<String,Object> map:list){ fildValue = (String)map.get(fildName); System.out.println("fildValue:"+fildValue); if(StringUtils.isEmpty(fildValue)){ continue; } if(fildValue.length()%32==0 && MysqlHexAesTool.isHexStrValid(fildValue)){ continue; } values.clear(); values.put("id",map.get("id")); values.put("filedData",fildValue); sqlSentence.sqlSentence("UPDATE "+tableName+" SET "+fildName+" = #{m.filedData} WHERE id = #{m.id}",values); if(commonService.updateSentence(sqlSentence)!=1){ throw new ServiceException("更新超过1条,更新失败!"); } } } } } } } } src/main/java/com/hx/mybatis/aes/springbean/MySqlInterceptor.java
New file @@ -0,0 +1,97 @@ package com.hx.mybatis.aes.springbean; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.ParameterMapping; import org.apache.ibatis.mapping.SqlCommandType; import org.apache.ibatis.plugin.*; import org.apache.ibatis.reflection.DefaultReflectorFactory; import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.SystemMetaObject; import org.apache.ibatis.session.Configuration; import org.springframework.stereotype.Component; import java.lang.reflect.Field; import java.sql.Connection; import java.util.List; import java.util.Properties; @Component @Intercepts({ @Signature( type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class }) }) public class MySqlInterceptor implements Interceptor { @Override public Object intercept(Invocation invocation) throws Throwable { // 方法一 StatementHandler statementHandler = (StatementHandler) invocation.getTarget(); MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY, SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory()); //先拦截到RoutingStatementHandler,里面有个StatementHandler类型的delegate变量,其实现类是BaseStatementHandler,然后就到BaseStatementHandler的成员变量mappedStatement MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement"); //id为执行的mapper方法的全路径名,如com.uv.dao.UserMapper.insertUser //String id = mappedStatement.getId(); //sql语句类型 select、delete、insert、update SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType(); BoundSql boundSql = statementHandler.getBoundSql(); // 获取节点的配置 Configuration configuration = mappedStatement.getConfiguration(); // 获取参数 Object parameterObject = boundSql.getParameterObject(); // MetaObject主要是封装了originalObject对象,提供了get和set的方法用于获取和设置originalObject的属性值,主要支持对JavaBean、Collection、Map三种类型对象的操作 // MetaObject metaObject1 = configuration.newMetaObject(parameterObject); //获取sql中问号的基本信息 List<ParameterMapping> parameterMappings = boundSql .getParameterMappings(); /*for (ParameterMapping parameterMapping : parameterMappings) { String propertyName = parameterMapping.getProperty(); System.out.println("propertyName:"+ propertyName); System.out.println("parameterObject:"+ parameterObject); }*/ //这里可以进行sql修改 //获取到原始sql语句 String sql = boundSql.getSql(); //新增 if(sqlCommandType == SqlCommandType.INSERT){ sql = SqlUtils.insertSql(sql, VariableAesKey.aesKeysTable); }else if(sqlCommandType == SqlCommandType.UPDATE){ sql = SqlUtils.updateSql(sql, VariableAesKey.aesKeysTable); }else if(sqlCommandType == SqlCommandType.SELECT){ if(VariableAesKey.isRun == 1){ sql = SqlUtils.selectSql(sql, VariableAesKey.aesKeysTable); } }else if(sqlCommandType == SqlCommandType.DELETE){ sql = SqlUtils.deleteSql(sql, VariableAesKey.aesKeysTable); } //通过反射修改sql语句 Field field = boundSql.getClass().getDeclaredField("sql"); field.setAccessible(true); field.set(boundSql, sql); return invocation.proceed(); } @Override public Object plugin(Object target) { if (target instanceof StatementHandler) { return Plugin.wrap(target, this); } else { return target; } } @Override public void setProperties(Properties properties) { } } src/main/java/com/hx/mybatis/aes/springbean/SqlUtils.java
New file @@ -0,0 +1,500 @@ package com.hx.mybatis.aes.springbean; import com.alibaba.druid.sql.SQLUtils; import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.SQLObject; import com.alibaba.druid.sql.ast.SQLStatement; import com.alibaba.druid.sql.ast.statement.*; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement; import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor; import com.alibaba.druid.sql.visitor.SQLASTOutputVisitor; import com.alibaba.druid.stat.TableStat; import com.alibaba.druid.util.JdbcConstants; import com.alibaba.druid.util.JdbcUtils; import com.hx.util.StringUtils; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; /** * sql语句处理工具 * @author CJH 2022-01-12 */ public class SqlUtils { /**查询加密数据处理,只对查询做处理,select返回不做处理 * @param sql sql语句 * @param aesKeysTable aes秘钥 * @return */ public static String selectSql(String sql,Map<String,Map<String,String>> aesKeysTable){ MySqlStatementParser parser = new MySqlStatementParser(sql); SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect(); //获取格式化的slq语句 sql = sqlStatement.toString(); //解析select查询 //SQLSelect sqlSelect = sqlStatement.getSelect() //获取sql查询块 SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ; StringBuffer out = new StringBuffer() ; //创建sql解析的标准化输出 SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ; //获取表和别名 ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor(); sqlStatement.accept(visitorTable); Map<String,String> tableMaps = visitorTable.getTableMap(); //获取所有的字段 MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor(); sqlStatement.accept(visitor); //遍历所有字段 Collection<TableStat.Column> columns= visitor.getColumns(); StringBuilder sqlWhere = new StringBuilder(); StringBuilder sqlSelect = new StringBuilder(); String expr = null; sqlSelect.append("SELECT "); //解析select返回的数据字段项 for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) { if(sqlSelect.length() > 7){ sqlSelect.append(","); } out.delete(0, out.length()) ; sqlSelectItem.accept(sqlastOutputVisitor) ; expr = out.toString(); sqlSelect.append(expr); /* if(expr.indexOf("SELECT") == -1){ sqlSelect.append(expr); }else{ //sqlSelect.append("("); sqlSelect.append(expr); //sqlSelect.append(")"); *//* if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){ sqlSelect.append(" AS "+sqlSelectItem.getAlias()); }*//* }*/ } //解析from if(sqlSelectQuery.getFrom() != null){ out.delete(0, out.length()) ; sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ; sqlWhere.append(" FROM "+out); } //解析where if(sqlSelectQuery.getWhere() != null){ out.delete(0, out.length()) ; sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ; sqlWhere.append(" WHERE "+out+" "); } if(sqlSelectQuery.getGroupBy() != null){ out.delete(0, out.length()) ; sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ; sqlWhere.append(" "+out); } if(sqlSelectQuery.getOrderBy() != null){ out.delete(0, out.length()) ; sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ; sqlWhere.append(" "+out); } if(sqlSelectQuery.getLimit() != null){ out.delete(0, out.length()) ; sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ; sqlWhere.append(" "+out); } //处理where需要加密得字段 sql = sqlWhere.toString(); if(!StringUtils.isEmpty(sql)){ Map<String,String> aesKeys = null; String aeskey = null; //把剩下的拼接上来 String tableAl = null; for(TableStat.Column column:columns){ aesKeys= aesKeysTable.get(column.getTable()); if(aesKeys == null){ continue; } aeskey = aesKeys.getOrDefault(column.getName(),null); if(StringUtils.isEmpty(aeskey)){ continue; } tableAl = tableMaps.get(column.getTable()); if(!StringUtils.isEmpty(tableAl)){ tableAl = tableAl+"."+column.getName(); }else{ tableAl = column.getName(); } sql = sql.replaceAll("\\b"+tableAl+"\\b","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')"); } } return sqlSelect.toString()+sql; } /** * 处理select返回字段的参数 * @param sql * @param aesKeysTable * @param tableMaps * @param columns * @return */ public static String selectSqlHandle(String sql,Map<String,Map<String,String>> aesKeysTable ,Map<String,String> tableMaps,Collection<TableStat.Column> columns){ MySqlStatementParser parser = new MySqlStatementParser(sql); SQLSelectStatement sqlStatement = (SQLSelectStatement) parser.parseSelect(); //获取格式化的slq语句 sql = sqlStatement.toString(); //解析select查询 //SQLSelect sqlSelect = sqlStatement.getSelect() ; //获取sql查询块 SQLSelectQueryBlock sqlSelectQuery = (SQLSelectQueryBlock)sqlStatement.getSelect().getQuery() ; StringBuffer out = new StringBuffer() ; //创建sql解析的标准化输出 SQLASTOutputVisitor sqlastOutputVisitor = SQLUtils.createFormatOutputVisitor(out , null , JdbcUtils.MYSQL) ; StringBuilder sqlWhere = new StringBuilder(); StringBuilder sqlSelect = new StringBuilder(); String expr = null; sqlSelect.append("SELECT "); //解析select返回的数据字段项 for (SQLSelectItem sqlSelectItem : sqlSelectQuery.getSelectList()) { if(sqlSelect.length() > 7){ sqlSelect.append(","); } expr = sqlSelectItem.getExpr().toString(); if(expr.indexOf("SELECT") == -1){ sqlSelect.append(expr); if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){ sqlSelect.append(" AS "+sqlSelectItem.getAlias()); } }else{ sqlSelect.append("("); selectSqlHandle(expr,aesKeysTable,tableMaps,columns); sqlSelect.append(")"); if(!StringUtils.isEmpty(sqlSelectItem.getAlias())){ sqlSelect.append(" AS "+sqlSelectItem.getAlias()); } } } //解析from if(sqlSelectQuery.getFrom() != null){ out.delete(0, out.length()) ; sqlSelectQuery.getFrom().accept(sqlastOutputVisitor) ; sqlWhere.append(" FROM "+out); } //解析where if(sqlSelectQuery.getWhere() != null){ out.delete(0, out.length()) ; sqlSelectQuery.getWhere().accept(sqlastOutputVisitor) ; sqlWhere.append(" WHERE "+out+" "); } if(sqlSelectQuery.getGroupBy() != null){ out.delete(0, out.length()) ; sqlSelectQuery.getGroupBy().accept(sqlastOutputVisitor) ; sqlWhere.append(" "+out); } if(sqlSelectQuery.getOrderBy() != null){ out.delete(0, out.length()) ; sqlSelectQuery.getOrderBy().accept(sqlastOutputVisitor) ; sqlWhere.append(" "+out); } if(sqlSelectQuery.getLimit() != null){ out.delete(0, out.length()) ; sqlSelectQuery.getLimit().accept(sqlastOutputVisitor) ; sqlWhere.append(" "+out); } sql = sqlWhere.toString(); if(!StringUtils.isEmpty(sql)){ Map<String,String> aesKeys = null; String aeskey = null; //把剩下的拼接上来 String tableAl = null; for(TableStat.Column column:columns){ aesKeys= aesKeysTable.get(column.getTable()); if(aesKeys == null){ continue; } aeskey = aesKeys.getOrDefault(column.getName(),null); if(StringUtils.isEmpty(aeskey)){ continue; } tableAl = tableMaps.get(column.getTable()); if(!StringUtils.isEmpty(tableAl)){ tableAl = tableAl+"."+column.getName(); }else{ tableAl = column.getName(); } sql = sql.replaceAll("\\b"+tableAl+"\\b","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')"); } } return sqlSelect.toString()+sql; } /**新增加密数据处理 * @param sql sql语句 * @param aesKeysTable aes秘钥 * @return */ public static String insertSql(String sql,Map<String,Map<String,String>> aesKeysTable){ //装载重写的sql语句 StringBuilder splicingSql = new StringBuilder(); sql = SQLUtils.format(sql, JdbcConstants.MYSQL); String[] datas = sql.split("VALUES",2); splicingSql.append(datas[0]+"VALUES "); //重新拼接SQL语句 //解析sql语句 MySqlStatementParser parser = new MySqlStatementParser(sql); SQLStatement statement = parser.parseStatement(); MySqlInsertStatement insert = (MySqlInsertStatement)statement; String insertName = insert.getTableName().getSimpleName(); //根据表名称获取到AES秘钥 Map<String,String> aesKeys= aesKeysTable.get(insertName); if(aesKeys == null){ return sql; } //获取所有的字段 List<SQLExpr> columns = insert.getColumns(); String fildValue = null; String aeskey = null; //遍历值 List<SQLInsertStatement.ValuesClause> vcl = insert.getValuesList(); for(int j = 0; j<vcl.size(); j++){ if( j != 0){ splicingSql.append(","); } for(int i = 0;i < columns.size();i++){ //查询改字段是否需要加密 aeskey = aesKeys.getOrDefault(columns.get(i).toString(),null); fildValue = vcl.get(j).getValues().get(i).toString(); if(i == 0){ splicingSql.append("("); if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){ splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))"); }else{ splicingSql.append(fildValue); } }else if(i == columns.size()-1){ splicingSql.append(","); if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){ splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))"); }else{ splicingSql.append(fildValue); } splicingSql.append(")"); }else{ splicingSql.append(","); if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){ splicingSql.append("HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))"); }else{ splicingSql.append(fildValue); } } } } return splicingSql.toString(); } /**更新加密数据处理 * @param sql sql语句 * @param aesKeysTable aes秘钥 * @return */ public static String updateSql(String sql,Map<String,Map<String,String>> aesKeysTable){ //装载重写的sql语句 StringBuilder splicingSql = new StringBuilder(); //sql = SQLUtils.format(sql, JdbcConstants.MYSQL); MySqlStatementParser parser = new MySqlStatementParser(sql); SQLStatement sqlStatement = parser.parseStatement(); //获取格式化的slq语句 sql = sqlStatement.toString(); MySqlUpdateStatement updateStatement = (MySqlUpdateStatement)sqlStatement; String insertName = updateStatement.getTableName().getSimpleName(); String[] datas = sql.split("WHERE",2); Map<String,String> aesKeys = aesKeysTable.get(insertName); if(aesKeys == null){ return sql; } splicingSql.append("UPDATE "+insertName+" SET "); String aeskey = null; String fildValue = null; List<SQLUpdateSetItem> items = updateStatement.getItems(); for(int i = 0;i<items.size();i++){ if(i != 0){ splicingSql.append(","); } SQLUpdateSetItem item = items.get(i); //查询改字段是否需要加密 aeskey = aesKeys.getOrDefault(item.getColumn().toString(),null); fildValue = item.getValue().toString(); if(aeskey != null && fildValue.indexOf("AES_ENCRYPT") == -1){ splicingSql.append(item.getColumn()+" = HEX(AES_ENCRYPT("+fildValue+",'"+aeskey+"'))"); }else{ splicingSql.append(item.getColumn()+" = "+fildValue); } } String sqlWhere = " WHERE"; //把剩下的拼接上来 if(datas.length > 1){ for(int i =1;i<datas.length;i++){ sqlWhere = sqlWhere+datas[i]; } parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere); sqlStatement = parser.parseStatement(); ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor(); sqlStatement.accept(visitorTable); //获取表和别名 Map<String,String> tableMaps = visitorTable.getTableMap(); tableMaps.put(insertName,null); //获取所有的字段 MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor(); sqlStatement.accept(visitor); String tableAl = null; //遍历所有字段 Collection<TableStat.Column> columns= visitor.getColumns(); for(TableStat.Column column:columns){ aesKeys= aesKeysTable.get(column.getTable()); if(aesKeys == null){ continue; } aeskey = aesKeys.getOrDefault(column.getName(),null); if(StringUtils.isEmpty(aeskey)){ continue; } tableAl = tableMaps.get(column.getTable()); if(!StringUtils.isEmpty(tableAl)){ tableAl = tableAl+"."+column.getName(); }else{ tableAl = column.getName(); } sqlWhere = sqlWhere.replaceAll("\\b"+tableAl+"\\b","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')"); } } splicingSql.append(sqlWhere.toString()); return splicingSql.toString(); } /**删除加密数据处理 * @param sql sql语句 * @param aesKeysTable aes秘钥 * @return */ public static String deleteSql(String sql,Map<String,Map<String,String>> aesKeysTable){ //装载重写的sql语句 StringBuilder splicingSql = new StringBuilder(); //sql = SQLUtils.format(sql, JdbcConstants.MYSQL); MySqlStatementParser parser = new MySqlStatementParser(sql); SQLStatement sqlStatement = parser.parseStatement(); //获取格式化的slq语句 sql = sqlStatement.toString(); MySqlDeleteStatement deleteStatement = (MySqlDeleteStatement)sqlStatement; String insertName = deleteStatement.getTableName().getSimpleName(); String[] datas = sql.split("WHERE",2); Map<String,String> aesKeys = aesKeysTable.get(insertName); if(aesKeys == null){ return sql; } splicingSql.append("DELETE FROM "+insertName); String aeskey = null; String sqlWhere = " WHERE"; //把剩下的拼接上来 if(datas.length > 1){ for(int i =1;i<datas.length;i++){ sqlWhere = sqlWhere+datas[i]; } parser = new MySqlStatementParser("SELECT * FROM "+insertName+" "+sqlWhere); sqlStatement = parser.parseStatement(); ExportTableAliasVisitor visitorTable = new ExportTableAliasVisitor(); sqlStatement.accept(visitorTable); //获取表和别名 Map<String,String> tableMaps = visitorTable.getTableMap(); tableMaps.put(insertName,null); //获取所有的字段 MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor(); sqlStatement.accept(visitor); String tableAl = null; //遍历所有字段 Collection<TableStat.Column> columns= visitor.getColumns(); for(TableStat.Column column:columns){ aesKeys= aesKeysTable.get(column.getTable()); if(aesKeys == null){ continue; } aeskey = aesKeys.getOrDefault(column.getName(),null); if(StringUtils.isEmpty(aeskey)){ continue; } tableAl = tableMaps.get(column.getTable()); if(!StringUtils.isEmpty(tableAl)){ tableAl = tableAl+"."+column.getName(); }else{ tableAl = column.getName(); } sqlWhere = sqlWhere.replaceAll("\\b"+tableAl+"\\b","AES_DECRYPT(UNHEX("+tableAl+"),'"+aeskey+"')"); } } splicingSql.append(sqlWhere.toString()); return splicingSql.toString(); } } src/main/java/com/hx/mybatis/aes/springbean/VariableAesKey.java
New file @@ -0,0 +1,289 @@ package com.hx.mybatis.aes.springbean; import com.gitee.sunchenbin.mybatis.actable.annotation.Table; import com.hx.common.annotations.MysqlHexAes; import com.hx.util.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Component; import javax.annotation.PostConstruct; import javax.annotation.Resource; import java.io.File; import java.io.FileFilter; import java.io.IOException; import java.lang.reflect.Field; import java.net.JarURLConnection; import java.net.URL; import java.net.URLDecoder; import java.util.*; import java.util.jar.JarEntry; import java.util.jar.JarFile; /** * 获取指定包里面的AES秘钥 */ @Component public class VariableAesKey { //log4j日志 private static Logger logger = LoggerFactory.getLogger(VariableAesKey.class.getName()); @Resource private ConstantBean constantBean; /**是否已经启动完*/ public static int isRun = 0; /**存储所有AES的秘钥*/ public static Map<String,String> aesKeys = new HashMap<>(); /**根据表明来存储AES秘钥*/ public static Map<String,Map<String,String>> aesKeysTable = new HashMap<>(); /**固定的aes秘钥*/ public static String AES_KEY = null; /**存储AES秘钥*/ public static void setAesKey(String aesKeyFild,String aesKey){ aesKeys.put(aesKeyFild,aesKey); } /**获取AES秘钥*/ public static String getAesKey(String aesKeyFild){ if(aesKeyFild == null){ return AES_KEY; } if(StringUtils.isEmpty(aesKeys.get(aesKeyFild))){ return AES_KEY; }else { return aesKeys.get(aesKeyFild); } } /** * 项目启动就执行后就执行该方法 */ @PostConstruct public void VariableAesKey(){ isRun = 1; //项目启动的时候填入 logger.info("扫描获取AES的包:" + constantBean.getPackPath()); AES_KEY = constantBean.getFixedAesKey(); if(!StringUtils.isEmpty(constantBean.getPackPath())){ Set<Class<?>> classes = classData(constantBean.getPackPath()); logger.info("扫描获取AES的包classes:" + classes.size()); Map<String,String> aesKeysFild = new HashMap<>(); boolean isAes = false; String tableName = null; for(Class<?> cl:classes){ //表名称 boolean hasAnnotation = cl.isAnnotationPresent(Table.class); if(!hasAnnotation){ continue; } Table table = cl.getAnnotation(Table.class); tableName = table.name(); aesKeysFild = new HashMap<>(); isAes = false; // 取得本类的全部属性 Field[] fields = cl.getDeclaredFields(); fields = getPatentFields(fields,cl); for (Field field:fields) { // 判断方法中是否有指定注解类型的注解 hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class); if (hasAnnotation) { // 根据注解类型返回方法的指定类型注解 MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class); //String aesKeyField = mysqlHexAes.aesKeyField(); String aesKey = mysqlHexAes.aesKey(); if(StringUtils.isEmpty(aesKey)){ aesKey = constantBean.getFixedAesKey(); if(StringUtils.isEmpty(aesKey)){ throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName()); } } String key = aesKeys.get(field.getName()); if(StringUtils.isEmpty(key)){ aesKeys.put(field.getName(),aesKey); aesKeysFild.put(field.getName(),aesKey); isAes = true; }else{ isAes = true; aesKeysFild.put(field.getName(),aesKey); if(!aesKey.equals(key)){ throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样"); } } } } if(isAes){ aesKeysTable.put(tableName,aesKeysFild); } } } } /**获取包下面的所有文件*/ public static Set<Class<?>> classData(String packPath){ Set<Class<?>> classes = new LinkedHashSet(); String[] split = packPath.split(",|;"); String[] var3 = split; int var4 = split.length; label82: for(int var5 = 0; var5 < var4; ++var5) { String pack = var3[var5]; boolean recursive = true; String packageName = pack; String packageDirName = pack.replace('.', '/'); try { Enumeration dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName); while(true) { label75: while(true) { if (!dirs.hasMoreElements()) { continue label82; } URL url = (URL)dirs.nextElement(); String protocol = url.getProtocol(); if ("file".equals(protocol)) { System.err.println("file类型的扫描:" + pack); String filePath = URLDecoder.decode(url.getFile(), "UTF-8"); findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes); } else if ("jar".equals(protocol)) { System.err.println("jar类型的扫描"); try { JarFile jar = ((JarURLConnection)url.openConnection()).getJarFile(); Enumeration entries = jar.entries(); while(true) { JarEntry entry; String name; int idx; do { do { if (!entries.hasMoreElements()) { continue label75; } entry = (JarEntry)entries.nextElement(); name = entry.getName(); if (name.charAt(0) == '/') { name = name.substring(1); } } while(!name.startsWith(packageDirName)); idx = name.lastIndexOf(47); if (idx != -1) { packageName = name.substring(0, idx).replace('/', '.'); } } while(idx == -1 && !recursive); if (name.endsWith(".class") && !entry.isDirectory()) { String className = name.substring(packageName.length() + 1, name.length() - 6); try { classes.add(Class.forName(packageName + '.' + className)); } catch (ClassNotFoundException var20) { var20.printStackTrace(); } } } } catch (IOException var21) { var21.printStackTrace(); } } } } } catch (IOException var22) { var22.printStackTrace(); } } return classes; } /** * 以文件的形式来获取包下的所有Class * * @param packageName * @param packagePath * @param recursive * @param classes */ public static void findAndAddClassesInPackageByFile( String packageName, String packagePath, final boolean recursive, Set<Class<?>> classes){ // 获取此包的目录 建立一个File File dir = new File(packagePath); // 如果不存在或者 也不是目录就直接返回 if (!dir.exists() || !dir.isDirectory()) { // log.warn("用户定义包名 " + packageName + " 下没有任何文件"); return; } // 如果存在 就获取包下的所有文件 包括目录 File[] dirfiles = dir.listFiles(new FileFilter(){ // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件) @Override public boolean accept(File file){ return (recursive && file.isDirectory()) || (file.getName().endsWith(".class")); } }); // 循环所有文件 for (File file : dirfiles){ // 如果是目录 则继续扫描 if (file.isDirectory()) { findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes); }else{ // 如果是java类文件 去掉后面的.class 只留下类名 String className = file.getName().substring(0, file.getName().length() - 6); try{ // 添加到集合中去 // classes.add(Class.forName(packageName + '.' + // className)); // 经过回复同学的提醒,这里用forName有一些不好,会触发static方法,没有使用classLoader的load干净 classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className)); }catch (ClassNotFoundException e){ // log.error("添加用户自定义视图类错误 找不到此类的.class文件"); e.printStackTrace(); } } } } /** * 获取父类的字段 * @param fields * @param clas * @return */ public static Field[] getPatentFields(Field[] fields,Class<?> clas){ if (clas.getSuperclass() != null) { Class clsSup = clas.getSuperclass(); List<Field> fieldList = new ArrayList<Field>(); fieldList.addAll(Arrays.asList(fields)); fieldList.addAll(Arrays.asList(clsSup.getDeclaredFields())); fields = new Field[fieldList.size()]; int i = 0; for (Object field : fieldList.toArray()) { fields[i] = (Field) field; i++; } fields = getPatentFields(fields,clsSup); } return fields; } } src/main/java/com/hx/util/mysql/aes/MysqlHexAesTool.java
File was renamed from src/main/java/com/hx/util/mysql/aes/MysqlHexAes.java @@ -14,7 +14,7 @@ * @author CJH * @Date 2021-01-06 */ public class MysqlHexAes { public class MysqlHexAesTool { public static SecretKeySpec generateMySQLAESKey(final String key, final String encoding) {