package limit import ( "context" "errors" "fmt" "strconv" "sync" "sync/atomic" "time" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stores/redis" xrate "golang.org/x/time/rate" ) const ( // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key // KEYS[1] as tokens_key // KEYS[2] as timestamp_key script = `local rate = tonumber(ARGV[1]) local capacity = tonumber(ARGV[2]) local now = tonumber(ARGV[3]) local requested = tonumber(ARGV[4]) local fill_time = capacity/rate local ttl = math.floor(fill_time*2) local last_tokens = tonumber(redis.call("get", KEYS[1])) if last_tokens == nil then last_tokens = capacity end local last_refreshed = tonumber(redis.call("get", KEYS[2])) if last_refreshed == nil then last_refreshed = 0 end local delta = math.max(0, now-last_refreshed) local filled_tokens = math.min(capacity, last_tokens+(delta*rate)) local allowed = filled_tokens >= requested local new_tokens = filled_tokens if allowed then new_tokens = filled_tokens - requested end redis.call("setex", KEYS[1], ttl, new_tokens) redis.call("setex", KEYS[2], ttl, now) return allowed` tokenFormat = "{%s}.tokens" timestampFormat = "{%s}.ts" pingInterval = time.Millisecond * 100 ) // A TokenLimiter controls how frequently events are allowed to happen with in one second. type TokenLimiter struct { rate int burst int store *redis.Redis tokenKey string timestampKey string rescueLock sync.Mutex redisAlive uint32 rescueLimiter *xrate.Limiter monitorStarted bool } // NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits // bursts of at most burst tokens. func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter { tokenKey := fmt.Sprintf(tokenFormat, key) timestampKey := fmt.Sprintf(timestampFormat, key) return &TokenLimiter{ rate: rate, burst: burst, store: store, tokenKey: tokenKey, timestampKey: timestampKey, redisAlive: 1, rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst), } } // Allow is shorthand for AllowN(time.Now(), 1). func (lim *TokenLimiter) Allow() bool { return lim.AllowN(time.Now(), 1) } // AllowCtx is shorthand for AllowNCtx(ctx,time.Now(), 1) with incoming context. func (lim *TokenLimiter) AllowCtx(ctx context.Context) bool { return lim.AllowNCtx(ctx, time.Now(), 1) } // AllowN reports whether n events may happen at time now. // Use this method if you intend to drop / skip events that exceed the rate. // Otherwise, use Reserve or Wait. func (lim *TokenLimiter) AllowN(now time.Time, n int) bool { return lim.reserveN(context.Background(), now, n) } // AllowNCtx reports whether n events may happen at time now with incoming context. // Use this method if you intend to drop / skip events that exceed the rate. // Otherwise, use Reserve or Wait. func (lim *TokenLimiter) AllowNCtx(ctx context.Context, now time.Time, n int) bool { return lim.reserveN(ctx, now, n) } func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) bool { select { case <-ctx.Done(): logx.Errorf("fail to use rate limiter: %s", ctx.Err()) return false default: } if atomic.LoadUint32(&lim.redisAlive) == 0 { return lim.rescueLimiter.AllowN(now, n) } resp, err := lim.store.EvalCtx(ctx, script, []string{ lim.tokenKey, lim.timestampKey, }, []string{ strconv.Itoa(lim.rate), strconv.Itoa(lim.burst), strconv.FormatInt(now.Unix(), 10), strconv.Itoa(n), }) // redis allowed == false // Lua boolean false -> r Nil bulk reply if err == redis.Nil { return false } if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { logx.Errorf("fail to use rate limiter: %s", err) return false } if err != nil { logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err) lim.startMonitor() return lim.rescueLimiter.AllowN(now, n) } code, ok := resp.(int64) if !ok { logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp) lim.startMonitor() return lim.rescueLimiter.AllowN(now, n) } // redis allowed == true // Lua boolean true -> r integer reply with value of 1 return code == 1 } func (lim *TokenLimiter) startMonitor() { lim.rescueLock.Lock() defer lim.rescueLock.Unlock() if lim.monitorStarted { return } lim.monitorStarted = true atomic.StoreUint32(&lim.redisAlive, 0) go lim.waitForRedis() } func (lim *TokenLimiter) waitForRedis() { ticker := time.NewTicker(pingInterval) defer func() { ticker.Stop() lim.rescueLock.Lock() lim.monitorStarted = false lim.rescueLock.Unlock() }() for range ticker.C { if lim.store.Ping() { atomic.StoreUint32(&lim.redisAlive, 1) return } } }