package com.dl.framework.aspectj; import cn.hutool.core.util.ArrayUtil; import com.dl.common.annotation.RateLimiter; import com.dl.common.constant.CacheConstants; import com.dl.common.enums.LimitType; import com.dl.common.exception.ServiceException; import com.dl.common.utils.MessageUtils; import com.dl.common.utils.ServletUtils; import com.dl.common.utils.StringUtils; import com.dl.common.utils.redis.RedisUtils; import lombok.extern.slf4j.Slf4j; import org.aspectj.lang.JoinPoint; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Before; import org.aspectj.lang.reflect.MethodSignature; import org.redisson.api.RateType; import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.ParameterNameDiscoverer; import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; import org.springframework.expression.ExpressionParser; import org.springframework.expression.ParserContext; import org.springframework.expression.common.TemplateParserContext; import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.expression.spel.support.StandardEvaluationContext; import org.springframework.stereotype.Component; import java.lang.reflect.Method; /** * 限流处理 * * @author Lion Li */ @Slf4j @Aspect @Component public class RateLimiterAspect { /** * 定义spel表达式解析器 */ private final ExpressionParser parser = new SpelExpressionParser(); /** * 定义spel解析模版 */ private final ParserContext parserContext = new TemplateParserContext(); /** * 定义spel上下文对象进行解析 */ private final EvaluationContext context = new StandardEvaluationContext(); /** * 方法参数解析器 */ private final ParameterNameDiscoverer pnd = new DefaultParameterNameDiscoverer(); @Before("@annotation(rateLimiter)") public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable { int time = rateLimiter.time(); int count = rateLimiter.count(); String combineKey = getCombineKey(rateLimiter, point); try { RateType rateType = RateType.OVERALL; if (rateLimiter.limitType() == LimitType.CLUSTER) { rateType = RateType.PER_CLIENT; } long number = RedisUtils.rateLimiter(combineKey, rateType, count, time); if (number == -1) { String message = rateLimiter.message(); if (StringUtils.startsWith(message, "{") && StringUtils.endsWith(message, "}")) { message = MessageUtils.message(StringUtils.substring(message, 1, message.length() - 1)); } throw new ServiceException(message); } log.info("限制令牌 => {}, 剩余令牌 => {}, 缓存key => '{}'", count, number, combineKey); } catch (Exception e) { if (e instanceof ServiceException) { throw e; } else { throw new RuntimeException("服务器限流异常,请稍候再试"); } } } public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) { String key = rateLimiter.key(); // 获取方法(通过方法签名来获取) MethodSignature signature = (MethodSignature) point.getSignature(); Method method = signature.getMethod(); Class targetClass = method.getDeclaringClass(); // 判断是否是spel格式 if (StringUtils.containsAny(key, "#")) { // 获取参数值 Object[] args = point.getArgs(); // 获取方法上参数的名称 String[] parameterNames = pnd.getParameterNames(method); if (ArrayUtil.isEmpty(parameterNames)) { throw new ServiceException("限流key解析异常!请联系管理员!"); } for (int i = 0; i < parameterNames.length; i++) { context.setVariable(parameterNames[i], args[i]); } // 解析返回给key try { Expression expression; if (StringUtils.startsWith(key, parserContext.getExpressionPrefix()) && StringUtils.endsWith(key, parserContext.getExpressionSuffix())) { expression = parser.parseExpression(key, parserContext); } else { expression = parser.parseExpression(key); } key = expression.getValue(context, String.class) + ":"; } catch (Exception e) { throw new ServiceException("限流key解析异常!请联系管理员!"); } } StringBuilder stringBuffer = new StringBuilder(CacheConstants.RATE_LIMIT_KEY); stringBuffer.append(ServletUtils.getRequest().getRequestURI()).append(":"); if (rateLimiter.limitType() == LimitType.IP) { // 获取请求ip stringBuffer.append(ServletUtils.getClientIP()).append(":"); } else if (rateLimiter.limitType() == LimitType.CLUSTER) { // 获取客户端实例id stringBuffer.append(RedisUtils.getClient().getId()).append(":"); } return stringBuffer.append(key).toString(); } }