chenjiahe
2022-01-13 c64e1248bfda3ac8c5120e529fd096dfc4846629
提交 | 用户 | age
c64e12 1 package com.hx.mybatis.aes.springbean;
e29546 2
c64e12 3 import com.gitee.sunchenbin.mybatis.actable.annotation.Table;
e29546 4 import com.hx.common.annotations.MysqlHexAes;
C 5 import com.hx.util.StringUtils;
6 import org.springframework.stereotype.Component;
7
8 import javax.annotation.PostConstruct;
9 import javax.annotation.Resource;
10 import java.io.File;
11 import java.io.FileFilter;
12 import java.io.IOException;
13 import java.lang.reflect.Field;
14 import java.net.URL;
15 import java.net.URLDecoder;
16 import java.util.*;
17
18 /**
19  * 获取指定包里面的AES秘钥
20  */
21 @Component
22 public class VariableAesKey {
23
24     @Resource
25     private ConstantBean constantBean;
26
c64e12 27     /**是否已经启动完*/
C 28     public static int isRun = 0;
29
30     /**存储所有AES的秘钥*/
e29546 31     public static Map<String,String> aesKeys = new HashMap<>();
c64e12 32     /**根据表明来存储AES秘钥*/
C 33     public static Map<String,Map<String,String>> aesKeysTable = new HashMap<>();
34
35     /**固定的aes秘钥*/
36     public static String AES_KEY = null;
e29546 37
C 38     /**存储AES秘钥*/
39     public static void setAesKey(String aesKeyFild,String aesKey){
40         aesKeys.put(aesKeyFild,aesKey);
41     }
42     /**获取AES秘钥*/
43     public static String getAesKey(String aesKeyFild){
c64e12 44         if(aesKeyFild == null){
C 45             return AES_KEY;
46         }
47         if(StringUtils.isEmpty(aesKeys.get(aesKeyFild))){
48             return AES_KEY;
49         }else {
50             return  aesKeys.get(aesKeyFild);
51         }
e29546 52     }
C 53
54     /**
55      * 项目启动就执行后就执行该方法
56      */
57     @PostConstruct
58     public void VariableAesKey(){
c64e12 59
C 60         isRun = 1;
e29546 61         //项目启动的时候填入
C 62         System.err.println("扫描获取AES:" + constantBean.getPackPath());
c64e12 63         AES_KEY = constantBean.getFixedAesKey();
e29546 64         if(StringUtils.noNull(constantBean.getPackPath())){
C 65             Set<Class<?>> classes = classData(constantBean.getPackPath());
c64e12 66
C 67             Map<String,String> aesKeysFild = new HashMap<>();
68             boolean isAes = false;
69             String tableName = null;
70
e29546 71             for(Class<?> cl:classes){
c64e12 72                 //表名称
C 73                 Table table = cl.getAnnotation(Table.class);
74                 if(table == null){
75                     continue;
76                 }
77                 tableName = table.name();
78
79                 aesKeysFild = new HashMap<>();
80                 isAes = false;
81
e29546 82                 // 取得本类的全部属性
C 83                 Field[] fields = cl.getDeclaredFields();
84                 fields = getPatentFields(fields,cl);
85                 for (Field field:fields) {
86                     // 判断方法中是否有指定注解类型的注解
87                     boolean hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class);
88                     if (hasAnnotation) {
89                         // 根据注解类型返回方法的指定类型注解
90                         MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class);
91
c64e12 92                         //String aesKeyField = mysqlHexAes.aesKeyField();
e29546 93                         String aesKey = mysqlHexAes.aesKey();
C 94
95                         if(StringUtils.isEmpty(aesKey)){
c64e12 96                             aesKey = constantBean.getFixedAesKey();
C 97                             if(StringUtils.isEmpty(aesKey)){
98                                 throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName());
e29546 99                             }
c64e12 100                         }
C 101                         String key = aesKeys.get(field.getName());
102                         if(StringUtils.isEmpty(key)){
103                             aesKeys.put(field.getName(),aesKey);
104                             aesKeysFild.put(field.getName(),aesKey);
105                             isAes = true;
e29546 106                         }else{
c64e12 107                             isAes = true;
C 108                             aesKeysFild.put(field.getName(),aesKey);
109                             if(!aesKey.equals(key)){
110                                 throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
e29546 111                             }
C 112                         }
113                     }
c64e12 114                 }
C 115                 if(isAes){
116                     aesKeysTable.put(tableName,aesKeysFild);
e29546 117                 }
C 118             }
119         }
120     }
121
122     /**获取包下面的所有文件*/
123     public static Set<Class<?>> classData(String packPath){
124         Set<Class<?>> classes = new LinkedHashSet<Class<?>>();
c64e12 125
C 126         //截取
127         String[] packPaths = packPath.split(";|,");
128         for( String packageName : packPaths){
129             // 是否循环迭代
130             boolean recursive = true;
131             // 获取包的名字 并进行替换
132             String packageDirName = packageName.replace('.', '/');
133             // 定义一个枚举的集合 并进行循环来处理这个目录下的things
134             Enumeration<URL> dirs;
135             try{
136                 dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
137                 // 循环迭代下去
138                 while (dirs.hasMoreElements()){
139                     // 获取下一个元素
140                     URL url = dirs.nextElement();
141                     // 得到协议的名称
142                     String protocol = url.getProtocol();
143                     // 如果是以文件的形式保存在服务器上
144                     if ("file".equals(protocol)) {
145                         // 获取包的物理路径
146                         String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
147                         // 以文件的方式扫描整个包下的文件 并添加到集合中
148                         findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
149                     }
e29546 150                 }
c64e12 151             }catch (IOException e){
C 152                 e.printStackTrace();
e29546 153             }
C 154         }
155         return classes;
156     }
157
158     /**
159      * 以文件的形式来获取包下的所有Class
160      *
161      * @param packageName
162      * @param packagePath
163      * @param recursive
164      * @param classes
165      */
166     public static void findAndAddClassesInPackageByFile(
167             String packageName,
168             String packagePath,
169             final boolean recursive,
170             Set<Class<?>> classes){
171         // 获取此包的目录 建立一个File
172         File dir = new File(packagePath);
173         // 如果不存在或者 也不是目录就直接返回
174         if (!dir.exists() || !dir.isDirectory()) {
175             // log.warn("用户定义包名 " + packageName + " 下没有任何文件");
176             return;
177         }
178         // 如果存在 就获取包下的所有文件 包括目录
179         File[] dirfiles = dir.listFiles(new FileFilter(){
180             // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
181             @Override
182             public boolean accept(File file){
183                 return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
184             }
185         });
186         // 循环所有文件
187         for (File file : dirfiles){
188             // 如果是目录 则继续扫描
189             if (file.isDirectory()) {
190                 findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes);
191             }else{
192                 // 如果是java类文件 去掉后面的.class 只留下类名
193                 String className = file.getName().substring(0, file.getName().length() - 6);
194                 try{
195                     // 添加到集合中去
196                     // classes.add(Class.forName(packageName + '.' +
197                     // className));
198                     // 经过回复同学的提醒,这里用forName有一些不好,会触发static方法,没有使用classLoader的load干净
199                     classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className));
200                 }catch (ClassNotFoundException e){
201                     // log.error("添加用户自定义视图类错误 找不到此类的.class文件");
202                     e.printStackTrace();
203                 }
204             }
205         }
206     }
207
208     /**
209      * 获取父类的字段
210      * @param fields
211      * @param clas
212      * @return
213      */
214     public static Field[] getPatentFields(Field[] fields,Class<?> clas){
215         if (clas.getSuperclass() != null) {
216             Class clsSup = clas.getSuperclass();
217             List<Field> fieldList = new ArrayList<Field>();
218             fieldList.addAll(Arrays.asList(fields));
219             fieldList.addAll(Arrays.asList(clsSup.getDeclaredFields()));
220             fields = new Field[fieldList.size()];
221             int i = 0;
222             for (Object field : fieldList.toArray()) {
223                 fields[i] = (Field) field;
224                 i++;
225             }
226             fields = getPatentFields(fields,clsSup);
227         }
228         return  fields;
229     }
230
231
232 }