wangrenhuang
2022-03-02 1ffbaa7b82f99443d933b0f7d0add0b35d2db01c
src/main/java/com/hx/mybatis/aes/springbean/VariableAesKey.java
@@ -3,6 +3,8 @@
import com.gitee.sunchenbin.mybatis.actable.annotation.Table;
import com.hx.common.annotations.MysqlHexAes;
import com.hx.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
@@ -11,15 +13,21 @@
import java.io.FileFilter;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLDecoder;
import java.util.*;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
/**
 * 获取指定包里面的AES秘钥
 */
@Component
public class VariableAesKey {
    //log4j日志
    private static Logger logger = LoggerFactory.getLogger(VariableAesKey.class.getName());
    @Resource
    private ConstantBean constantBean;
@@ -34,6 +42,10 @@
    /**固定的aes秘钥*/
    public static String AES_KEY = null;
    /**数据库加密字段初始化版本号*/
    public static String INIT_VERSION = null;
    /**存储AES秘钥*/
    public static void setAesKey(String aesKeyFild,String aesKey){
@@ -59,21 +71,23 @@
        isRun = 1;
        //项目启动的时候填入
        System.err.println("扫描获取AES:" + constantBean.getPackPath());
        logger.info("扫描获取AES的包:" + constantBean.getPackPath());
        AES_KEY = constantBean.getFixedAesKey();
        if(StringUtils.noNull(constantBean.getPackPath())){
        INIT_VERSION = constantBean.getInitVersion();
        if(!StringUtils.isEmpty(constantBean.getPackPath())){
            Set<Class<?>> classes = classData(constantBean.getPackPath());
            logger.info("扫描获取AES的包classes:" + classes.size());
            Map<String,String> aesKeysFild = new HashMap<>();
            boolean isAes = false;
            String tableName = null;
            for(Class<?> cl:classes){
                //表名称
                Table table = cl.getAnnotation(Table.class);
                if(table == null){
                boolean hasAnnotation = cl.isAnnotationPresent(Table.class);
                if(!hasAnnotation){
                    continue;
                }
                Table table = cl.getAnnotation(Table.class);
                tableName = table.name();
                aesKeysFild = new HashMap<>();
@@ -84,7 +98,7 @@
                fields = getPatentFields(fields,cl);
                for (Field field:fields) {
                    // 判断方法中是否有指定注解类型的注解
                    boolean hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class);
                    hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class);
                    if (hasAnnotation) {
                        // 根据注解类型返回方法的指定类型注解
                        MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class);
@@ -121,37 +135,85 @@
    /**获取包下面的所有文件*/
    public static Set<Class<?>> classData(String packPath){
        Set<Class<?>> classes = new LinkedHashSet<Class<?>>();
        Set<Class<?>> classes = new LinkedHashSet();
        String[] split = packPath.split(",|;");
        String[] var3 = split;
        int var4 = split.length;
        //截取
        String[] packPaths = packPath.split(";|,");
        for( String packageName : packPaths){
            // 是否循环迭代
        label82:
        for(int var5 = 0; var5 < var4; ++var5) {
            String pack = var3[var5];
            boolean recursive = true;
            // 获取包的名字 并进行替换
            String packageDirName = packageName.replace('.', '/');
            // 定义一个枚举的集合 并进行循环来处理这个目录下的things
            Enumeration<URL> dirs;
            try{
                dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
                // 循环迭代下去
                while (dirs.hasMoreElements()){
                    // 获取下一个元素
                    URL url = dirs.nextElement();
                    // 得到协议的名称
                    String protocol = url.getProtocol();
                    // 如果是以文件的形式保存在服务器上
                    if ("file".equals(protocol)) {
                        // 获取包的物理路径
                        String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                        // 以文件的方式扫描整个包下的文件 并添加到集合中
                        findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
            String packageName = pack;
            String packageDirName = pack.replace('.', '/');
            try {
                Enumeration dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
                while(true) {
                    label75:
                    while(true) {
                        if (!dirs.hasMoreElements()) {
                            continue label82;
                        }
                        URL url = (URL)dirs.nextElement();
                        String protocol = url.getProtocol();
                        if ("file".equals(protocol)) {
                            System.err.println("file类型的扫描:" + pack);
                            String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                            findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
                        } else if ("jar".equals(protocol)) {
                            System.err.println("jar类型的扫描");
                            try {
                                JarFile jar = ((JarURLConnection)url.openConnection()).getJarFile();
                                Enumeration entries = jar.entries();
                                while(true) {
                                    JarEntry entry;
                                    String name;
                                    int idx;
                                    do {
                                        do {
                                            if (!entries.hasMoreElements()) {
                                                continue label75;
                                            }
                                            entry = (JarEntry)entries.nextElement();
                                            name = entry.getName();
                                            if (name.charAt(0) == '/') {
                                                name = name.substring(1);
                                            }
                                        } while(!name.startsWith(packageDirName));
                                        idx = name.lastIndexOf(47);
                                        if (idx != -1) {
                                            packageName = name.substring(0, idx).replace('/', '.');
                                        }
                                    } while(idx == -1 && !recursive);
                                    if (name.endsWith(".class") && !entry.isDirectory()) {
                                        String className = name.substring(packageName.length() + 1, name.length() - 6);
                                        try {
                                            classes.add(Class.forName(packageName + '.' + className));
                                        } catch (ClassNotFoundException var20) {
                                            var20.printStackTrace();
                                        }
                                    }
                                }
                            } catch (IOException var21) {
                                var21.printStackTrace();
                            }
                        }
                    }
                }
            }catch (IOException e){
                e.printStackTrace();
            } catch (IOException var22) {
                var22.printStackTrace();
            }
        }
        return classes;
    }