Spring Cloud Gateway 限流适配多规则的解决方案
in 代码 with 0 comment

Spring Cloud Gateway 限流适配多规则的解决方案

in 代码 with 0 comment

Spring Cloud Gateway 限流适配多规则的解决方案

首先要说明,本文是使用的 Spring Cloud Gateway 自带的或者称原生的 Redis 限流!

背景

限流作用就不说了,往往都是防止一些恶意请求,无限制请求接口导致服务处理时间过长,继而导致响应延迟,服务阻塞等等,所以会对高频率的一些接口添加限流这样的功能。

通常,我们往往是针对 1 个路由或者说是对 1 个接口进行限流,限流的规则通常是:XXX 路由 XXX 在 XXX 时间内最多允许访问 XXX 次

比如:查询用户信息接口 [路由] 每个用户 [条件] 每秒 [频率时间] 最多支持访问 10 次 [频率最大限制]

举个明白点的例子,我 1 秒内连续请求 11 次 [查询用户信息接口],那么第 11 次就应该被拦截,提示请求频繁,相信大家在一些双 11 这样的节日里会遇到过类似情况~

我们换个规则,再举个例子:[查询用户信息接口] 每秒 最多支持访问 100 次~,也就是说不管谁请求,反正 [查询用户信息接口] 1 秒内最大支持访问 100 次请求,超过 100 的都会被拦截~

以上两个例子,都是单独使用,是对 1 个路由指定了 1 个限流的规则,实际业务需求中,可能还需要 2 个或者多个规则同时使用。

比如:

[查询用户信息接口] 每个用户 每秒 最多支持访问 10 次 ,这是规则 1,用来限制单个用户的次数

同时,[查询用户信息接口] 每秒 最多支持访问 100 次~,这是规则 2,用来限制接口的次数

这个我再举个明白点的例子,假如有 10 个人在同 1 秒来请求 [查询用户信息接口]

前 8 个人都在 1 秒内请求 10 次,(8 个人每个人都不违反规则 1,接口请求总数也不超过 100 次,接口还可以请求 20 次,不违反规则 2)

第 9 个人请求 11 次,(达到规则 1 限流条件,这个人第 11 次请求肯定被拦截,接口请求总数不超过 100 次,接口还可以请求 9 次)

第 10 个人请求 10 次,(不违反规则 1,接口请求为 101 次,总数超过 100 次,达到规则 2 限流条件,所以这个人第 10 次请求肯定被拦截)

当然请求顺序这都是理想状态,实际场景中顺序会有差别~

既然了解后,那么现在的问题就是:Spring Cloud Gateway 自带的限流默认 1 个路由(或者说是 1 个接口)只能配置 1 个限流规则!本文就是来解决这种问题,让 1 个路由适配多个规则!🤯

!!!Spring Cloud Gateway 提供了一套限流方案的接口,并且也基于 Redis 实现了一套限流方案,这个也就是本文的要着重分析的点!

Spring Cloud Gateway 大致流程熟悉

大致流程图

SpringCloudGateway-01

具体流程这里就不说了,我直接说本文的涉及的要点。

当请求进入到网关,网关会根据请求路由来组装对应的过滤器,而我们的限流也是其中的过滤器,Spring Cloud Gateway 自己的实现就是:RequestRateLimiterGatewayFilterFactory,所以我们要分析其源码,了解它大致干了什么事,我们才好知道有没有办法调整!

平常配置使用回顾

分析前,我们先回顾下平常我们配置限流是怎么配置的。

附: RateLimiterConfig,首先我们定义好限流规则 KeyResolver

/**
 * Author: Suremotoo
 */
@Configuration
public class RateLimiterConfig {


    @Primary
    @Bean(value = "remoteAddrKeyResolver")
    public KeyResolver remoteAddrKeyResolver() {
        return exchange -> {
            String hostAddress = exchange.getRequest().getRemoteAddress().getAddress().getHostAddress();
            // log("remoteAddrKeyResolver 限流规则 ip {}", hostAddress);
            return Mono.just(hostAddress);
        };
    }

    /**
     * 按照 Path 限流
     *
     * @return key
     */
    @Bean(value = "pathKeyResolver")
    public KeyResolver pathKeyResolver() {
        return exchange -> {
            Route route = (Route)exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR);
            // log("pathKeyResolver 限流规则 ip {}", route.getId());
            return Mono.just(exchange.getRequest().getPath().toString());
        };
    }

}

然后在 application.yml 配置路由的限流, 示例:

spring:
  cloud:
    gateway:
      routes:
        - id: query_user_info_route
          uri: lb://user-center
          filters:
            - name: RequestRateLimiter
              args:
                # 令牌桶每秒填充平均速率
                redis-rate-limiter.replenishRate: 1
                # 令牌桶的上限
                redis-rate-limiter.burstCapacity: 10
                # 使用 SpEL 表达式从 Spring 容器中获取 Bean 对象,pathKeyResolver 是根据地址来限流
                key-resolver: "#{@remoteAddrKeyResolver}" # 详情见 RateLimiterConfig

可以看到,过滤器 (filters),我们配置的是 RequestRateLimiter,这里的 RequestRateLimiter 其实指的就是 RequestRateLimiterGatewayFilterFactory ,只是省略了后面的 GatewayFilterFactory~

该过滤器的参数有 redis-rate-limiter、key-resolver,说明这两个其实是很重要的属性!

