AES
chenjiahe
2022-01-07 e29546af6fcc89c64b17cee6de01ce2a963b33f4
AES
4个文件已添加
305 ■■■■■ 已修改文件
src/main/java/com/hx/common/annotations/MysqlHexAes.java 20 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/mybatis/handler/aes/GenericStringHandler.java 55 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/springbean/ConstantBean.java 34 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/springbean/VariableAesKey.java 196 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
src/main/java/com/hx/common/annotations/MysqlHexAes.java
New file
@@ -0,0 +1,20 @@
package com.hx.common.annotations;
import java.lang.annotation.*;
/**
 * 指定mysql的AES加密字段
 * @author CJH
 */
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MysqlHexAes {
    /**秘钥字段*/
    String aesKeyField() default "";
    /**秘钥*/
    String aesKey();
    /**查询解密*/
    boolean selectDec() default false;
}
src/main/java/com/hx/mybatis/handler/aes/GenericStringHandler.java
New file
@@ -0,0 +1,55 @@
package com.hx.mybatis.handler.aes;
import com.hx.util.mysql.aes.MysqlHexAes;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.MappedJdbcTypes;
import org.apache.ibatis.type.MappedTypes;
import java.sql.CallableStatement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
/**
 * @author CJH
 * @Date 2021-01-02
 * // @MappedTypes注解中的类代表此转换器可以自动转换为的java对象,@MappedJdbcTypes注解中设置的是对应的jdbctype,mysql的json对象对应的jdbctype为VARCHAR。
 */
@MappedTypes(value = {String.class})
@MappedJdbcTypes(value = {JdbcType.VARCHAR}, includeNullJdbcType = true)
public class GenericStringHandler extends BaseTypeHandler<String> {
    public GenericStringHandler() {
    }
    @Override
    public void setNonNullParameter(PreparedStatement ps, int i, String parameter, JdbcType jdbcType) throws SQLException {
        ps.setString(i, parameter);
    }
    @Override
    public String getNullableResult(ResultSet rs, String columnName) throws SQLException {
        String data = rs.getString(columnName);
        if(data != null && data.length()%32==0 && MysqlHexAes.isHexStrValid(data)){
            try{
                data = MysqlHexAes.decryptData(data,"123456",null);
            }catch (Exception e){
                //e.printStackTrace();
            }
        }
        return data;
    }
    @Override
    public String getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
        return rs.getString(columnIndex);
    }
    @Override
    public String getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
        return cs.getString(columnIndex);
    }
}
src/main/java/com/hx/springbean/ConstantBean.java
New file
@@ -0,0 +1,34 @@
package com.hx.springbean;
import com.hx.common.annotations.MysqlHexAes;
import com.hx.util.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URL;
import java.net.URLDecoder;
import java.util.*;
/**
 * 通用常量集中营
 * @author CJH
 */
@Component
public class ConstantBean {
    /**获取AES秘钥的配置(从什么包获取到)*/
    @Value("${mysql.hxe.aex.find.packs:null}")
    private String packPath;
    public String getPackPath() {
        return packPath;
    }
    public void setPackPath(String packPath) {
        this.packPath = packPath;
    }
}
src/main/java/com/hx/springbean/VariableAesKey.java
New file
@@ -0,0 +1,196 @@
package com.hx.springbean;
import com.hx.common.annotations.MysqlHexAes;
import com.hx.util.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URL;
import java.net.URLDecoder;
import java.util.*;
/**
 * 获取指定包里面的AES秘钥
 */
