AES
chenjiahe
2022-01-07 e29546af6fcc89c64b17cee6de01ce2a963b33f4
提交 | 用户 | age
e29546 1 package com.hx.springbean;
C 2
3 import com.hx.common.annotations.MysqlHexAes;
4 import com.hx.util.StringUtils;
5 import org.springframework.beans.factory.annotation.Value;
6 import org.springframework.stereotype.Component;
7
8 import javax.annotation.PostConstruct;
9 import javax.annotation.Resource;
10 import java.io.File;
11 import java.io.FileFilter;
12 import java.io.IOException;
13 import java.lang.reflect.Field;
14 import java.net.URL;
15 import java.net.URLDecoder;
16 import java.util.*;
17
18 /**
19  * 获取指定包里面的AES秘钥
20  */
21 @Component
22 public class VariableAesKey {
23
24     @Resource
25     private ConstantBean constantBean;
26
27     /**存储AES的秘钥*/
28     public static Map<String,String> aesKeys = new HashMap<>();
29
30     /**存储AES秘钥*/
31     public static void setAesKey(String aesKeyFild,String aesKey){
32         aesKeys.put(aesKeyFild,aesKey);
33     }
34     /**获取AES秘钥*/
35     public static String getAesKey(String aesKeyFild){
36         return aesKeys.get(aesKeyFild);
37     }
38
39     /**
40      * 项目启动就执行后就执行该方法
41      */
42     @PostConstruct
43     public void VariableAesKey(){
44         //项目启动的时候填入
45         System.err.println("扫描获取AES:" + constantBean.getPackPath());
46         if(StringUtils.noNull(constantBean.getPackPath())){
47             Set<Class<?>> classes = classData(constantBean.getPackPath());
48             for(Class<?> cl:classes){
49                 // 取得本类的全部属性
50                 Field[] fields = cl.getDeclaredFields();
51                 fields = getPatentFields(fields,cl);
52                 for (Field field:fields) {
53                     // 判断方法中是否有指定注解类型的注解
54                     boolean hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class);
55                     if (hasAnnotation) {
56                         // 根据注解类型返回方法的指定类型注解
57                         MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class);
58
59                         String aesKeyField = mysqlHexAes.aesKeyField();
60                         String aesKey = mysqlHexAes.aesKey();
61
62                         if(StringUtils.isEmpty(aesKey)){
63                             throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName());
64                         }
65                         if(StringUtils.noNull(aesKeyField)){
66                             String key = aesKeys.get(aesKeyField);
67                             if(StringUtils.isEmpty(key)){
68                                 aesKeys.put(aesKeyField,aesKey);
69                             }else{
70                                 if(!aesKey.equals(key)){
71                                     throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
72                                 }
73                             }
74                         }else{
75                             String key = aesKeys.get(field.getName());
76                             if(StringUtils.isEmpty(key)){
77                                 aesKeys.put(field.getName(),aesKey);
78                             }else{
79                                 if(!aesKey.equals(key)){
80                                     throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
81                                 }
82                             }
83                         }
84                     }
85                 }
86             }
87         }
88     }
89
90     /**获取包下面的所有文件*/
91     public static Set<Class<?>> classData(String packPath){
92         Set<Class<?>> classes = new LinkedHashSet<Class<?>>();
93         // 是否循环迭代
94         boolean recursive = true;
95         // 获取包的名字 并进行替换
96         String packageName = packPath;
97         String packageDirName = packageName.replace('.', '/');
98         // 定义一个枚举的集合 并进行循环来处理这个目录下的things
99         Enumeration<URL> dirs;
100         try{
101             dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
102             // 循环迭代下去
103             while (dirs.hasMoreElements()){
104                 // 获取下一个元素
105                 URL url = dirs.nextElement();
106                 // 得到协议的名称
107                 String protocol = url.getProtocol();
108                 // 如果是以文件的形式保存在服务器上
109                 if ("file".equals(protocol)) {
110                     // 获取包的物理路径
111                     String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
112                     // 以文件的方式扫描整个包下的文件 并添加到集合中
113                     findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
114                 }
115             }
116         }catch (IOException e){
117             e.printStackTrace();
118         }
119         return classes;
120     }
121
122     /**
123      * 以文件的形式来获取包下的所有Class
124      *
125      * @param packageName
126      * @param packagePath
127      * @param recursive
128      * @param classes
129      */
130     public static void findAndAddClassesInPackageByFile(
131             String packageName,
132             String packagePath,
133             final boolean recursive,
134             Set<Class<?>> classes){
135         // 获取此包的目录 建立一个File
136         File dir = new File(packagePath);
137         // 如果不存在或者 也不是目录就直接返回
138         if (!dir.exists() || !dir.isDirectory()) {
139             // log.warn("用户定义包名 " + packageName + " 下没有任何文件");
140             return;
141         }
142         // 如果存在 就获取包下的所有文件 包括目录
143         File[] dirfiles = dir.listFiles(new FileFilter(){
144             // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
145             @Override
146             public boolean accept(File file){
147                 return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
148             }
149         });
150         // 循环所有文件
151         for (File file : dirfiles){
152             // 如果是目录 则继续扫描
153             if (file.isDirectory()) {
154                 findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes);
155             }else{
156                 // 如果是java类文件 去掉后面的.class 只留下类名
157                 String className = file.getName().substring(0, file.getName().length() - 6);
158                 try{
159                     // 添加到集合中去
160                     // classes.add(Class.forName(packageName + '.' +
161                     // className));
162                     // 经过回复同学的提醒,这里用forName有一些不好,会触发static方法,没有使用classLoader的load干净
163                     classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className));
164                 }catch (ClassNotFoundException e){
165                     // log.error("添加用户自定义视图类错误 找不到此类的.class文件");
166                     e.printStackTrace();
167                 }
168             }
169         }
170     }
171
172     /**
173      * 获取父类的字段
174      * @param fields
175      * @param clas
176      * @return
177      */
178     public static Field[] getPatentFields(Field[] fields,Class<?> clas){
179         if (clas.getSuperclass() != null) {
180             Class clsSup = clas.getSuperclass();
181             List<Field> fieldList = new ArrayList<Field>();
182             fieldList.addAll(Arrays.asList(fields));
183             fieldList.addAll(Arrays.asList(clsSup.getDeclaredFields()));
184             fields = new Field[fieldList.size()];
185             int i = 0;
186             for (Object field : fieldList.toArray()) {
187                 fields[i] = (Field) field;
188                 i++;
189             }
190             fields = getPatentFields(fields,clsSup);
191         }
192         return  fields;
193     }
194
195
196 }