1 个限流规则我们配置 1 个过滤器及属性,那么我想再加一个规则,正常情况我们会这样做,示例:

spring:
  cloud:
    gateway:
      routes:
        - id: query_user_info_route
          uri: lb://user-center
          filters:
          # 第一个限流过滤器和规则
            - name: RequestRateLimiter
              args:
                # 令牌桶每秒填充平均速率
                redis-rate-limiter.replenishRate: 1
                # 令牌桶的上限
                redis-rate-limiter.burstCapacity: 100
                # 使用 SpEL 表达式从 Spring 容器中获取 Bean 对象,pathKeyResolver 是根据地址来限流
                key-resolver: "#{@pathKeyResolver}" # 详情见 RateLimiterConfig
          # 第二个限流过滤器和规则
            - name: RequestRateLimiter
              args:
                # 令牌桶每秒填充平均速率
                redis-rate-limiter.replenishRate: 1
                # 令牌桶的上限
                redis-rate-limiter.burstCapacity: 10
                # 使用 SpEL 表达式从 Spring 容器中获取 Bean 对象,remoteAddrKeyResolver 是根据请求 ip 来限流
                key-resolver: "#{@remoteAddrKeyResolver}" # 详情见 RateLimiterConfig

写的时候还洋洋洒洒~

skrjlg

写完后看上没问题,程序也能跑起来,但你会发现实际就只有 1 个生效,下面属性的把上面的覆盖了!,欧了买了噶,~

omg

是配置了两个一样的过滤器,实际运行的时候,也确实都跑了两次这个过滤器,只是每次取的速率什么的,是相同的~,相当于同一个限流规则,校验了两遍~

好得很

具体跑起来效果我就不展示了,接下来我们来正儿八经分析下源码,看看什么情况!

RequestRateLimiterGatewayFilterFactory 源码分析

这里仅列出核心代码分析😬

public class RequestRateLimiterGatewayFilterFactory extends AbstractGatewayFilterFactory<RequestRateLimiterGatewayFilterFactory.Config> {
    public static final String KEY_RESOLVER_KEY = "keyResolver";
    private static final String EMPTY_KEY = "____EMPTY_KEY__";
   
    /**
     * 限流算法及实现
     */
    private final RateLimiter defaultRateLimiter;
  
     /**
     * 限流关键字 key
     */
    private final KeyResolver defaultKeyResolver;
    private boolean denyEmptyKey = true;
    private String emptyKeyStatusCode;

    public RequestRateLimiterGatewayFilterFactory(RateLimiter defaultRateLimiter, KeyResolver defaultKeyResolver) {
        super(RequestRateLimiterGatewayFilterFactory.Config.class);
        this.emptyKeyStatusCode = HttpStatus.FORBIDDEN.name();
        this.defaultRateLimiter = defaultRateLimiter;
        this.defaultKeyResolver = defaultKeyResolver;
    }

    public KeyResolver getDefaultKeyResolver() {
        return this.defaultKeyResolver;
    }

    public RateLimiter getDefaultRateLimiter() {
        return this.defaultRateLimiter;
    }

    public boolean isDenyEmptyKey() {
        return this.denyEmptyKey;
    }

    public void setDenyEmptyKey(boolean denyEmptyKey) {
        this.denyEmptyKey = denyEmptyKey;
    }

    public String getEmptyKeyStatusCode() {
        return this.emptyKeyStatusCode;
    }

    public void setEmptyKeyStatusCode(String emptyKeyStatusCode) {
        this.emptyKeyStatusCode = emptyKeyStatusCode;
    }

    public GatewayFilter apply(RequestRateLimiterGatewayFilterFactory.Config config) {
        KeyResolver resolver = (KeyResolver)this.getOrDefault(config.keyResolver, this.defaultKeyResolver);
        RateLimiter<Object> limiter = (RateLimiter)this.getOrDefault(config.rateLimiter, this.defaultRateLimiter);
        boolean denyEmpty = (Boolean)this.getOrDefault(config.denyEmptyKey, this.denyEmptyKey);
        HttpStatusHolder emptyKeyStatus = HttpStatusHolder.parse((String)this.getOrDefault(config.emptyKeyStatus, this.emptyKeyStatusCode));
        return (exchange, chain) -> {
            Route route = (Route)exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR);
            return resolver.resolve(exchange).defaultIfEmpty("____EMPTY_KEY__").flatMap((key) -> {
                if ("____EMPTY_KEY__".equals(key)) {
                    if (denyEmpty) {
                        ServerWebExchangeUtils.setResponseStatus(exchange, emptyKeyStatus);
                        return exchange.getResponse().setComplete();
                    } else {
                        return chain.filter(exchange);
                    }
                } else {
                    // isAllowed 方法,根据 [路由 id] 和 [限流关键字 key] 来判断是否要限流
                    return limiter.isAllowed(route.getId(), key).flatMap((response) -> {
                        Iterator var4 = response.getHeaders().entrySet().iterator();

                        while(var4.hasNext()) {
                            Entry<String, String> header = (Entry)var4.next();
                            exchange.getResponse().getHeaders().add((String)header.getKey(), (String)header.getValue());
                        }

                        if (response.isAllowed()) {
                            return chain.filter(exchange);
                        } else {
                            ServerWebExchangeUtils.setResponseStatus(exchange, config.getStatusCode());
                            return exchange.getResponse().setComplete();
                        }
                    });
                }
            });
        };
    }

    private <T> T getOrDefault(T configValue, T defaultValue) {
        return configValue != null ? configValue : defaultValue;
    }

