|
|
|
package limit
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"strconv"
|
|
|
|
"sync"
|
|
|
|
"sync/atomic"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
|
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
|
|
|
xrate "golang.org/x/time/rate"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
tokenFormat = "{%s}.tokens"
|
|
|
|
timestampFormat = "{%s}.ts"
|
|
|
|
pingInterval = time.Millisecond * 100
|
|
|
|
)
|
|
|
|
|
|
|
|
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
|
|
|
|
// KEYS[1] as tokens_key
|
|
|
|
// KEYS[2] as timestamp_key
|
|
|
|
var script = redis.NewScript(`local rate = tonumber(ARGV[1])
|
|
|
|
local capacity = tonumber(ARGV[2])
|
|
|
|
local now = tonumber(ARGV[3])
|
|
|
|
local requested = tonumber(ARGV[4])
|
|
|
|
local fill_time = capacity/rate
|
|
|
|
local ttl = math.floor(fill_time*2)
|
|
|
|
local last_tokens = tonumber(redis.call("get", KEYS[1]))
|
|
|
|
if last_tokens == nil then
|
|
|
|
last_tokens = capacity
|
|
|
|
end
|
|
|
|
|
|
|
|
local last_refreshed = tonumber(redis.call("get", KEYS[2]))
|
|
|
|
if last_refreshed == nil then
|
|
|
|
last_refreshed = 0
|
|
|
|
end
|
|
|
|
|
|
|
|
local delta = math.max(0, now-last_refreshed)
|
|
|
|
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
|
|
|
|
local allowed = filled_tokens >= requested
|
|
|
|
local new_tokens = filled_tokens
|
|
|
|
if allowed then
|
|
|
|
new_tokens = filled_tokens - requested
|
|
|
|
end
|
|
|
|
|
|
|
|
redis.call("setex", KEYS[1], ttl, new_tokens)
|
|
|
|
redis.call("setex", KEYS[2], ttl, now)
|
|
|
|
|
|
|
|
return allowed`)
|
|
|
|
|
|
|
|
// A TokenLimiter controls how frequently events are allowed to happen with in one second.
|
|
|
|
type TokenLimiter struct {
|
|
|
|
rate int
|
|
|
|
burst int
|
|
|
|
store *redis.Redis
|
|
|
|
tokenKey string
|
|
|
|
timestampKey string
|
|
|
|
rescueLock sync.Mutex
|
|
|
|
redisAlive uint32
|
|
|
|
monitorStarted bool
|
|
|
|
rescueLimiter *xrate.Limiter
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits
|
|
|
|
// bursts of at most burst tokens.
|
|
|
|
func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter {
|
|
|
|
tokenKey := fmt.Sprintf(tokenFormat, key)
|
|
|
|
timestampKey := fmt.Sprintf(timestampFormat, key)
|
|
|
|
|
|
|
|
return &TokenLimiter{
|
|
|
|
rate: rate,
|
|
|
|
burst: burst,
|
|
|
|
store: store,
|
|
|
|
tokenKey: tokenKey,
|
|
|
|
timestampKey: timestampKey,
|
|
|
|
redisAlive: 1,
|
|
|
|
rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Allow is shorthand for AllowN(time.Now(), 1).
|
|
|
|
func (lim *TokenLimiter) Allow() bool {
|
|
|
|
return lim.AllowN(time.Now(), 1)
|
|
|
|
}
|
|
|
|
|
|
|
|
// AllowCtx is shorthand for AllowNCtx(ctx,time.Now(), 1) with incoming context.
|
|
|
|
func (lim *TokenLimiter) AllowCtx(ctx context.Context) bool {
|
|
|
|
return lim.AllowNCtx(ctx, time.Now(), 1)
|
|
|
|
}
|
|
|
|
|
|
|
|
// AllowN reports whether n events may happen at time now.
|
|
|
|
// Use this method if you intend to drop / skip events that exceed the rate.
|
|
|
|
// Otherwise, use Reserve or Wait.
|
|
|
|
func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
|
|
|
|
return lim.reserveN(context.Background(), now, n)
|
|
|
|
}
|
|
|
|
|
|
|
|
// AllowNCtx reports whether n events may happen at time now with incoming context.
|
|
|
|
// Use this method if you intend to drop / skip events that exceed the rate.
|
|
|
|
// Otherwise, use Reserve or Wait.
|
|
|
|
func (lim *TokenLimiter) AllowNCtx(ctx context.Context, now time.Time, n int) bool {
|
|
|
|
return lim.reserveN(ctx, now, n)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) bool {
|
|
|
|
if atomic.LoadUint32(&lim.redisAlive) == 0 {
|
|
|
|
return lim.rescueLimiter.AllowN(now, n)
|
|
|
|
}
|
|
|
|
|
|
|
|
resp, err := lim.store.ScriptRunCtx(ctx,
|
|
|
|
script,
|
|
|
|
[]string{
|
|
|
|
lim.tokenKey,
|
|
|
|
lim.timestampKey,
|
|
|
|
},
|
|
|
|
[]string{
|
|
|
|
strconv.Itoa(lim.rate),
|
|
|
|
strconv.Itoa(lim.burst),
|
|
|
|
strconv.FormatInt(now.Unix(), 10),
|
|
|
|
strconv.Itoa(n),
|
|
|
|
})
|
|
|
|
// redis allowed == false
|
|
|
|
// Lua boolean false -> r Nil bulk reply
|
|
|
|
if err == redis.Nil {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
|
|
|
logx.Errorf("fail to use rate limiter: %s", err)
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
|
|
|
|
lim.startMonitor()
|
|
|
|
return lim.rescueLimiter.AllowN(now, n)
|
|
|
|
}
|
|
|
|
|
|
|
|
code, ok := resp.(int64)
|
|
|
|
if !ok {
|
|
|
|
logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp)
|
|
|
|
lim.startMonitor()
|
|
|
|
return lim.rescueLimiter.AllowN(now, n)
|
|
|
|
}
|
|
|
|
|
|
|
|
// redis allowed == true
|
|
|
|
// Lua boolean true -> r integer reply with value of 1
|
|
|
|
return code == 1
|
|
|
|
}
|
|
|
|
|
|
|
|
func (lim *TokenLimiter) startMonitor() {
|
|
|
|
lim.rescueLock.Lock()
|
|
|
|
defer lim.rescueLock.Unlock()
|
|
|
|
|
|
|
|
if lim.monitorStarted {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
lim.monitorStarted = true
|
|
|
|
atomic.StoreUint32(&lim.redisAlive, 0)
|
|
|
|
|
|
|
|
go lim.waitForRedis()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (lim *TokenLimiter) waitForRedis() {
|
|
|
|
ticker := time.NewTicker(pingInterval)
|
|
|
|
defer func() {
|
|
|
|
ticker.Stop()
|
|
|
|
lim.rescueLock.Lock()
|
|
|
|
lim.monitorStarted = false
|
|
|
|
lim.rescueLock.Unlock()
|
|
|
|
}()
|
|
|
|
|
|
|
|
for range ticker.C {
|
|
|
|
if lim.store.Ping() {
|
|
|
|
atomic.StoreUint32(&lim.redisAlive, 1)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|