1 限流
为了维护系统稳定性和防止DDoS攻击,需要对系统请求量进行限制。
2 滑动窗口
限流方式有:固定窗口,滑动窗口,令牌桶和漏斗。滑动窗口的意思是:维护一个长度固定的窗口,动态统计窗口内请求次数,如果窗口内请求次数超过阈值则不允许访问。
3 实现
参考https://www.jianshu.com/p/cb11e552505b。采用Redis的zset数据结构,将当前请求的时间戳作为score字段,统计窗口时间内请求次数是否超过限制。
完整代码在https://gitcode.com/zsss1/ratelimit/overview
// 限流类型
public enum LimitType {DEFAULT,IP
}
// 限流注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiter {String key() default "rate:limiter:";long limit() default 1;long expire() default 1;String message() default "访问频繁";LimitType limitType() default LimitType.IP;
}
// 限流切面
@Component
@Aspect
public class RateLimiterHandler {private static final Logger LOGGER = LoggerFactory.getLogger(RateLimiterHandler.class);@Autowiredprivate RedisTemplate<String, Object> redisTemplate;@Autowired@Qualifier("sliding_window")private RedisScript<Long> redisScript;// AOP动态代理com.example包下所有@annotation注解的方法@Around("execution(* com.example..*.*(..)) && @annotation(rateLimiter)")public Object around(ProceedingJoinPoint proceedingJoinPoint, RateLimiter rateLimiter) throws Throwable {Object[] args = proceedingJoinPoint.getArgs();long currentTime = Long.parseLong((String) args[0]);MethodSignature signature = (MethodSignature) proceedingJoinPoint.getSignature();Method method = signature.getMethod();StringBuilder limitKey = new StringBuilder(rateLimiter.key());if (rateLimiter.limitType() == LimitType.IP) {limitKey.append("127.0.0.1");}String className = method.getDeclaringClass().getName();String methodName = method.getName();limitKey.append("_").append(className).append("_").append(methodName);long limitCount = rateLimiter.limit();long windowTime = rateLimiter.expire();List<String> keyList = new ArrayList<>();keyList.add(limitKey.toString());Long result = redisTemplate.execute(redisScript, keyList, windowTime, currentTime, limitCount);if (result != null && result != 1) {throw new RuntimeException(rateLimiter.message());}return proceedingJoinPoint.proceed();}
}
lua">-- 如果允许本次请求,返回1;如果不允许本次请求,返回0
--获取KEYlocal key = KEYS[1]--获取ARGV内的参数-- 缓存时间local expire = tonumber(ARGV[1])-- 当前时间local currentMs = tonumber(ARGV[2])-- 最大次数local limit_count = tonumber(ARGV[3])--窗口开始时间local windowStartMs = currentMs - tonumber(expire * 1000)--获取key的次数local current = redis.call('zcount', key, windowStartMs, currentMs)--如果key的次数存在且大于预设值直接返回当前key的次数if current and tonumber(current) >= limit_count thenreturn 0;
end-- 清除所有过期成员redis.call("ZREMRANGEBYSCORE", key, 0, windowStartMs);-- 添加当前成员redis.call("zadd", key, currentMs, currentMs);redis.call("expire", key, expire);--返回key的次数return 1
// 测试类
// 为了方便统计当前时间,将时间作为请求参数传入接口
@SpringBootTest(classes = DemoApplication.class)
@AutoConfigureMockMvc
public class RateLimitControllerTest {@Autowiredprivate WebApplicationContext webApplicationContext;private MockMvc mockMvc;@BeforeEachpublic void setUp() throws Exception {mockMvc = MockMvcBuilders.webAppContextSetup(webApplicationContext).build();}@Testpublic void test_rate_limit() throws Exception {String url = "/rate/test";Map<Long, Integer> timeStatusMap = new LinkedHashMap<>();for (int i = 0; i < 20; i++) {Thread.sleep(800);long currentTime = System.currentTimeMillis();MockHttpServletRequestBuilder builder = MockMvcRequestBuilders.get(url).param("currentTime", String.valueOf(currentTime)).accept(MediaType.APPLICATION_JSON);int status = mockMvc.perform(builder).andReturn().getResponse().getStatus();timeStatusMap.put(currentTime, status);}for (Map.Entry<Long, Integer> entry : timeStatusMap.entrySet()) {Long currentTime = entry.getKey();int status = entry.getValue();int spectedStatus = getStatusOfCurrentTime(currentTime, timeStatusMap.entrySet());System.out.println(status + ", " + spectedStatus + ", " + currentTime);// assertEquals(status, spectedStatus);}}private int getStatusOfCurrentTime(Long currentTime, Set<Map.Entry<Long, Integer>> set) {long startTime = currentTime - 5000;int count = 0;for (Map.Entry<Long, Integer> entry : set) {if (entry.getKey() >= startTime && entry.getKey() < currentTime && entry.getValue() == 200) {count++;}}if (count < 5) {return 200;}return 400;}
}
// 接口
@RestController
@RequestMapping("/rate")
public class RateLimitController {@GetMapping("/test")@RateLimiter(limit = 5, expire = 5, limitType = LimitType.IP)public String test(String currentTime) {return "h";}
}