上述代码中,结合我们从 application.yml 配置中查看,该源码中其实最重要的就是:

RateLimiter :限流算法及实现(实际实现是令牌桶算法,这里先不做深入探究)

KeyResolver :限流关键字 key(这里 key 其实就是我们说的对用户限流、对接口限流,当我们要对 ip 限流时,这个 key 就是请求的 ip)

还有就是 limiter.isAllowed 这个函数,是校验是否达到限流条件的重要方法!

KeyResolver 看上去就不是影响多规则限流的重要因素~,那么我们就直接来看看 RateLimiter~

RateLimiter 源码分析

打开源码一看,哦是 interface,我们看看实现类( idea 中点击下图标记📌处即可查看)

RateLimiter

发现有两个实现类,一个是抽象类 AbstractRateLimiter,一个是基于 Redis 实现的 RedisRateLimiter,(o゜▽゜)o☆[BINGO!],肯定是 RedisRateLimiter,我们直接打开它~

RateLimiter-Impl

RedisRateLimiter 源码(别着急看代码,先往下翻)

@ConfigurationProperties("spring.cloud.gateway.redis-rate-limiter")
public class RedisRateLimiter extends AbstractRateLimiter<RedisRateLimiter.Config> implements ApplicationContextAware {
    /** @deprecated */
    @Deprecated
    public static final String REPLENISH_RATE_KEY = "replenishRate";
    /** @deprecated */
    @Deprecated
    public static final String BURST_CAPACITY_KEY = "burstCapacity";
    public static final String CONFIGURATION_PROPERTY_NAME = "redis-rate-limiter";
    public static final String REDIS_SCRIPT_NAME = "redisRequestRateLimiterScript";
    public static final String REMAINING_HEADER = "X-RateLimit-Remaining";
    public static final String REPLENISH_RATE_HEADER = "X-RateLimit-Replenish-Rate";
    public static final String BURST_CAPACITY_HEADER = "X-RateLimit-Burst-Capacity";
    private Log log = LogFactory.getLog(this.getClass());
    private ReactiveRedisTemplate<String, String> redisTemplate;
    private RedisScript<List<Long>> script;
    private AtomicBoolean initialized = new AtomicBoolean(false);
    private RedisRateLimiter.Config defaultConfig;
    private boolean includeHeaders = true;
    private String remainingHeader = "X-RateLimit-Remaining";
    private String replenishRateHeader = "X-RateLimit-Replenish-Rate";
    private String burstCapacityHeader = "X-RateLimit-Burst-Capacity";

    public RedisRateLimiter(ReactiveRedisTemplate<String, String> redisTemplate, RedisScript<List<Long>> script, Validator validator) {
        super(RedisRateLimiter.Config.class, "redis-rate-limiter", validator);
        this.redisTemplate = redisTemplate;
        this.script = script;
        this.initialized.compareAndSet(false, true);
    }

    public RedisRateLimiter(int defaultReplenishRate, int defaultBurstCapacity) {
        super(RedisRateLimiter.Config.class, "redis-rate-limiter", (Validator)null);
        this.defaultConfig = (new RedisRateLimiter.Config()).setReplenishRate(defaultReplenishRate).setBurstCapacity(defaultBurstCapacity);
    }

