token limit support context (#2335)

* token limit support context

* add token limit with ctx

add token limit with ctx

Co-authored-by: sado <liaoyonglin@bilibili.com>
master
sado 2 years ago committed by GitHub
parent 799c118d95
commit f068062b13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,8 @@
package limit package limit
import ( import (
"context"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"sync" "sync"
@ -84,19 +86,38 @@ func (lim *TokenLimiter) Allow() bool {
return lim.AllowN(time.Now(), 1) return lim.AllowN(time.Now(), 1)
} }
// AllowCtx is shorthand for AllowNCtx(ctx,time.Now(), 1) with incoming context.
func (lim *TokenLimiter) AllowCtx(ctx context.Context) bool {
return lim.AllowNCtx(ctx, time.Now(), 1)
}
// AllowN reports whether n events may happen at time now. // AllowN reports whether n events may happen at time now.
// Use this method if you intend to drop / skip events that exceed the rate. // Use this method if you intend to drop / skip events that exceed the rate.
// Otherwise, use Reserve or Wait. // Otherwise, use Reserve or Wait.
func (lim *TokenLimiter) AllowN(now time.Time, n int) bool { func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
return lim.reserveN(now, n) return lim.reserveN(context.Background(), now, n)
}
// AllowNCtx reports whether n events may happen at time now with incoming context.
// Use this method if you intend to drop / skip events that exceed the rate.
// Otherwise, use Reserve or Wait.
func (lim *TokenLimiter) AllowNCtx(ctx context.Context, now time.Time, n int) bool {
return lim.reserveN(ctx, now, n)
}
func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) bool {
select {
case <-ctx.Done():
logx.Errorf("fail to use rate limiter: %s", ctx.Err())
return false
default:
} }
func (lim *TokenLimiter) reserveN(now time.Time, n int) bool {
if atomic.LoadUint32(&lim.redisAlive) == 0 { if atomic.LoadUint32(&lim.redisAlive) == 0 {
return lim.rescueLimiter.AllowN(now, n) return lim.rescueLimiter.AllowN(now, n)
} }
resp, err := lim.store.Eval( resp, err := lim.store.EvalCtx(ctx,
script, script,
[]string{ []string{
lim.tokenKey, lim.tokenKey,
@ -113,6 +134,12 @@ func (lim *TokenLimiter) reserveN(now time.Time, n int) bool {
if err == redis.Nil { if err == redis.Nil {
return false return false
} }
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
logx.Errorf("fail to use rate limiter: %s", err)
return false
}
if err != nil { if err != nil {
logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err) logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
lim.startMonitor() lim.startMonitor()

@ -1,6 +1,7 @@
package limit package limit
import ( import (
"context"
"testing" "testing"
"time" "time"
@ -15,6 +16,30 @@ func init() {
logx.Disable() logx.Disable()
} }
func TestTokenLimit_WithCtx(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
const (
total = 100
rate = 5
burst = 10
)
l := NewTokenLimiter(rate, burst, redis.New(s.Addr()), "tokenlimit")
defer s.Close()
ctx, cancel := context.WithCancel(context.Background())
ok := l.AllowCtx(ctx)
assert.True(t, ok)
cancel()
for i := 0; i < total; i++ {
ok := l.AllowCtx(ctx)
assert.False(t, ok)
assert.False(t, l.monitorStarted)
}
}
func TestTokenLimit_Rescue(t *testing.T) { func TestTokenLimit_Rescue(t *testing.T) {
s, err := miniredis.Run() s, err := miniredis.Run()
assert.Nil(t, err) assert.Nil(t, err)

Loading…
Cancel
Save