Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
package com.yxw.zuoye.plugins;
import java.lang.reflect.Field;
import java.sql.Statement;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.defaults.DefaultSqlSession.StrictMap;
import org.springframework.stereotype.Component;
@Component
@Intercepts({@Signature(type = StatementHandler.class, method = "query", args = {Statement.class, ResultHandler.class}),
@Signature(type = StatementHandler.class, method = "update", args = {Statement.class}),
@Signature(type = StatementHandler.class, method = "batch", args = { Statement.class })})
public class mybatisPlugin implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object target = invocation.getTarget();
long startTime = System.currentTimeMillis();
StatementHandler statementHandler = (StatementHandler)target;
try {
return invocation.proceed();
} finally {
long endTime = System.currentTimeMillis();
long sqlCost = endTime - startTime;
BoundSql boundSql = statementHandler.getBoundSql();
String sql = boundSql.getSql();
Object parameterObject = boundSql.getParameterObject();
List<ParameterMapping> parameterMappingList = boundSql.getParameterMappings();
// 格式化Sql语句,去除换行符,替换参数
sql = formatSql(sql, parameterObject, parameterMappingList);
System.out.println("SQL:[" + sql + "]执行耗时[" + sqlCost + "ms]");
}
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
@SuppressWarnings("unchecked")
private String formatSql(String sql, Object parameterObject, List<ParameterMapping> parameterMappingList) {
// 输入sql字符串空判断
if (sql == null || sql.length() == 0) {
return "";
}
// 美化sql
sql = beautifySql(sql);
// 不传参数的场景,直接把Sql美化一下返回出去
if (parameterObject == null || parameterMappingList == null || parameterMappingList.size() == 0) {
return sql;
}
// 定义一个没有替换过占位符的sql,用于出异常时返回
String sqlWithoutReplacePlaceholder = sql;
try {
if (parameterMappingList != null) {
Class<?> parameterObjectClass = parameterObject.getClass();
// 如果参数是StrictMap且Value类型为Collection,获取key="list"的属性,这里主要是为了处理<foreach>循环时传入List这种参数的占位符替换
// 例如select * from xxx where id in <foreach collection="list">...</foreach>
if (isStrictMap(parameterObjectClass)) {
StrictMap<Collection<?>> strictMap = (StrictMap<Collection<?>>)parameterObject;
if (isList(strictMap.get("list").getClass())) {
sql = handleListParameter(sql, strictMap.get("list"));
}
} else if (isMap(parameterObjectClass)) {
// 如果参数是Map则直接强转,通过map.get(key)方法获取真正的属性值
// 这里主要是为了处理<insert>、<delete>、<update>、<select>时传入parameterType为map的场景
Map<?, ?> paramMap = (Map<?, ?>) parameterObject;
sql = handleMapParameter(sql, paramMap, parameterMappingList);
} else {
// 通用场景,比如传的是一个自定义的对象或者八种基本数据类型之一或者String
sql = handleCommonParameter(sql, parameterMappingList, parameterObjectClass, parameterObject);
}
}
} catch (Exception e) {
// 占位符替换过程中出现异常,则返回没有替换过占位符但是格式美化过的sql,这样至少保证sql语句比BoundSql中的sql更好看
return sqlWithoutReplacePlaceholder;
}
return sql;
}
/**
* 美化Sql
*/
private String beautifySql(String sql) {
// sql = sql.replace("\n", "").replace("\t", "").replace(" ", " ").replace("( ", "(").replace(" )", ")").replace(" ,", ",");
sql = sql.replaceAll("[\\s\n ]+"," ");
return sql;
}
/**
* 处理参数为List的场景
*/
private String handleListParameter(String sql, Collection<?> col) {
if (col != null && col.size() != 0) {
for (Object obj : col) {
String value = null;
Class<?> objClass = obj.getClass();
// 只处理基本数据类型、基本数据类型的包装类、String这三种
// 如果是复合类型也是可以的,不过复杂点且这种场景较少,写代码的时候要判断一下要拿到的是复合类型中的哪个属性
if (isPrimitiveOrPrimitiveWrapper(objClass)) {
value = obj.toString();
} else if (objClass.isAssignableFrom(String.class)) {
value = "\"" + obj.toString() + "\"";
}
sql = sql.replaceFirst("\\?", value);
}
}
return sql;
}
/**
* 处理参数为Map的场景
*/
private String handleMapParameter(String sql, Map<?, ?> paramMap, List<ParameterMapping> parameterMappingList) {
for (ParameterMapping parameterMapping : parameterMappingList) {
Object propertyName = parameterMapping.getProperty();
Object propertyValue = paramMap.get(propertyName);
if (propertyValue != null) {
if (propertyValue.getClass().isAssignableFrom(String.class)) {
propertyValue = "\"" + propertyValue + "\"";
}
sql = sql.replaceFirst("\\?", propertyValue.toString());
}
}
return sql;
}
/**
* 处理通用的场景
*/
private String handleCommonParameter(String sql, List<ParameterMapping> parameterMappingList, Class<?> parameterObjectClass,
Object parameterObject) throws Exception {
for (ParameterMapping parameterMapping : parameterMappingList) {
String propertyValue = null;
// 基本数据类型或者基本数据类型的包装类,直接toString即可获取其真正的参数值,其余直接取paramterMapping中的property属性即可
if (isPrimitiveOrPrimitiveWrapper(parameterObjectClass) || "java.lang.String".equals(parameterObjectClass.getName())) {
propertyValue = "\"" + parameterObject.toString() + "\"";
} else {
String propertyName = parameterMapping.getProperty();
Field field = parameterObjectClass.getDeclaredField(propertyName);
// 要获取Field中的属性值,这里必须将私有属性的accessible设置为true
field.setAccessible(true);
propertyValue = String.valueOf(field.get(parameterObject));
if (parameterMapping.getJavaType().isAssignableFrom(String.class)) {
propertyValue = "\"" + propertyValue + "\"";
}
}
sql = sql.replaceFirst("\\?", propertyValue);
}
return sql;
}
/**
* 是否基本数据类型或者基本数据类型的包装类
*/
private boolean isPrimitiveOrPrimitiveWrapper(Class<?> parameterObjectClass) {
return parameterObjectClass.isPrimitive() ||
(parameterObjectClass.isAssignableFrom(Byte.class) || parameterObjectClass.isAssignableFrom(Short.class) ||
parameterObjectClass.isAssignableFrom(Integer.class) || parameterObjectClass.isAssignableFrom(Long.class) ||
parameterObjectClass.isAssignableFrom(Double.class) || parameterObjectClass.isAssignableFrom(Float.class) ||
parameterObjectClass.isAssignableFrom(Character.class) || parameterObjectClass.isAssignableFrom(Boolean.class));
}
/**
* 是否DefaultSqlSession的内部类StrictMap
*/
private boolean isStrictMap(Class<?> parameterObjectClass) {
return parameterObjectClass.isAssignableFrom(StrictMap.class);
}
/**
* 是否List的实现类
*/
private boolean isList(Class<?> clazz) {
Class<?>[] interfaceClasses = clazz.getInterfaces();
for (Class<?> interfaceClass : interfaceClasses) {
if (interfaceClass.isAssignableFrom(List.class)) {
return true;
}
}
return false;
}
/**
* 是否Map的实现类
*/
private boolean isMap(Class<?> parameterObjectClass) {
Class<?>[] interfaceClasses = parameterObjectClass.getInterfaces();
for (Class<?> interfaceClass : interfaceClasses) {
if (interfaceClass.isAssignableFrom(Map.class)) {
return true;
}
}
return false;
}
}