    static List<String> getKeys(String id) {
        String prefix = "request_rate_limiter.{" + id;
        String tokenKey = prefix + "}.tokens";
        String timestampKey = prefix + "}.timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }

    public boolean isIncludeHeaders() {
        return this.includeHeaders;
    }

    public void setIncludeHeaders(boolean includeHeaders) {
        this.includeHeaders = includeHeaders;
    }

    public String getRemainingHeader() {
        return this.remainingHeader;
    }

    public void setRemainingHeader(String remainingHeader) {
        this.remainingHeader = remainingHeader;
    }

    public String getReplenishRateHeader() {
        return this.replenishRateHeader;
    }

    public void setReplenishRateHeader(String replenishRateHeader) {
        this.replenishRateHeader = replenishRateHeader;
    }

    public String getBurstCapacityHeader() {
        return this.burstCapacityHeader;
    }

    public void setBurstCapacityHeader(String burstCapacityHeader) {
        this.burstCapacityHeader = burstCapacityHeader;
    }

    public void setApplicationContext(ApplicationContext context) throws BeansException {
        if (this.initialized.compareAndSet(false, true)) {
            this.redisTemplate = (ReactiveRedisTemplate)context.getBean("stringReactiveRedisTemplate", ReactiveRedisTemplate.class);
            this.script = (RedisScript)context.getBean("redisRequestRateLimiterScript", RedisScript.class);
            if (context.getBeanNamesForType(Validator.class).length > 0) {
                this.setValidator((Validator)context.getBean(Validator.class));
            }
        }

    }

    RedisRateLimiter.Config getDefaultConfig() {
        return this.defaultConfig;
    }

    public Mono<Response> isAllowed(String routeId, String id) {
        if (!this.initialized.get()) {
            throw new IllegalStateException("RedisRateLimiter is not initialized");
        } else {
            RedisRateLimiter.Config routeConfig = this.loadConfiguration(routeId);
            int replenishRate = routeConfig.getReplenishRate();
            int burstCapacity = routeConfig.getBurstCapacity();

            try {
                List<String> keys = getKeys(id);
                List<String> scriptArgs = Arrays.asList(replenishRate + "", burstCapacity + "", Instant.now().getEpochSecond() + "", "1");
                Flux<List<Long>> flux = this.redisTemplate.execute(this.script, keys, scriptArgs);
                return flux.onErrorResume((throwable) -> {
                    return Flux.just(Arrays.asList(1L, -1L));
                }).reduce(new ArrayList(), (longs, l) -> {
                    longs.addAll(l);
                    return longs;
                }).map((results) -> {
                    boolean allowed = (Long)results.get(0) == 1L;
                    Long tokensLeft = (Long)results.get(1);
                    Response response = new Response(allowed, this.getHeaders(routeConfig, tokensLeft));
                    if (this.log.isDebugEnabled()) {
                        this.log.debug("response: " + response);
                    }

                    return response;
                });
            } catch (Exception var9) {
                this.log.error("Error determining if user allowed from redis", var9);
                return Mono.just(new Response(true, this.getHeaders(routeConfig, -1L)));
            }
        }
    }

    RedisRateLimiter.Config loadConfiguration(String routeId) {
        RedisRateLimiter.Config routeConfig = (RedisRateLimiter.Config)this.getConfig().getOrDefault(routeId, this.defaultConfig);
        if (routeConfig == null) {
            routeConfig = (RedisRateLimiter.Config)this.getConfig().get("defaultFilters");
        }

        if (routeConfig == null) {
            throw new IllegalArgumentException("No Configuration found for route " + routeId + " or defaultFilters");
        } else {
            return routeConfig;
        }
    }

    @NotNull
    public Map<String, String> getHeaders(RedisRateLimiter.Config config, Long tokensLeft) {
        Map<String, String> headers = new HashMap();
        if (this.isIncludeHeaders()) {
            headers.put(this.remainingHeader, tokensLeft.toString());
            headers.put(this.replenishRateHeader, String.valueOf(config.getReplenishRate()));
            headers.put(this.burstCapacityHeader, String.valueOf(config.getBurstCapacity()));
        }

        return headers;
    }

    @Validated
    public static class Config {
        @Min(1L)
        private int replenishRate;
        @Min(1L)
        private int burstCapacity = 1;

        public Config() {
        }

        public int getReplenishRate() {
            return this.replenishRate;
        }

        public RedisRateLimiter.Config setReplenishRate(int replenishRate) {
            this.replenishRate = replenishRate;
            return this;
        }

        public int getBurstCapacity() {
            return this.burstCapacity;
        }

        public RedisRateLimiter.Config setBurstCapacity(int burstCapacity) {
            this.burstCapacity = burstCapacity;
            return this;
        }

        public String toString() {
            return "Config{replenishRate=" + this.replenishRate + ", burstCapacity=" + this.burstCapacity + '}';
        }
    }
}

别看代码多,不要慌!实际上就是基于 Redis 限流是怎么个算法实现的,但是和限流为什么只能有一个规则,好像啥都没有~🤣,说明不在这里~

吓傻了

📢注意了,但是它 extends AbstractRateLimiter 了,继承了 AbstractRateLimiter 类,我们还是看看这个类吧~

AbstractRateLimiter 源码分析

AbstractRateLimiter 源码

public abstract class AbstractRateLimiter<C> extends AbstractStatefulConfigurable<C> implements RateLimiter<C>, ApplicationListener<FilterArgsEvent> {
    private String configurationPropertyName;
    private Validator validator;

    protected AbstractRateLimiter(Class<C> configClass, String configurationPropertyName, Validator validator) {
        super(configClass);
        this.configurationPropertyName = configurationPropertyName;
        this.validator = validator;
    }

    protected String getConfigurationPropertyName() {
        return this.configurationPropertyName;
    }

    protected Validator getValidator() {
        return this.validator;
    }

    public void setValidator(Validator validator) {
        this.validator = validator;
    }
		
    public void onApplicationEvent(FilterArgsEvent event) {
        Map<String, Object> args = event.getArgs();
        if (!args.isEmpty() && this.hasRelevantKey(args)) {
            String routeId = event.getRouteId();
            C routeConfig = this.newConfig();
            ConfigurationUtils.bind(routeConfig, args, this.configurationPropertyName, this.configurationPropertyName, this.validator);
          // 重点  
          this.getConfig().put(routeId, routeConfig);
        }
    }

    private boolean hasRelevantKey(Map<String, Object> args) {
        return args.keySet().stream().anyMatch((key) -> {
            return key.startsWith(this.configurationPropertyName + ".");
        });
    }

    public String toString() {
        return (new ToStringCreator(this)).append("configurationPropertyName", this.configurationPropertyName).append("config", this.getConfig()).append("configClass", this.getConfigClass()).toString();
    }
}

代码不多,就一个核心方法 onApplicationEvent,参数是个 FilterArgsEvent,看上去是把过滤器的参数 args 都获取出来,再做处理

小提示,看看人家的命名,一看就让人知道大概什么意思,以后大家也注意下命名!

贴一下限流的核心配置示例:

        - name: RequestRateLimiter
          args:
            # 令牌桶每秒填充平均速率
            redis-rate-limiter.replenishRate: 1
            # 令牌桶的上限
            redis-rate-limiter.burstCapacity: 10
            # 使用 SpEL 表达式从 Spring 容器中获取 Bean 对象,pathKeyResolver 是根据地址来限流
            key-resolver: "#{@pathKeyResolver}"

捋一捋,这个过滤器的参数 args 就是限流参数,而

RedisRateLimiter extends AbstractRateLimiter<RedisRateLimiter.Config>

那么 onApplicationEvent,应该是把参数对应的 routeConfig 对象初始化出来~,也就是 RedisRateLimiter.Config 的这个 Config 对象,

Config 里就两个属性,也就是限流的重要参数,果然没错~

简单再贴一下 RedisRateLimiter.Config 代码

@Validated
public static class Config {
    @Min(1L)
    private int replenishRate;
    @Min(1L)
    private int burstCapacity = 1;

