ChenJiaHe
2022-01-23 5c933de8e9024a194432e6feeef71bf987e3939f
提交 | 用户 | age
c64e12 1 package com.hx.mybatis.aes.springbean;
e29546 2
c64e12 3 import com.gitee.sunchenbin.mybatis.actable.annotation.Table;
e29546 4 import com.hx.common.annotations.MysqlHexAes;
C 5 import com.hx.util.StringUtils;
346eb2 6 import org.slf4j.Logger;
C 7 import org.slf4j.LoggerFactory;
e29546 8 import org.springframework.stereotype.Component;
C 9
10 import javax.annotation.PostConstruct;
11 import javax.annotation.Resource;
12 import java.io.File;
13 import java.io.FileFilter;
14 import java.io.IOException;
15 import java.lang.reflect.Field;
346eb2 16 import java.net.JarURLConnection;
e29546 17 import java.net.URL;
C 18 import java.net.URLDecoder;
19 import java.util.*;
346eb2 20 import java.util.jar.JarEntry;
C 21 import java.util.jar.JarFile;
e29546 22
C 23 /**
24  * 获取指定包里面的AES秘钥
25  */
26 @Component
27 public class VariableAesKey {
346eb2 28
C 29     //log4j日志
30     private static Logger logger = LoggerFactory.getLogger(VariableAesKey.class.getName());
e29546 31
C 32     @Resource
33     private ConstantBean constantBean;
34
c64e12 35     /**是否已经启动完*/
C 36     public static int isRun = 0;
37
38     /**存储所有AES的秘钥*/
e29546 39     public static Map<String,String> aesKeys = new HashMap<>();
c64e12 40     /**根据表明来存储AES秘钥*/
C 41     public static Map<String,Map<String,String>> aesKeysTable = new HashMap<>();
42
43     /**固定的aes秘钥*/
44     public static String AES_KEY = null;
e29546 45
C 46     /**存储AES秘钥*/
47     public static void setAesKey(String aesKeyFild,String aesKey){
48         aesKeys.put(aesKeyFild,aesKey);
49     }
50     /**获取AES秘钥*/
51     public static String getAesKey(String aesKeyFild){
c64e12 52         if(aesKeyFild == null){
C 53             return AES_KEY;
54         }
55         if(StringUtils.isEmpty(aesKeys.get(aesKeyFild))){
56             return AES_KEY;
57         }else {
58             return  aesKeys.get(aesKeyFild);
59         }
e29546 60     }
C 61
62     /**
63      * 项目启动就执行后就执行该方法
64      */
65     @PostConstruct
66     public void VariableAesKey(){
c64e12 67
C 68         isRun = 1;
e29546 69         //项目启动的时候填入
346eb2 70         logger.info("扫描获取AES的包:" + constantBean.getPackPath());
c64e12 71         AES_KEY = constantBean.getFixedAesKey();
af1eee 72         if(!StringUtils.isEmpty(constantBean.getPackPath())){
e29546 73             Set<Class<?>> classes = classData(constantBean.getPackPath());
346eb2 74             logger.info("扫描获取AES的包classes:" + classes.size());
c64e12 75             Map<String,String> aesKeysFild = new HashMap<>();
C 76             boolean isAes = false;
77             String tableName = null;
78
e29546 79             for(Class<?> cl:classes){
c64e12 80                 //表名称
346eb2 81                 boolean hasAnnotation = cl.isAnnotationPresent(Table.class);
C 82                 if(!hasAnnotation){
c64e12 83                     continue;
C 84                 }
346eb2 85                 Table table = cl.getAnnotation(Table.class);
c64e12 86                 tableName = table.name();
C 87
88                 aesKeysFild = new HashMap<>();
89                 isAes = false;
90
e29546 91                 // 取得本类的全部属性
C 92                 Field[] fields = cl.getDeclaredFields();
93                 fields = getPatentFields(fields,cl);
94                 for (Field field:fields) {
95                     // 判断方法中是否有指定注解类型的注解
346eb2 96                     hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class);
e29546 97                     if (hasAnnotation) {
C 98                         // 根据注解类型返回方法的指定类型注解
99                         MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class);
100
c64e12 101                         //String aesKeyField = mysqlHexAes.aesKeyField();
e29546 102                         String aesKey = mysqlHexAes.aesKey();
C 103
104                         if(StringUtils.isEmpty(aesKey)){
c64e12 105                             aesKey = constantBean.getFixedAesKey();
C 106                             if(StringUtils.isEmpty(aesKey)){
107                                 throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName());
e29546 108                             }
c64e12 109                         }
C 110                         String key = aesKeys.get(field.getName());
111                         if(StringUtils.isEmpty(key)){
112                             aesKeys.put(field.getName(),aesKey);
113                             aesKeysFild.put(field.getName(),aesKey);
114                             isAes = true;
e29546 115                         }else{
c64e12 116                             isAes = true;
C 117                             aesKeysFild.put(field.getName(),aesKey);
118                             if(!aesKey.equals(key)){
119                                 throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
e29546 120                             }
C 121                         }
122                     }
c64e12 123                 }
C 124                 if(isAes){
125                     aesKeysTable.put(tableName,aesKeysFild);
e29546 126                 }
C 127             }
128         }
129     }
130
131     /**获取包下面的所有文件*/
132     public static Set<Class<?>> classData(String packPath){
346eb2 133         Set<Class<?>> classes = new LinkedHashSet();
C 134         String[] split = packPath.split(",|;");
135         String[] var3 = split;
136         int var4 = split.length;
c64e12 137
346eb2 138         label82:
C 139         for(int var5 = 0; var5 < var4; ++var5) {
140             String pack = var3[var5];
c64e12 141             boolean recursive = true;
346eb2 142             String packageName = pack;
C 143             String packageDirName = pack.replace('.', '/');
144
145             try {
146                 Enumeration dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
147
148                 while(true) {
149                     label75:
150                     while(true) {
151                         if (!dirs.hasMoreElements()) {
152                             continue label82;
153                         }
154
155                         URL url = (URL)dirs.nextElement();
156                         String protocol = url.getProtocol();
157                         if ("file".equals(protocol)) {
158                             System.err.println("file类型的扫描:" + pack);
159                             String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
160                             findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
161                         } else if ("jar".equals(protocol)) {
162                             System.err.println("jar类型的扫描");
163
164                             try {
165                                 JarFile jar = ((JarURLConnection)url.openConnection()).getJarFile();
166                                 Enumeration entries = jar.entries();
167
168                                 while(true) {
169                                     JarEntry entry;
170                                     String name;
171                                     int idx;
172                                     do {
173                                         do {
174                                             if (!entries.hasMoreElements()) {
175                                                 continue label75;
176                                             }
177
178                                             entry = (JarEntry)entries.nextElement();
179                                             name = entry.getName();
180                                             if (name.charAt(0) == '/') {
181                                                 name = name.substring(1);
182                                             }
183                                         } while(!name.startsWith(packageDirName));
184
185                                         idx = name.lastIndexOf(47);
186                                         if (idx != -1) {
187                                             packageName = name.substring(0, idx).replace('/', '.');
188                                         }
189                                     } while(idx == -1 && !recursive);
190
191                                     if (name.endsWith(".class") && !entry.isDirectory()) {
192                                         String className = name.substring(packageName.length() + 1, name.length() - 6);
193
194                                         try {
195                                             classes.add(Class.forName(packageName + '.' + className));
196                                         } catch (ClassNotFoundException var20) {
197                                             var20.printStackTrace();
198                                         }
199                                     }
200                                 }
201                             } catch (IOException var21) {
202                                 var21.printStackTrace();
203                             }
204                         }
c64e12 205                     }
e29546 206                 }
346eb2 207             } catch (IOException var22) {
C 208                 var22.printStackTrace();
e29546 209             }
C 210         }
346eb2 211
e29546 212         return classes;
C 213     }
214
215     /**
216      * 以文件的形式来获取包下的所有Class
217      *
218      * @param packageName
219      * @param packagePath
220      * @param recursive
221      * @param classes
222      */
223     public static void findAndAddClassesInPackageByFile(
224             String packageName,
225             String packagePath,
226             final boolean recursive,
227             Set<Class<?>> classes){
228         // 获取此包的目录 建立一个File
229         File dir = new File(packagePath);
230         // 如果不存在或者 也不是目录就直接返回
231         if (!dir.exists() || !dir.isDirectory()) {
232             // log.warn("用户定义包名 " + packageName + " 下没有任何文件");
233             return;
234         }
235         // 如果存在 就获取包下的所有文件 包括目录
236         File[] dirfiles = dir.listFiles(new FileFilter(){
237             // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
238             @Override
239             public boolean accept(File file){
240                 return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
241             }
242         });
243         // 循环所有文件
244         for (File file : dirfiles){
245             // 如果是目录 则继续扫描
246             if (file.isDirectory()) {
247                 findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes);
248             }else{
249                 // 如果是java类文件 去掉后面的.class 只留下类名
250                 String className = file.getName().substring(0, file.getName().length() - 6);
251                 try{
252                     // 添加到集合中去
253                     // classes.add(Class.forName(packageName + '.' +
254                     // className));
255                     // 经过回复同学的提醒,这里用forName有一些不好,会触发static方法,没有使用classLoader的load干净
256                     classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className));
257                 }catch (ClassNotFoundException e){
258                     // log.error("添加用户自定义视图类错误 找不到此类的.class文件");
259                     e.printStackTrace();
260                 }
261             }
262         }
263     }
264
265     /**
266      * 获取父类的字段
267      * @param fields
268      * @param clas
269      * @return
270      */
271     public static Field[] getPatentFields(Field[] fields,Class<?> clas){
272         if (clas.getSuperclass() != null) {
273             Class clsSup = clas.getSuperclass();
274             List<Field> fieldList = new ArrayList<Field>();
275             fieldList.addAll(Arrays.asList(fields));
276             fieldList.addAll(Arrays.asList(clsSup.getDeclaredFields()));
277             fields = new Field[fieldList.size()];
278             int i = 0;
279             for (Object field : fieldList.toArray()) {
280                 fields[i] = (Field) field;
281                 i++;
282             }
283             fields = getPatentFields(fields,clsSup);
284         }
285         return  fields;
286     }
287
288
289 }