diff --git a/core/limit/tokenlimit.go b/core/limit/tokenlimit.go index 9fabc8f6..7e43e6db 100644 --- a/core/limit/tokenlimit.go +++ b/core/limit/tokenlimit.go @@ -1,6 +1,8 @@ package limit import ( + "context" + "errors" "fmt" "strconv" "sync" @@ -84,19 +86,38 @@ func (lim *TokenLimiter) Allow() bool { return lim.AllowN(time.Now(), 1) } +// AllowCtx is shorthand for AllowNCtx(ctx,time.Now(), 1) with incoming context. +func (lim *TokenLimiter) AllowCtx(ctx context.Context) bool { + return lim.AllowNCtx(ctx, time.Now(), 1) +} + // AllowN reports whether n events may happen at time now. // Use this method if you intend to drop / skip events that exceed the rate. // Otherwise, use Reserve or Wait. func (lim *TokenLimiter) AllowN(now time.Time, n int) bool { - return lim.reserveN(now, n) + return lim.reserveN(context.Background(), now, n) } -func (lim *TokenLimiter) reserveN(now time.Time, n int) bool { +// AllowNCtx reports whether n events may happen at time now with incoming context. +// Use this method if you intend to drop / skip events that exceed the rate. +// Otherwise, use Reserve or Wait. +func (lim *TokenLimiter) AllowNCtx(ctx context.Context, now time.Time, n int) bool { + return lim.reserveN(ctx, now, n) +} + +func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) bool { + select { + case <-ctx.Done(): + logx.Errorf("fail to use rate limiter: %s", ctx.Err()) + return false + default: + } + if atomic.LoadUint32(&lim.redisAlive) == 0 { return lim.rescueLimiter.AllowN(now, n) } - resp, err := lim.store.Eval( + resp, err := lim.store.EvalCtx(ctx, script, []string{ lim.tokenKey, @@ -113,6 +134,12 @@ func (lim *TokenLimiter) reserveN(now time.Time, n int) bool { if err == redis.Nil { return false } + + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + logx.Errorf("fail to use rate limiter: %s", err) + return false + } + if err != nil { logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err) lim.startMonitor() diff --git a/core/limit/tokenlimit_test.go b/core/limit/tokenlimit_test.go index c0bd6a38..65107d01 100644 --- a/core/limit/tokenlimit_test.go +++ b/core/limit/tokenlimit_test.go @@ -1,6 +1,7 @@ package limit import ( + "context" "testing" "time" @@ -15,6 +16,30 @@ func init() { logx.Disable() } +func TestTokenLimit_WithCtx(t *testing.T) { + s, err := miniredis.Run() + assert.Nil(t, err) + + const ( + total = 100 + rate = 5 + burst = 10 + ) + l := NewTokenLimiter(rate, burst, redis.New(s.Addr()), "tokenlimit") + defer s.Close() + + ctx, cancel := context.WithCancel(context.Background()) + ok := l.AllowCtx(ctx) + assert.True(t, ok) + + cancel() + for i := 0; i < total; i++ { + ok := l.AllowCtx(ctx) + assert.False(t, ok) + assert.False(t, l.monitorStarted) + } +} + func TestTokenLimit_Rescue(t *testing.T) { s, err := miniredis.Run() assert.Nil(t, err)