@Component
public class VariableAesKey {
    @Resource
    private ConstantBean constantBean;
    /**存储AES的秘钥*/
    public static Map<String,String> aesKeys = new HashMap<>();
    /**存储AES秘钥*/
    public static void setAesKey(String aesKeyFild,String aesKey){
        aesKeys.put(aesKeyFild,aesKey);
    }
    /**获取AES秘钥*/
    public static String getAesKey(String aesKeyFild){
        return aesKeys.get(aesKeyFild);
    }
    /**
     * 项目启动就执行后就执行该方法
     */
    @PostConstruct
    public void VariableAesKey(){
        //项目启动的时候填入
        System.err.println("扫描获取AES:" + constantBean.getPackPath());
        if(StringUtils.noNull(constantBean.getPackPath())){
            Set<Class<?>> classes = classData(constantBean.getPackPath());
            for(Class<?> cl:classes){
                // 取得本类的全部属性
                Field[] fields = cl.getDeclaredFields();
                fields = getPatentFields(fields,cl);
                for (Field field:fields) {
                    // 判断方法中是否有指定注解类型的注解
                    boolean hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class);
                    if (hasAnnotation) {
                        // 根据注解类型返回方法的指定类型注解
                        MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class);
                        String aesKeyField = mysqlHexAes.aesKeyField();
                        String aesKey = mysqlHexAes.aesKey();
                        if(StringUtils.isEmpty(aesKey)){
                            throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName());
                        }
                        if(StringUtils.noNull(aesKeyField)){
                            String key = aesKeys.get(aesKeyField);
                            if(StringUtils.isEmpty(key)){
                                aesKeys.put(aesKeyField,aesKey);
                            }else{
                                if(!aesKey.equals(key)){
                                    throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
                                }
                            }
                        }else{
                            String key = aesKeys.get(field.getName());
                            if(StringUtils.isEmpty(key)){
                                aesKeys.put(field.getName(),aesKey);
                            }else{
                                if(!aesKey.equals(key)){
                                    throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    /**获取包下面的所有文件*/
    public static Set<Class<?>> classData(String packPath){
        Set<Class<?>> classes = new LinkedHashSet<Class<?>>();
        // 是否循环迭代
        boolean recursive = true;
        // 获取包的名字 并进行替换
        String packageName = packPath;
        String packageDirName = packageName.replace('.', '/');
        // 定义一个枚举的集合 并进行循环来处理这个目录下的things
        Enumeration<URL> dirs;
        try{
            dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
            // 循环迭代下去
            while (dirs.hasMoreElements()){
                // 获取下一个元素
                URL url = dirs.nextElement();
                // 得到协议的名称
                String protocol = url.getProtocol();
                // 如果是以文件的形式保存在服务器上
                if ("file".equals(protocol)) {
                    // 获取包的物理路径
                    String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                    // 以文件的方式扫描整个包下的文件 并添加到集合中
                    findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
                }
            }
        }catch (IOException e){
            e.printStackTrace();
        }
        return classes;
    }
    /**
     * 以文件的形式来获取包下的所有Class
     *
     * @param packageName
     * @param packagePath
     * @param recursive
     * @param classes
     */
    public static void findAndAddClassesInPackageByFile(
            String packageName,
            String packagePath,
            final boolean recursive,
            Set<Class<?>> classes){
        // 获取此包的目录 建立一个File
        File dir = new File(packagePath);
        // 如果不存在或者 也不是目录就直接返回
        if (!dir.exists() || !dir.isDirectory()) {
            // log.warn("用户定义包名 " + packageName + " 下没有任何文件");
            return;
        }
        // 如果存在 就获取包下的所有文件 包括目录
        File[] dirfiles = dir.listFiles(new FileFilter(){
            // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
            @Override
            public boolean accept(File file){
                return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
            }
        });
        // 循环所有文件
        for (File file : dirfiles){
            // 如果是目录 则继续扫描
            if (file.isDirectory()) {
                findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes);
            }else{
                // 如果是java类文件 去掉后面的.class 只留下类名
                String className = file.getName().substring(0, file.getName().length() - 6);
                try{
                    // 添加到集合中去
                    // classes.add(Class.forName(packageName + '.' +
                    // className));
                    // 经过回复同学的提醒,这里用forName有一些不好,会触发static方法,没有使用classLoader的load干净
                    classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className));
                }catch (ClassNotFoundException e){
                    // log.error("添加用户自定义视图类错误 找不到此类的.class文件");
                    e.printStackTrace();
                }
            }
        }
    }
    /**
     * 获取父类的字段
     * @param fields
     * @param clas
     * @return
     */
    public static Field[] getPatentFields(Field[] fields,Class<?> clas){
        if (clas.getSuperclass() != null) {
            Class clsSup = clas.getSuperclass();
            List<Field> fieldList = new ArrayList<Field>();
            fieldList.addAll(Arrays.asList(fields));
            fieldList.addAll(Arrays.asList(clsSup.getDeclaredFields()));
            fields = new Field[fieldList.size()];
            int i = 0;
            for (Object field : fieldList.toArray()) {
                fields[i] = (Field) field;
                i++;
            }
            fields = getPatentFields(fields,clsSup);
        }
        return  fields;
    }
}