    public Config() {
    }
    // 省略...
}

最后最后有个 this.getConfig().put(routeId, routeConfig);

这不就是把路由和其对应的限流规则存到一个 Map 里嘛~,盲猜都知道这个 this.getConfig() 是个 Map,可以去 AbstractStatefulConfigurable 代码里看,这里就不展示了~

AbstractRateLimiter<C> extends AbstractStatefulConfigurable<C>

其实看到 this.getConfig().put(routeId, routeConfig); 这里我大概已经知道是什么问题了:我们针对一个 1 路由配置多个限流规则,最终名为 Config 的 Map 里存储的只有 1 个!!因为存储到 Map 里的 key 就是 routeId~

你的路由 id 是固定的,所以后面的规则把前面的覆盖了~~~🎉🎉🎉🎉

叉会腰

改造方案

既然找到问题了,我们就想办法改造它!经过前面的分析,我们应该要改造的就是 onApplicationEvent 方法里的 this.getConfig().put(routeId, routeConfig);

我们应该改造成,放入 Config 里的 Key 不用 RouteId!

用什么呢,我这里 使用 routeId 和 KeyResolver 的 hashcode 组合

改造前再捋清楚,Spring Cloud Gateway 自带的 Redis 限流实现类是 RedisRateLimiter,它继承的抽象类 AbstractRateLimiter,而我们要改造的方法在 AbstractRateLimiter

所以我们重写一个 RedisRateLimiter,重写 onApplicationEvent 方法 !

ok!Just Do It~

自定义 DiyRedisRateLimiter

首先,我们新建 1 个类,叫 DiyRedisRateLimiter,剩下的代码就从 RedisRateLimiter 全部拷贝过来!

然后重写 onApplicationEvent 方法!

DiyRedisRateLimiter 代码:

/**
 * Author: Suremotoo
 */
public class DiyRedisRateLimiter extends AbstractRateLimiter<DiyRedisRateLimiter.Config>
    implements ApplicationContextAware {

    public static final String REPLENISH_RATE_KEY = "replenishRate";

    public static final String BURST_CAPACITY_KEY = "burstCapacity";
    public static final String CONFIGURATION_PROPERTY_NAME = "redis-rate-limiter";
    public static final String REDIS_SCRIPT_NAME = "redisRequestRateLimiterScript";
    public static final String REMAINING_HEADER = "X-RateLimit-Remaining";
    public static final String REPLENISH_RATE_HEADER = "X-RateLimit-Replenish-Rate";
    public static final String BURST_CAPACITY_HEADER = "X-RateLimit-Burst-Capacity";
    private Log log = LogFactory.getLog(this.getClass());
    private ReactiveRedisTemplate<String, String> redisTemplate;
    private RedisScript<List<Long>> script;
    private AtomicBoolean initialized = new AtomicBoolean(false);
    private DiyRedisRateLimiter.Config defaultConfig;
    private boolean includeHeaders = true;
    private String remainingHeader = "X-RateLimit-Remaining";
    private String replenishRateHeader = "X-RateLimit-Replenish-Rate";
    private String burstCapacityHeader = "X-RateLimit-Burst-Capacity";

    public DiyRedisRateLimiter(ReactiveRedisTemplate<String, String> redisTemplate, RedisScript<List<Long>> script,
        Validator validator) {
        super(DiyRedisRateLimiter.Config.class, "redis-rate-limiter", validator);
        this.redisTemplate = redisTemplate;
        this.script = script;
        this.initialized.compareAndSet(false, true);
    }

    public DiyRedisRateLimiter(int defaultReplenishRate, int defaultBurstCapacity) {
        super(DiyRedisRateLimiter.Config.class, "redis-rate-limiter", (Validator)null);
        this.defaultConfig = (new DiyRedisRateLimiter.Config()).setReplenishRate(defaultReplenishRate)
            .setBurstCapacity(defaultBurstCapacity);
    }

    static List<String> getKeys(String id) {
        String prefix = "request_rate_limiter.{" + id;
        String tokenKey = prefix + "}.tokens";
        String timestampKey = prefix + "}.timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }

    public boolean isIncludeHeaders() {
        return this.includeHeaders;
    }

    public void setIncludeHeaders(boolean includeHeaders) {
        this.includeHeaders = includeHeaders;
    }

    public String getRemainingHeader() {
        return this.remainingHeader;
    }

    public void setRemainingHeader(String remainingHeader) {
        this.remainingHeader = remainingHeader;
    }

    public String getReplenishRateHeader() {
        return this.replenishRateHeader;
    }

    public void setReplenishRateHeader(String replenishRateHeader) {
        this.replenishRateHeader = replenishRateHeader;
    }

    public String getBurstCapacityHeader() {
        return this.burstCapacityHeader;
    }

    public void setBurstCapacityHeader(String burstCapacityHeader) {
        this.burstCapacityHeader = burstCapacityHeader;
    }

    @Override
    public void setApplicationContext(ApplicationContext context) throws BeansException {
        if (this.initialized.compareAndSet(false, true)) {
            this.redisTemplate =
                (ReactiveRedisTemplate)context.getBean("stringReactiveRedisTemplate", ReactiveRedisTemplate.class);
            this.script = (RedisScript)context.getBean("redisRequestRateLimiterScript", RedisScript.class);
            if (context.getBeanNamesForType(Validator.class).length > 0) {
                this.setValidator((Validator)context.getBean(Validator.class));
            }
        }

    }

    DiyRedisRateLimiter.Config getDefaultConfig() {
        return this.defaultConfig;
    }

  
                                                                                                      
   //                                    ``.-/++oosyyyyyyyyyssso++/-.`                                 
   //                                ./oydmmmhyso+/:------://++oshdmmmhs/-`                            
   //                            `:odmdhs+:.``                   ```-/ohmNmyo:`                        
   //                         ./ymmy+-``                               ``-+ymNd+.                      
   //                      `:ymmy/.`            小哥哥                      `.-sdmy:                    
   //                    `/hmh/.                    小姐姐                      ./hmy:`                 
   //                   :hNy:./o-                                                `:dNh/`               
   //                 `oNd:`:dNs.             帅哥                                  `+dNy.              
   //                `sNd- /Nd:`                   美女                               .yNm/`            
   //               `yMh. +Nd-                                                        `+NNo`           
   //              `oNd` /Nm-                                           `ss.            :mNo`          
   //              /Nm- .dN/       `..-:/+osssssooooooooooooo+//:-.``   `dMo             +NN/          
   //             -mN:  /Nh``..:/oyhddddhysssoossssssssssyyhhhddmmmdhys+:hMh`            `hMd`         
   //            `hNo` `yMyohmmdhso/-..`                      ``.-:/oyhdmNMd`             :NM/         
   //            +Mh`  `yNmhmMd.                                       `.yMm.             `hMh         
   //           .mN:    `-. sMd`                                         oMm.              +MN-        
   //  ./+:.    oMy`        sMm.                                         sMm.              :NN:        
   // .dMNNms. .dN:--`      sMm.                                         yMh`              .mM/        
   // :NN+:yNmo/NmyNN-      sMm.        我是 Suremotoo                  `dMy               `mN/        
   // .mMo  :dMNMMMNN:    `+NNs`                                        .NN+               -Nm`        
   //  yMd.  `oNMMdyMs:///hMm:         看这里                            oMm.               oMy         
   //  .NMo    -hMMNNmdhmMMh.                                          -mMs           /yo`:mN-         
   //   +MN:    `/dMm/``hMh`              看这里                         oMMho/-.``    :NMMymN:          
   //    +NN+`    `/mMdsMN:                                            `:oydmNMNmdhhhmMNMMm:           
   //     /mMy`     .yNMMd`                  看这里                       ``.-yMNhhhyo.-:.            
   //      -dMd-      :mMh`                                                  .sNN+`                    
   //       `yMm/      /NNs.                     看这里                      .+mNh:                      
   //        `oNNo`    `/mMmo.`                                         `-smNd/`                       
   //          :dMh-    :dNmNNdo:.`                                ``-/sdNMMy.                         
   //           .yMm+` /mm/.:sdmNmdyo++/:--..`````````.---::::/+osydmNmdy+:yNh-                        
   //            `oNNy/mm:    `.-+oydNMNmmmNmmmmmmmmmmmNNNNNNNNNMNhyo:.`   `/mm+`                      
   //              :dMNMh`      .yy..dN+.-yMh///++++++//::sMd:-:Nm-          -yNy.                     
   //               .sNMh`      +Mm.-Nm.  sMs   下   |    -Nd. `dNy/          `oNh.                    
   //                `oNN+.    `hMh`+Mh`  yM+   面   |   `dN/  sMMN+`         `oNd.                   
   //                  -sdmy+-`.mMo sMs  `hN/   就   |     sMs  :NmdN+`         `hMs                   
   //                    ./sddddMN-`dM/  `dN:   是   |    /Nd` .mN/dN+`    `.:ohNM+                   
   //                       `-+NMh`-Nm.  .NN.   改   |     .mm. `hM/-dNo-:+shddhosNh.                  
   //                         -NMo +Mh   /Nd`   造   |   `mM+ `yMs .ymNMNy+:.` `sNy.                 
   //                         /NN: oMh   oMd`   的   |     `dMs `yMd` `.+Nd.      `yNs`                
   //                         -ss` /s+   /yo`   方   |   `sdo  +dh`    +ms`      .hd- 
   //        这                                      |
   //  					这																				↓
   // 👇👇👇👇👇👇👇👇👇👇👇👇👇👇
    @Override
    public void onApplicationEvent(FilterArgsEvent event) {
        Map<String, Object> args = event.getArgs();
        if (!args.isEmpty() && this.hasRelevantKey(args)) {
            String routeId = event.getRouteId();
            Config routeConfig = this.newConfig();
            ConfigurationUtils.bind(routeConfig, args, super.getConfigurationPropertyName(),
                this.getConfigurationPropertyName(), this.getValidator());
            /**
             * 这里重写 id,防止过冲规则重复被覆盖 by Suremotoo
             */
            // 使用 routeId + KeyResolver 的 hashcode 组合作为配置 id,防止重复
            routeId = routeId + event.getArgs().get("key-resolver").hashCode();
            // System.out.println("routeId put = " + routeId);
            this.getConfig().put(routeId, routeConfig);
        }
    }
  
    /*****************************************************************/

    private boolean hasRelevantKey(Map<String, Object> args) {
        return args.keySet().stream().anyMatch((key) -> {
            return key.startsWith(this.getConfigurationPropertyName() + ".");
        });
    }

    @Override
    public Mono<Response> isAllowed(String routeId, String id) {
        if (!this.initialized.get()) {
            throw new IllegalStateException("DiyRedisRateLimiter is not initialized");
        } else {
            DiyRedisRateLimiter.Config routeConfig = this.loadConfiguration(routeId);
            int replenishRate = routeConfig.getReplenishRate();
            int burstCapacity = routeConfig.getBurstCapacity();

            try {
                List<String> keys = getKeys(id);
                List<String> scriptArgs =
                    Arrays.asList(replenishRate + "", burstCapacity + "", Instant.now().getEpochSecond() + "", "1");
                Flux<List<Long>> flux = this.redisTemplate.execute(this.script, keys, scriptArgs);
                return flux.onErrorResume((throwable) -> {
                    return Flux.just(Arrays.asList(1L, -1L));
                }).reduce(new ArrayList(), (longs, l) -> {
                    longs.addAll(l);
                    return longs;
                }).map((results) -> {
                    boolean allowed = (Long)results.get(0) == 1L;
                    Long tokensLeft = (Long)results.get(1);
                    Response response = new Response(allowed, this.getHeaders(routeConfig, tokensLeft));
                    if (this.log.isDebugEnabled()) {
                        this.log.debug("response: " + response);
                    }

                    return response;
                });
            } catch (Exception var9) {
                this.log.error("Error determining if user allowed from redis", var9);
                return Mono.just(new Response(true, this.getHeaders(routeConfig, -1L)));
            }
        }
    }

    DiyRedisRateLimiter.Config loadConfiguration(String routeId) {
        DiyRedisRateLimiter.Config routeConfig =
            (DiyRedisRateLimiter.Config)this.getConfig().getOrDefault(routeId, this.defaultConfig);
        if (routeConfig == null) {
            routeConfig = (DiyRedisRateLimiter.Config)this.getConfig().get("defaultFilters");
        }

        if (routeConfig == null) {
            throw new IllegalArgumentException("No Configuration found for route " + routeId + " or defaultFilters");
        } else {
            return routeConfig;
        }
    }

    @NotNull
    public Map<String, String> getHeaders(DiyRedisRateLimiter.Config config, Long tokensLeft) {
        Map<String, String> headers = new HashMap();
        if (this.isIncludeHeaders()) {
            headers.put(this.remainingHeader, tokensLeft.toString());
            headers.put(this.replenishRateHeader, String.valueOf(config.getReplenishRate()));
            headers.put(this.burstCapacityHeader, String.valueOf(config.getBurstCapacity()));
        }

        return headers;
    }

    @Validated
    public static class Config {
        @Min(1L)
        private int replenishRate;
        @Min(1L)
        private int burstCapacity = 1;

        public Config() {}

        public int getReplenishRate() {
            return this.replenishRate;
        }

        public DiyRedisRateLimiter.Config setReplenishRate(int replenishRate) {
            this.replenishRate = replenishRate;
            return this;
        }

        public int getBurstCapacity() {
            return this.burstCapacity;
        }

        public DiyRedisRateLimiter.Config setBurstCapacity(int burstCapacity) {
            this.burstCapacity = burstCapacity;
            return this;
        }

        @Override
        public String toString() {
            return "Config{replenishRate=" + this.replenishRate + ", burstCapacity=" + this.burstCapacity + '}';
        }
    }
}

创建完后,我们再在将这个类初始化为 Spring 里



/**
 * Author: Suremotoo
 */
@Configuration
public class RateLimiterConfig {

    /**
     * 使用自定义的限流类
     * 
     * @param redisTemplate
     * @param redisScript
     * @param validator
     * @return
     */
    @Bean
    @Primary
    public DiyRedisRateLimiter diyRedisRateLimiter(ReactiveRedisTemplate<String, String> redisTemplate,
        @Qualifier(DiyRedisRateLimiter.REDIS_SCRIPT_NAME) RedisScript<List<Long>> redisScript, Validator validator) {
        return new DiyRedisRateLimiter(redisTemplate, redisScript, validator);
    }
  
  // .... 其他 KeyResolver 省略,详情见上文描述中的 RateLimiterConfig
  
}

这就弄好了,但是注意,我们还没有改造完!

这仅仅是放入 Map 中已经不是 1 个了!但是用的时候呢?还记得前面提到的 isAllow 方法吗?这个方法是在 RequestRateLimiterGatewayFilterFactory 里的 apply 方法中 ,所以我们还要重写 这里!

复制粘贴

自定义 DiyRequestRateLimiterGatewayFilterFactory

首先,我们新建 1 个类,叫 DiyRequestRateLimiterGatewayFilterFactory,继承 RequestRateLimiterGatewayFilterFactory

然后重写 apply 方法!

DiyRequestRateLimiterGatewayFilterFactory 代码示例:

/**
 * Author: Suremotoo
 */
@Component
public class DiyRequestRateLimiterGatewayFilterFactory extends RequestRateLimiterGatewayFilterFactory {

    private final RateLimiter defaultRateLimiter;

    private final KeyResolver defaultKeyResolver;

    private boolean denyEmptyKey = true;

    public DiyRequestRateLimiterGatewayFilterFactory(RateLimiter defaultRateLimiter, KeyResolver defaultKeyResolver) {
        super(defaultRateLimiter, defaultKeyResolver);
        this.defaultKeyResolver = defaultKeyResolver;
        this.defaultRateLimiter = defaultRateLimiter;

    }

    @Override
    public GatewayFilter apply(RequestRateLimiterGatewayFilterFactory.Config config) {
        KeyResolver resolver = (KeyResolver)this.getOrDefault(config.getKeyResolver(), this.defaultKeyResolver);
        RateLimiter<Object> limiter = (RateLimiter)this.getOrDefault(config.getRateLimiter(), this.defaultRateLimiter);
        boolean denyEmpty = (Boolean)this.getOrDefault(config.getDenyEmptyKey(), this.denyEmptyKey);
        HttpStatusHolder emptyKeyStatus =
            HttpStatusHolder.parse((String)this.getOrDefault(config.getEmptyKeyStatus(), this.getEmptyKeyStatusCode()));
        return (exchange, chain) -> {
            Route route = (Route)exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR);
            return resolver.resolve(exchange).defaultIfEmpty("____EMPTY_KEY__").flatMap((key) -> {
                
                if ("____EMPTY_KEY__".equals(key)) {
                    if (denyEmpty) {
                        ServerWebExchangeUtils.setResponseStatus(exchange, emptyKeyStatus);
                        return exchange.getResponse().setComplete();
                    } else {
                        return chain.filter(exchange);
                    }
                } else {
                    
                     //             `-/++//////+++-`           
                     //          .+o+:.          `:os+`        
                     //        .so:      在这里      .+o.      
                     //       :y:y.                    `yo     
                     //      -h.h`  .-://::///:-.` y`    os    
                     //     `h`+hyo+/-.``.....-:/ooN-    `d.   
                     // `:. o/ ``d.                d-     oo   
                     // +yssmd- `m.    在这里     `m`     +o   
                     // .d`.hmssd-                /h    /-h.   
                     //  :h. :hm:    在这里       .+oyyshy-    
                     //   .h: `h+                  `oy.        
                     //    `ss`+yys+:-.`````..-:/osyh`         
                     //      /dh  `soyod+ooo+doh+`  -y-        
                     //       .y+-/ys::s     o//N/   `d.       
                     //         `:m:d`++     -y.hsooooy/       
                     //           y.s +:     .y`h .y`  s.     
                     //           
                     //   👇👇👇👇👇🏻
                  
                    /**
                     * 改造校验时传入的 routeId
                     */
                    // 获取 routeId
                    String routeId = route.getId();
                    // 使用 routeId+KeyResolver 的 hashcode 获取 config
                    routeId = routeId + resolver.hashCode();
                    
                    return limiter.isAllowed(routeId, key).flatMap((response) -> {
                        Iterator var4 = response.getHeaders().entrySet().iterator();

                        while (var4.hasNext()) {
                            Map.Entry<String, String> header = (Map.Entry)var4.next();
                            exchange.getResponse().getHeaders().add((String)header.getKey(), (String)header.getValue());
                        }

                        if (response.isAllowed()) {
                            return chain.filter(exchange);
                        } else {
                            // 原返回信息代码
                            // ServerWebExchangeUtils.setResponseStatus(exchange, config.getStatusCode());
                            // return exchange.getResponse().setComplete();
                                 

                            // 自定义返回信息 | start
                            ServerHttpResponse httpResponse = exchange.getResponse();
                            httpResponse.setStatusCode(config.getStatusCode());

                            Map<String, Object> dataMap = new HashMap<>(4);
                            dataMap.put("errorCode", "1000");
                            dataMap.put("errorMsg", "操作频繁,歇会再来~");

                            DataBuffer buffer = httpResponse.bufferFactory()
                                .wrap(JSONObject.wrap(dataMap).toString().getBytes(StandardCharsets.UTF_8));
                            return httpResponse.writeWith(Mono.just(buffer));
                            // 自定义返回信息 | end
                        }
                    });
                }
            });
        };
    }

    private <T> T getOrDefault(T configValue, T defaultValue) {
        return configValue != null ? configValue : defaultValue;
    }

}

然后在 application.yml 中使用的时候用自己定义的 DiyRequestRateLimiterGatewayFilterFactory

示例:

spring:
  cloud:
    gateway:
      routes:
        - id: query_user_info_route
          uri: lb://user-center
          filters:
          # 第一个限流过滤器和规则
            - name: DiyRequestRateLimiter
              args:
                # 令牌桶每秒填充平均速率
                redis-rate-limiter.replenishRate: 1
                # 令牌桶的上限
                redis-rate-limiter.burstCapacity: 100
                # 使用 SpEL 表达式从 Spring 容器中获取 Bean 对象,pathKeyResolver 是根据地址来限流
                key-resolver: "#{@pathKeyResolver}" # 详情见 RateLimiterConfig
          # 第二个限流过滤器和规则
            - name: DiyRequestRateLimiter
              args:
                # 令牌桶每秒填充平均速率
                redis-rate-limiter.replenishRate: 1
                # 令牌桶的上限
                redis-rate-limiter.burstCapacity: 10
                # 使用 SpEL 表达式从 Spring 容器中获取 Bean 对象,remoteAddrKeyResolver 是根据请求 ip 来限流
                key-resolver: "#{@remoteAddrKeyResolver}" # 详情见 RateLimiterConfig

终于大功告成~

a