使用SpringMVC作为Controller层进行Web开发时,经常会需要对Controller中的方法进行参数检查。本来SpringMVC自带@Valid和@Validated两个注解可用来检查参数,但只能检查参数是bean的情况,对于参数是String或者Long类型的就不适用了,对此,可以利用Spring的AOP和自定义注解,自己写一个参数校验的功能。
一.自定义注解:
ValidParam.java:
import java.lang.annotation.*;
/**
* 标注在参数bean上,需要对该参数校验
*/
@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ValidParam {
}
NotNull.java:(基本类型校验)
import java.lang.annotation.*;
@Target({ElementType.FIELD, ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface NotNull {
String msg() default "字段不能为空";
}
NotEmpty.java:
import java.lang.annotation.*;
@Target({ElementType.FIELD, ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface NotEmpty {
String msg() default "字段不能为空";
}
二.参数校验切面类
ParamCheckAspect.java:
import com.example.recordlog.anotation.NotEmpty;
import com.example.recordlog.anotation.NotNull;
import com.example.recordlog.anotation.ValidParam;
import com.example.recordlog.exception.CustomException;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.lang.reflect.Field;
import java.lang.reflect.Parameter;
import java.util.Arrays;
/**
* 参数检查切面类
*/
@Aspect
@Component
@Order(1)
public class ParamCheckAspect {
public Object msg = null;
@Before("execution(* com.example.recordlog.controller.*.*(..))")
public void paramCheck(JoinPoint joinPoint) throws Exception {
//获取参数对象
Object[] args = joinPoint.getArgs();
//获取方法参数
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Parameter[] parameters = signature.getMethod().getParameters();
for (int i = 0; i < parameters.length; i++) {
Parameter parameter = parameters[i];
//Java自带基本类型的参数(例如Integer、String)的处理方式
if (isPrimite(parameter.getType())) {
NotNull notNull = parameter.getAnnotation(NotNull.class);
if (notNull != null && args[i] == null) {
throw new CustomException(parameter.toString() + notNull.msg());
}
//TODO
continue;
}
/*
* 没有标注@ValidParam注解,或者是HttpServletRequest、HttpServletResponse、HttpSession时,都不做处理
*/
if (parameter.getType().isAssignableFrom(HttpServletRequest.class) || parameter.getType().isAssignableFrom(HttpSession.class) ||
parameter.getType().isAssignableFrom(HttpServletResponse.class) || parameter.getAnnotation(ValidParam.class) == null) {
continue;
}
Class<?> paramClazz = parameter.getType();
//获取类型所对应的参数对象,实际项目中Controller中的接口不会传两个相同的自定义类型的参数,所以此处直接使用findFirst()
Object arg = Arrays.stream(args).filter(o -> paramClazz.isAssignableFrom(o.getClass())).findFirst().get();
//得到参数的所有成员变量
Field[] declaredFields = paramClazz.getDeclaredFields();
for (Field field : declaredFields) {
field.setAccessible(true);
//校验标有@NotNull注解的字段
NotNull notNull = field.getAnnotation(NotNull.class);
if (notNull != null) {
Object fieldValue = field.get(arg);
if (fieldValue == null) {
throw new CustomException(field.getName() + notNull.msg());
}
}
//校验标有@NotEmpty注解的字段,NotEmpty只用在String类型上
NotEmpty notEmpty = field.getAnnotation(NotEmpty.class);
if (notEmpty != null) {
if (!String.class.isAssignableFrom(field.getType())) {
throw new CustomException("NotEmpty Annotation using in a wrong field class");
}
String fieldStr = (String) field.get(arg);
if (StringUtils.isBlank(fieldStr)) {
throw new CustomException(field.getName() + notEmpty.msg());
}
}
}
}
}
/**
* 判断是否为基本类型:包括String
* clazz.isPrimitive() 用来判断Class是否为原始类型(boolean、char、byte、short、int、long、float、double)
*
* @param clazz clazz
* @return true:是; false:不是
*/
private boolean isPrimite(Class<?> clazz) {
return clazz.isPrimitive() || clazz == String.class;
}
}
三.自定义异常
public class CustomException extends RuntimeException {
private int code;
public CustomException() {
super();
}
public CustomException(String message) {
super(message);
}
public CustomException(int code, String message) {
super(message);
this.code = code;
}
public int getCode() {
return this.code;
}
public CustomException(String message, Throwable cause) {
super(message, cause);
}
}
四.全局异常统一处理
import com.example.recordlog.tools.ResponseUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;
/**
* 异常处理器
*/
@RestControllerAdvice
public class RRExceptionHandler {
private Logger logger = LoggerFactory.getLogger(getClass());
/**
* 处理自定义异常
*/
@ExceptionHandler(CustomException.class)
public ResponseUtils handleRRException(CustomException e) {
return ResponseUtils.fail(e.getCode(), e.getMessage());
}
@ExceptionHandler(Exception.class)
public ResponseUtils handleException(Exception e) {
logger.error(e.getMessage(), e);
return ResponseUtils.fail(-1, e.getMessage());
}
}
五.返回数据工具类
/**
* REST接口统一返回数据工具类封装RestResponse
*/
public class ResponseUtils<T> implements Serializable {
private static final long serialVersionUID = 3728877563912075885L;
private int code;
private String msg;
private T data;
public ResponseUtils() {
}
public ResponseUtils(int code, String message, T data) {
this.code = code;
this.setMsg(message);
this.data = data;
}
public ResponseUtils(int code, T data) {
this.code = code;
this.data = data;
}
public ResponseUtils(int code, String message) {
this.code = code;
this.setMsg(message);
}
public int getCode() {
return code;
}
public String getMsg() {
return msg;
}
public T getData() {
return data;
}
public void setCode(int code) {
this.code = code;
}
public void setMsg(String msg) {
this.msg = msg;
}
public void setData(T data) {
this.data = data;
}
/**
* 成功时-返回data
*
* @param <T>
* @return
*/
public static <T> ResponseUtils<T> success(T data) {
return new ResponseUtils<T>(200, null, data);
}
/**
* 成功-不返回data
*
* @param <T>
* @return
*/
public static <T> ResponseUtils<T> success(String msg) {
return new ResponseUtils<T>(200, msg);
}
/**
* 成功-返回data+msg
*
* @param <T>
* @return
*/
public static <T> ResponseUtils<T> success(String msg, T data) {
return new ResponseUtils<T>(200, msg, data);
}
/**
* 失败
*
* @param <T>
* @return
*/
public static <T> ResponseUtils<T> fail(String msg) {
return new ResponseUtils<T>(500, msg, null);
}
/**
* 失败-code
*
* @param <T>
* @return
*/
public static <T> ResponseUtils<T> fail(int code, String msg) {
return new ResponseUtils<T>(code, msg, null);
}
@Override
public String toString() {
return "RestResponse{" + "code=" + code + ", msg='" + msg + '\'' + ", data=" + data + '}';
}
}
六.JavaBean
import com.example.recordlog.anotation.NotNull;
import lombok.Data;
import java.io.Serializable;
import java.util.Date;
@Data
public class StudentParam {
@NotNull
private Integer id;
private Integer age;
@NotEmpty
private String name;
}
七.测试
注意事项:对象参数校验使用 @ValidParam 注解 ,单个参数使用 @NotNull
import com.example.recordlog.anotation.NotNull;
import com.example.recordlog.anotation.ValidParam;
import com.example.recordlog.bean.UserInfo;
import com.example.recordlog.tools.ResponseUtils;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
/**
* 测试验证
*/
@RestController
@RequestMapping
public class ValidatorTestController {
/**
* 使用 @ValidParam注解 对象验证,单个参数使用 @NotNull验证
*/
@PostMapping("/validator")
public ResponseUtils validatorObject(@ValidParam UserInfo userInfo, @NotNull Integer limit) {
return ResponseUtils.success(userInfo);
}
}