package com.hx.mybatis.aes.springbean;
|
|
import com.gitee.sunchenbin.mybatis.actable.annotation.Table;
|
import com.hx.common.annotations.MysqlHexAes;
|
import com.hx.util.StringUtils;
|
import org.slf4j.Logger;
|
import org.slf4j.LoggerFactory;
|
import org.springframework.stereotype.Component;
|
|
import javax.annotation.PostConstruct;
|
import javax.annotation.Resource;
|
import java.io.File;
|
import java.io.FileFilter;
|
import java.io.IOException;
|
import java.lang.reflect.Field;
|
import java.net.JarURLConnection;
|
import java.net.URL;
|
import java.net.URLDecoder;
|
import java.util.*;
|
import java.util.jar.JarEntry;
|
import java.util.jar.JarFile;
|
|
/**
|
* 获取指定包里面的AES秘钥
|
*/
|
@Component
|
public class VariableAesKey {
|
|
//log4j日志
|
private static Logger logger = LoggerFactory.getLogger(VariableAesKey.class.getName());
|
|
@Resource
|
private ConstantBean constantBean;
|
|
/**是否已经启动完*/
|
public static int isRun = 0;
|
|
/**存储所有AES的秘钥*/
|
public static Map<String,String> aesKeys = new HashMap<>();
|
/**根据表明来存储AES秘钥*/
|
public static Map<String,Map<String,String>> aesKeysTable = new HashMap<>();
|
|
/**固定的aes秘钥*/
|
public static String AES_KEY = null;
|
/**数据库加密字段初始化版本号*/
|
public static String INIT_VERSION = null;
|
|
|
|
/**存储AES秘钥*/
|
public static void setAesKey(String aesKeyFild,String aesKey){
|
aesKeys.put(aesKeyFild,aesKey);
|
}
|
/**获取AES秘钥*/
|
public static String getAesKey(String aesKeyFild){
|
if(aesKeyFild == null){
|
return AES_KEY;
|
}
|
if(StringUtils.isEmpty(aesKeys.get(aesKeyFild))){
|
return AES_KEY;
|
}else {
|
return aesKeys.get(aesKeyFild);
|
}
|
}
|
|
/**
|
* 项目启动就执行后就执行该方法
|
*/
|
@PostConstruct
|
public void VariableAesKey(){
|
|
isRun = 1;
|
//项目启动的时候填入
|
logger.info("扫描获取AES的包:" + constantBean.getPackPath());
|
AES_KEY = constantBean.getFixedAesKey();
|
INIT_VERSION = constantBean.getInitVersion();
|
if(!StringUtils.isEmpty(constantBean.getPackPath())){
|
Set<Class<?>> classes = classData(constantBean.getPackPath());
|
logger.info("扫描获取AES的包classes:" + classes.size());
|
Map<String,String> aesKeysFild = new HashMap<>();
|
boolean isAes = false;
|
String tableName = null;
|
|
for(Class<?> cl:classes){
|
//表名称
|
boolean hasAnnotation = cl.isAnnotationPresent(Table.class);
|
if(!hasAnnotation){
|
continue;
|
}
|
Table table = cl.getAnnotation(Table.class);
|
tableName = table.name();
|
|
aesKeysFild = new HashMap<>();
|
isAes = false;
|
|
// 取得本类的全部属性
|
Field[] fields = cl.getDeclaredFields();
|
fields = getPatentFields(fields,cl);
|
for (Field field:fields) {
|
// 判断方法中是否有指定注解类型的注解
|
hasAnnotation = field.isAnnotationPresent(MysqlHexAes.class);
|
if (hasAnnotation) {
|
// 根据注解类型返回方法的指定类型注解
|
MysqlHexAes mysqlHexAes = field.getAnnotation(MysqlHexAes.class);
|
|
//String aesKeyField = mysqlHexAes.aesKeyField();
|
String aesKey = mysqlHexAes.aesKey();
|
|
if(StringUtils.isEmpty(aesKey)){
|
aesKey = constantBean.getFixedAesKey();
|
if(StringUtils.isEmpty(aesKey)){
|
throw new RuntimeException("mysql的AES秘钥不能为空:"+field.getName());
|
}
|
}
|
String key = aesKeys.get(field.getName());
|
if(StringUtils.isEmpty(key)){
|
aesKeys.put(field.getName(),aesKey);
|
aesKeysFild.put(field.getName(),aesKey);
|
isAes = true;
|
}else{
|
isAes = true;
|
aesKeysFild.put(field.getName(),aesKey);
|
if(!aesKey.equals(key)){
|
throw new RuntimeException("字段/定义的AES秘钥字段【"+field.getName()+"】多个一样,但是AES秘钥不一样");
|
}
|
}
|
}
|
}
|
if(isAes){
|
aesKeysTable.put(tableName,aesKeysFild);
|
}
|
}
|
}
|
}
|
|
/**获取包下面的所有文件*/
|
public static Set<Class<?>> classData(String packPath){
|
Set<Class<?>> classes = new LinkedHashSet();
|
String[] split = packPath.split(",|;");
|
String[] var3 = split;
|
int var4 = split.length;
|
|
label82:
|
for(int var5 = 0; var5 < var4; ++var5) {
|
String pack = var3[var5];
|
boolean recursive = true;
|
String packageName = pack;
|
String packageDirName = pack.replace('.', '/');
|
|
try {
|
Enumeration dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
|
|
while(true) {
|
label75:
|
while(true) {
|
if (!dirs.hasMoreElements()) {
|
continue label82;
|
}
|
|
URL url = (URL)dirs.nextElement();
|
String protocol = url.getProtocol();
|
if ("file".equals(protocol)) {
|
System.err.println("file类型的扫描:" + pack);
|
String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
|
findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes);
|
} else if ("jar".equals(protocol)) {
|
System.err.println("jar类型的扫描");
|
|
try {
|
JarFile jar = ((JarURLConnection)url.openConnection()).getJarFile();
|
Enumeration entries = jar.entries();
|
|
while(true) {
|
JarEntry entry;
|
String name;
|
int idx;
|
do {
|
do {
|
if (!entries.hasMoreElements()) {
|
continue label75;
|
}
|
|
entry = (JarEntry)entries.nextElement();
|
name = entry.getName();
|
if (name.charAt(0) == '/') {
|
name = name.substring(1);
|
}
|
} while(!name.startsWith(packageDirName));
|
|
idx = name.lastIndexOf(47);
|
if (idx != -1) {
|
packageName = name.substring(0, idx).replace('/', '.');
|
}
|
} while(idx == -1 && !recursive);
|
|
if (name.endsWith(".class") && !entry.isDirectory()) {
|
String className = name.substring(packageName.length() + 1, name.length() - 6);
|
|
try {
|
classes.add(Class.forName(packageName + '.' + className));
|
} catch (ClassNotFoundException var20) {
|
var20.printStackTrace();
|
}
|
}
|
}
|
} catch (IOException var21) {
|
var21.printStackTrace();
|
}
|
}
|
}
|
}
|
} catch (IOException var22) {
|
var22.printStackTrace();
|
}
|
}
|
|
return classes;
|
}
|
|
/**
|
* 以文件的形式来获取包下的所有Class
|
*
|
* @param packageName
|
* @param packagePath
|
* @param recursive
|
* @param classes
|
*/
|
public static void findAndAddClassesInPackageByFile(
|
String packageName,
|
String packagePath,
|
final boolean recursive,
|
Set<Class<?>> classes){
|
// 获取此包的目录 建立一个File
|
File dir = new File(packagePath);
|
// 如果不存在或者 也不是目录就直接返回
|
if (!dir.exists() || !dir.isDirectory()) {
|
// log.warn("用户定义包名 " + packageName + " 下没有任何文件");
|
return;
|
}
|
// 如果存在 就获取包下的所有文件 包括目录
|
File[] dirfiles = dir.listFiles(new FileFilter(){
|
// 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
|
@Override
|
public boolean accept(File file){
|
return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
|
}
|
});
|
// 循环所有文件
|
for (File file : dirfiles){
|
// 如果是目录 则继续扫描
|
if (file.isDirectory()) {
|
findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes);
|
}else{
|
// 如果是java类文件 去掉后面的.class 只留下类名
|
String className = file.getName().substring(0, file.getName().length() - 6);
|
try{
|
// 添加到集合中去
|
// classes.add(Class.forName(packageName + '.' +
|
// className));
|
// 经过回复同学的提醒,这里用forName有一些不好,会触发static方法,没有使用classLoader的load干净
|
classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + '.' + className));
|
}catch (ClassNotFoundException e){
|
// log.error("添加用户自定义视图类错误 找不到此类的.class文件");
|
e.printStackTrace();
|
}
|
}
|
}
|
}
|
|
/**
|
* 获取父类的字段
|
* @param fields
|
* @param clas
|
* @return
|
*/
|
public static Field[] getPatentFields(Field[] fields,Class<?> clas){
|
if (clas.getSuperclass() != null) {
|
Class clsSup = clas.getSuperclass();
|
List<Field> fieldList = new ArrayList<Field>();
|
fieldList.addAll(Arrays.asList(fields));
|
fieldList.addAll(Arrays.asList(clsSup.getDeclaredFields()));
|
fields = new Field[fieldList.size()];
|
int i = 0;
|
for (Object field : fieldList.toArray()) {
|
fields[i] = (Field) field;
|
i++;
|
}
|
fields = getPatentFields(fields,clsSup);
|
}
|
return fields;
|
}
|
|
|
}
|