package com.hx.mybatis.aes.springbean;
|
|
import com.gitee.sunchenbin.mybatis.actable.annotation.Table;
|
import com.hx.common.annotations.MysqlHexAes;
|
import com.hx.util.StringUtils;
|
import org.springframework.stereotype.Component;
|
|
import javax.annotation.PostConstruct;
|
import javax.annotation.Resource;
|
import java.io.File;
|
import java.io.FileFilter;
|
import java.io.IOException;
|
import java.lang.reflect.Field;
|
import java.net.URL;
|
import java.net.URLDecoder;
|
import java.util.*;
|
|
/**
|
* 获取指定包里面的AES秘钥
|
*/
|
@Component
|
public class VariableAesKey {
|
|
@Resource
|
private ConstantBean constantBean;
|
|
/**是否已经启动完*/
|
public static int isRun = 0;
|
|
/**存储所有AES的秘钥*/
|
public static Map<String,String> aesKeys = new HashMap<>();
|
/**根据表明来存储AES秘钥*/
|
public static Map<String,Map<String,String>> aesKeysTable = new HashMap<>();
|
|
/**固定的aes秘钥*/
|
public static String AES_KEY = null;
|
|
/**存储AES秘钥*/
|
public static void setAesKey(String aesKeyFild,String aesKey){
|
aesKeys.put(aesKeyFild,aesKey);
|
}
|
/**获取AES秘钥*/
|
public static String getAesKey(String aesKeyFild){
|
if(aesKeyFild == null){
|
return AES_KEY;
|
}
|
if(StringUtils.isEmpty(aesKeys.get(aesKeyFild))){
|
return AES_KEY;
|
}else {
|
return aesKeys.get(aesKeyFild);
|
}
|
}
|
|
/**
|
* 项目启动就执行后就执行该方法
|
*/
|
@PostConstruct
|
public void VariableAesKey(){
|
|
isRun = 1;
|
//项目启动的时候填入
|
System.err.println("扫描获取AES:" + constantBean.getPackPath());
|
AES_KEY = constantBean.getFixedAesKey();
|
if(!StringUtils.isEmpty(constantBean.getPackPath())){
|
Set<Class<?>> classes = classData(constantBean.getPackPath());
|
|
Map<String,String> aesKeysFild = new HashMap<>();
|
boolean isAes = false;
|
String tableName = null;
|
|
for(Class<?> cl:classes){
|
//表名称
|
Table table = cl.getAnnotation(Table.class);
|
if(table == null){
|
continue;
|
}
|
tableName = table.name();
|
|
aesKeysFild = new HashMap<>();
|
isAes = false;
|
|
// 取得本类的全部属性
|
Field[] fields = cl.getDeclaredFields();
|
fields = getPatentFields(fields,cl);
|
for (Field field:fields) {
|
// 判断方法中是否有指定注解类型的注解
|
boolean hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class);
|
if (hasAnnotation) {
|
// 根据注解类型返回方法的指定类型注解
|
MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class);
|
|
//String aesKeyField = mysqlHexAes.aesKeyField();
|
String aesKey = mysqlHexAes.aesKey();
|
|
if(StringUtils.isEmpty(aesKey)){
|
aesKey = constantBean.getFixedAesKey();
|
if(StringUtils.isEmpty(aesKey)){
|
throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName());
|
}
|
}
|
String key = aesKeys.get(field.getName());
|
if(StringUtils.isEmpty(key)){
|
aesKeys.put(field.getName(),aesKey);
|
aesKeysFild.put(field.getName(),aesKey);
|
isAes = true;
|
}else{
|
isAes = true;
|
aesKeysFild.put(field.getName(),aesKey);
|
if(!aesKey.equals(key)){
|
throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
|
}
|
}
|
}
|
}
|
if(isAes){
|
aesKeysTable.put(tableName,aesKeysFild);
|
}
|
}
|
}
|
}
|
|
/**获取包下面的所有文件*/
|
public static Set<Class<?>> classData(String packPath){
|
Set<Class<?>> classes = new LinkedHashSet<Class<?>>();
|
|
//截取
|
String[] packPaths = packPath.split(";|,");
|
for( String packageName : packPaths){
|
// 是否循环迭代
|
boolean recursive = true;
|
// 获取包的名字 并进行替换
|
String packageDirName = packageName.replace('.', '/');
|
// 定义一个枚举的集合 并进行循环来处理这个目录下的things
|
Enumeration<URL> dirs;
|
try{
|
dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
|
// 循环迭代下去
|
while (dirs.hasMoreElements()){
|
// 获取下一个元素
|
URL url = dirs.nextElement();
|
// 得到协议的名称
|
String protocol = url.getProtocol();
|
// 如果是以文件的形式保存在服务器上
|
if ("file".equals(protocol)) {
|
// 获取包的物理路径
|
String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
|
// 以文件的方式扫描整个包下的文件 并添加到集合中
|
findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
|
}
|
}
|
}catch (IOException e){
|
e.printStackTrace();
|
}
|
}
|
return classes;
|
}
|
|
/**
|
* 以文件的形式来获取包下的所有Class
|
*
|
* @param packageName
|
* @param packagePath
|
* @param recursive
|
* @param classes
|
*/
|
public static void findAndAddClassesInPackageByFile(
|
String packageName,
|
String packagePath,
|
final boolean recursive,
|
Set<Class<?>> classes){
|
// 获取此包的目录 建立一个File
|
File dir = new File(packagePath);
|
// 如果不存在或者 也不是目录就直接返回
|
if (!dir.exists() || !dir.isDirectory()) {
|
// log.warn("用户定义包名 " + packageName + " 下没有任何文件");
|
return;
|
}
|
// 如果存在 就获取包下的所有文件 包括目录
|
File[] dirfiles = dir.listFiles(new FileFilter(){
|
// 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
|
@Override
|
public boolean accept(File file){
|
return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
|
}
|
});
|
// 循环所有文件
|
for (File file : dirfiles){
|
// 如果是目录 则继续扫描
|
if (file.isDirectory()) {
|
findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes);
|
}else{
|
// 如果是java类文件 去掉后面的.class 只留下类名
|
String className = file.getName().substring(0, file.getName().length() - 6);
|
try{
|
// 添加到集合中去
|
// classes.add(Class.forName(packageName + '.' +
|
// className));
|
// 经过回复同学的提醒,这里用forName有一些不好,会触发static方法,没有使用classLoader的load干净
|
classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className));
|
}catch (ClassNotFoundException e){
|
// log.error("添加用户自定义视图类错误 找不到此类的.class文件");
|
e.printStackTrace();
|
}
|
}
|
}
|
}
|
|
/**
|
* 获取父类的字段
|
* @param fields
|
* @param clas
|
* @return
|
*/
|
public static Field[] getPatentFields(Field[] fields,Class<?> clas){
|
if (clas.getSuperclass() != null) {
|
Class clsSup = clas.getSuperclass();
|
List<Field> fieldList = new ArrayList<Field>();
|
fieldList.addAll(Arrays.asList(fields));
|
fieldList.addAll(Arrays.asList(clsSup.getDeclaredFields()));
|
fields = new Field[fieldList.size()];
|
int i = 0;
|
for (Object field : fieldList.toArray()) {
|
fields[i] = (Field) field;
|
i++;
|
}
|
fields = getPatentFields(fields,clsSup);
|
}
|
return fields;
|
}
|
|
|
}
|