ChenJiaHe
2022-01-22 f12d2c4c92dcd7929bce4d2df74f48f8abc08350
src/main/java/com/hx/mybatis/aes/springbean/VariableAesKey.java
New file
@@ -0,0 +1,289 @@
package com.hx.mybatis.aes.springbean;
import com.gitee.sunchenbin.mybatis.actable.annotation.Table;
import com.hx.common.annotations.MysqlHexAes;
import com.hx.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLDecoder;
import java.util.*;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
/**
 * 获取指定包里面的AES秘钥
 */
@Component
public class VariableAesKey {
    //log4j日志
    private static Logger logger = LoggerFactory.getLogger(VariableAesKey.class.getName());
    @Resource
    private ConstantBean constantBean;
    /**是否已经启动完*/
    public static int isRun = 0;
    /**存储所有AES的秘钥*/
    public static Map<String,String> aesKeys = new HashMap<>();
    /**根据表明来存储AES秘钥*/
    public static Map<String,Map<String,String>> aesKeysTable = new HashMap<>();
    /**固定的aes秘钥*/
    public static String AES_KEY = null;
    /**存储AES秘钥*/
    public static void setAesKey(String aesKeyFild,String aesKey){
        aesKeys.put(aesKeyFild,aesKey);
    }
    /**获取AES秘钥*/
    public static String getAesKey(String aesKeyFild){
        if(aesKeyFild == null){
            return AES_KEY;
        }
        if(StringUtils.isEmpty(aesKeys.get(aesKeyFild))){
            return AES_KEY;
        }else {
            return  aesKeys.get(aesKeyFild);
        }
    }
    /**
     * 项目启动就执行后就执行该方法
     */
    @PostConstruct
    public void VariableAesKey(){
        isRun = 1;
        //项目启动的时候填入
        logger.info("扫描获取AES的包:" + constantBean.getPackPath());
        AES_KEY = constantBean.getFixedAesKey();
        if(!StringUtils.isEmpty(constantBean.getPackPath())){
            Set<Class<?>> classes = classData(constantBean.getPackPath());
            logger.info("扫描获取AES的包classes:" + classes.size());
            Map<String,String> aesKeysFild = new HashMap<>();
            boolean isAes = false;
            String tableName = null;
            for(Class<?> cl:classes){
                //表名称
                boolean hasAnnotation = cl.isAnnotationPresent(Table.class);
                if(!hasAnnotation){
                    continue;
                }
                Table table = cl.getAnnotation(Table.class);
                tableName = table.name();
                aesKeysFild = new HashMap<>();
                isAes = false;
                // 取得本类的全部属性
                Field[] fields = cl.getDeclaredFields();
                fields = getPatentFields(fields,cl);
                for (Field field:fields) {
                    // 判断方法中是否有指定注解类型的注解
                    hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class);
                    if (hasAnnotation) {
                        // 根据注解类型返回方法的指定类型注解
                        MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class);
                        //String aesKeyField = mysqlHexAes.aesKeyField();
                        String aesKey = mysqlHexAes.aesKey();
                        if(StringUtils.isEmpty(aesKey)){
                            aesKey = constantBean.getFixedAesKey();
                            if(StringUtils.isEmpty(aesKey)){
                                throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName());
                            }
                        }
                        String key = aesKeys.get(field.getName());
                        if(StringUtils.isEmpty(key)){
                            aesKeys.put(field.getName(),aesKey);
                            aesKeysFild.put(field.getName(),aesKey);
                            isAes = true;
                        }else{
                            isAes = true;
                            aesKeysFild.put(field.getName(),aesKey);
                            if(!aesKey.equals(key)){
                                throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
                            }
                        }
                    }
                }
                if(isAes){
                    aesKeysTable.put(tableName,aesKeysFild);
                }
            }
        }
    }
    /**获取包下面的所有文件*/
    public static Set<Class<?>> classData(String packPath){
        Set<Class<?>> classes = new LinkedHashSet();
        String[] split = packPath.split(",|;");
        String[] var3 = split;
        int var4 = split.length;
        label82:
        for(int var5 = 0; var5 < var4; ++var5) {
            String pack = var3[var5];
            boolean recursive = true;
            String packageName = pack;
            String packageDirName = pack.replace('.', '/');
            try {
                Enumeration dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
                while(true) {
                    label75:
                    while(true) {
                        if (!dirs.hasMoreElements()) {
                            continue label82;
                        }
                        URL url = (URL)dirs.nextElement();
                        String protocol = url.getProtocol();
                        if ("file".equals(protocol)) {
                            System.err.println("file类型的扫描:" + pack);
                            String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                            findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
                        } else if ("jar".equals(protocol)) {
                            System.err.println("jar类型的扫描");
                            try {
                                JarFile jar = ((JarURLConnection)url.openConnection()).getJarFile();
                                Enumeration entries = jar.entries();
                                while(true) {
                                    JarEntry entry;
                                    String name;
                                    int idx;
                                    do {
                                        do {
                                            if (!entries.hasMoreElements()) {
                                                continue label75;
                                            }
                                            entry = (JarEntry)entries.nextElement();
                                            name = entry.getName();
                                            if (name.charAt(0) == '/') {
                                                name = name.substring(1);
                                            }
                                        } while(!name.startsWith(packageDirName));
                                        idx = name.lastIndexOf(47);
                                        if (idx != -1) {
                                            packageName = name.substring(0, idx).replace('/', '.');
                                        }
                                    } while(idx == -1 && !recursive);
                                    if (name.endsWith(".class") && !entry.isDirectory()) {
                                        String className = name.substring(packageName.length() + 1, name.length() - 6);
                                        try {
                                            classes.add(Class.forName(packageName + '.' + className));
                                        } catch (ClassNotFoundException var20) {
                                            var20.printStackTrace();
                                        }
                                    }
                                }
                            } catch (IOException var21) {
                                var21.printStackTrace();
                            }
                        }
                    }
                }
            } catch (IOException var22) {
                var22.printStackTrace();
            }
        }
        return classes;
    }
    /**
     * 以文件的形式来获取包下的所有Class
     *
     * @param packageName
     * @param packagePath
     * @param recursive
     * @param classes
     */
    public static void findAndAddClassesInPackageByFile(
            String packageName,
            String packagePath,
            final boolean recursive,
            Set<Class<?>> classes){
        // 获取此包的目录 建立一个File
        File dir = new File(packagePath);
        // 如果不存在或者 也不是目录就直接返回
        if (!dir.exists() || !dir.isDirectory()) {
            // log.warn("用户定义包名 " + packageName + " 下没有任何文件");
            return;
        }
        // 如果存在 就获取包下的所有文件 包括目录
        File[] dirfiles = dir.listFiles(new FileFilter(){
            // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
            @Override
            public boolean accept(File file){
                return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
            }
        });
        // 循环所有文件
        for (File file : dirfiles){
            // 如果是目录 则继续扫描
            if (file.isDirectory()) {
                findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes);
            }else{
                // 如果是java类文件 去掉后面的.class 只留下类名
                String className = file.getName().substring(0, file.getName().length() - 6);
                try{
                    // 添加到集合中去
                    // classes.add(Class.forName(packageName + '.' +
                    // className));
                    // 经过回复同学的提醒,这里用forName有一些不好,会触发static方法,没有使用classLoader的load干净
                    classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className));
                }catch (ClassNotFoundException e){
                    // log.error("添加用户自定义视图类错误 找不到此类的.class文件");
                    e.printStackTrace();
                }
            }
        }
    }
    /**
     * 获取父类的字段
     * @param fields
     * @param clas
     * @return
     */
    public static Field[] getPatentFields(Field[] fields,Class<?> clas){
        if (clas.getSuperclass() != null) {
            Class clsSup = clas.getSuperclass();
            List<Field> fieldList = new ArrayList<Field>();
            fieldList.addAll(Arrays.asList(fields));
            fieldList.addAll(Arrays.asList(clsSup.getDeclaredFields()));
            fields = new Field[fieldList.size()];
            int i = 0;
            for (Object field : fieldList.toArray()) {
                fields[i] = (Field) field;
                i++;
            }
            fields = getPatentFields(fields,clsSup);
        }
        return  fields;
    }
}