chenjiahe
2022-01-13 c64e1248bfda3ac8c5120e529fd096dfc4846629
src/main/java/com/hx/mybatis/aes/springbean/VariableAesKey.java
File was renamed from src/main/java/com/hx/springbean/VariableAesKey.java
@@ -1,8 +1,8 @@
package com.hx.springbean;
package com.hx.mybatis.aes.springbean;
import com.gitee.sunchenbin.mybatis.actable.annotation.Table;
import com.hx.common.annotations.MysqlHexAes;
import com.hx.util.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
@@ -24,8 +24,16 @@
    @Resource
    private ConstantBean constantBean;
    /**存储AES的秘钥*/
    /**是否已经启动完*/
    public static int isRun = 0;
    /**存储所有AES的秘钥*/
    public static Map<String,String> aesKeys = new HashMap<>();
    /**根据表明来存储AES秘钥*/
    public static Map<String,Map<String,String>> aesKeysTable = new HashMap<>();
    /**固定的aes秘钥*/
    public static String AES_KEY = null;
    /**存储AES秘钥*/
    public static void setAesKey(String aesKeyFild,String aesKey){
@@ -33,7 +41,14 @@
    }
    /**获取AES秘钥*/
    public static String getAesKey(String aesKeyFild){
        return aesKeys.get(aesKeyFild);
        if(aesKeyFild == null){
            return AES_KEY;
        }
        if(StringUtils.isEmpty(aesKeys.get(aesKeyFild))){
            return AES_KEY;
        }else {
            return  aesKeys.get(aesKeyFild);
        }
    }
    /**
@@ -41,11 +56,29 @@
     */
    @PostConstruct
    public void VariableAesKey(){
        isRun = 1;
        //项目启动的时候填入
        System.err.println("扫描获取AES:" + constantBean.getPackPath());
        AES_KEY = constantBean.getFixedAesKey();
        if(StringUtils.noNull(constantBean.getPackPath())){
            Set<Class<?>> classes = classData(constantBean.getPackPath());
            Map<String,String> aesKeysFild = new HashMap<>();
            boolean isAes = false;
            String tableName = null;
            for(Class<?> cl:classes){
                //表名称
                Table table = cl.getAnnotation(Table.class);
                if(table == null){
                    continue;
                }
                tableName = table.name();
                aesKeysFild = new HashMap<>();
                isAes = false;
                // 取得本类的全部属性
                Field[] fields = cl.getDeclaredFields();
                fields = getPatentFields(fields,cl);
@@ -56,32 +89,31 @@
                        // 根据注解类型返回方法的指定类型注解
                        MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class);
                        String aesKeyField = mysqlHexAes.aesKeyField();
                        //String aesKeyField = mysqlHexAes.aesKeyField();
                        String aesKey = mysqlHexAes.aesKey();
                        if(StringUtils.isEmpty(aesKey)){
                            throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName());
                        }
                        if(StringUtils.noNull(aesKeyField)){
                            String key = aesKeys.get(aesKeyField);
                            if(StringUtils.isEmpty(key)){
                                aesKeys.put(aesKeyField,aesKey);
                            }else{
                                if(!aesKey.equals(key)){
                                    throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
                                }
                            aesKey = constantBean.getFixedAesKey();
                            if(StringUtils.isEmpty(aesKey)){
                                throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName());
                            }
                        }
                        String key = aesKeys.get(field.getName());
                        if(StringUtils.isEmpty(key)){
                            aesKeys.put(field.getName(),aesKey);
                            aesKeysFild.put(field.getName(),aesKey);
                            isAes = true;
                        }else{
                            String key = aesKeys.get(field.getName());
                            if(StringUtils.isEmpty(key)){
                                aesKeys.put(field.getName(),aesKey);
                            }else{
                                if(!aesKey.equals(key)){
                                    throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
                                }
                            isAes = true;
                            aesKeysFild.put(field.getName(),aesKey);
                            if(!aesKey.equals(key)){
                                throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
                            }
                        }
                    }
                }
                if(isAes){
                    aesKeysTable.put(tableName,aesKeysFild);
                }
            }
        }
@@ -90,31 +122,35 @@
    /**获取包下面的所有文件*/
    public static Set<Class<?>> classData(String packPath){
        Set<Class<?>> classes = new LinkedHashSet<Class<?>>();
        // 是否循环迭代
        boolean recursive = true;
        // 获取包的名字 并进行替换
        String packageName = packPath;
        String packageDirName = packageName.replace('.', '/');
        // 定义一个枚举的集合 并进行循环来处理这个目录下的things
        Enumeration<URL> dirs;
        try{
            dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
            // 循环迭代下去
            while (dirs.hasMoreElements()){
                // 获取下一个元素
                URL url = dirs.nextElement();
                // 得到协议的名称
                String protocol = url.getProtocol();
                // 如果是以文件的形式保存在服务器上
                if ("file".equals(protocol)) {
                    // 获取包的物理路径
                    String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                    // 以文件的方式扫描整个包下的文件 并添加到集合中
                    findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
        //截取
        String[] packPaths = packPath.split(";|,");
        for( String packageName : packPaths){
            // 是否循环迭代
            boolean recursive = true;
            // 获取包的名字 并进行替换
            String packageDirName = packageName.replace('.', '/');
            // 定义一个枚举的集合 并进行循环来处理这个目录下的things
            Enumeration<URL> dirs;
            try{
                dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
                // 循环迭代下去
                while (dirs.hasMoreElements()){
                    // 获取下一个元素
                    URL url = dirs.nextElement();
                    // 得到协议的名称
                    String protocol = url.getProtocol();
                    // 如果是以文件的形式保存在服务器上
                    if ("file".equals(protocol)) {
                        // 获取包的物理路径
                        String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                        // 以文件的方式扫描整个包下的文件 并添加到集合中
                        findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
                    }
                }
            }catch (IOException e){
                e.printStackTrace();
            }
        }catch (IOException e){
            e.printStackTrace();
        }
        return classes;
    }