package limit import ( "context" "errors" "strconv" "time" "github.com/zeromicro/go-zero/core/stores/redis" ) const ( // Unknown means not initialized state. Unknown = iota // Allowed means allowed state. Allowed // HitQuota means this request exactly hit the quota. HitQuota // OverQuota means passed the quota. OverQuota internalOverQuota = 0 internalAllowed = 1 internalHitQuota = 2 ) var ( // ErrUnknownCode is an error that represents unknown status code. ErrUnknownCode = errors.New("unknown status code") // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key periodScript = redis.NewScript(`local limit = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local current = redis.call("INCRBY", KEYS[1], 1) if current == 1 then redis.call("expire", KEYS[1], window) end if current < limit then return 1 elseif current == limit then return 2 else return 0 end`) ) type ( // PeriodOption defines the method to customize a PeriodLimit. PeriodOption func(l *PeriodLimit) // A PeriodLimit is used to limit requests during a period of time. PeriodLimit struct { period int quota int limitStore *redis.Redis keyPrefix string align bool } ) // NewPeriodLimit returns a PeriodLimit with given parameters. func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string, opts ...PeriodOption) *PeriodLimit { limiter := &PeriodLimit{ period: period, quota: quota, limitStore: limitStore, keyPrefix: keyPrefix, } for _, opt := range opts { opt(limiter) } return limiter } // Take requests a permit, it returns the permit state. func (h *PeriodLimit) Take(key string) (int, error) { return h.TakeCtx(context.Background(), key) } // TakeCtx requests a permit with context, it returns the permit state. func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) { resp, err := h.limitStore.ScriptRunCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{ strconv.Itoa(h.quota), strconv.Itoa(h.calcExpireSeconds()), }) if err != nil { return Unknown, err } code, ok := resp.(int64) if !ok { return Unknown, ErrUnknownCode } switch code { case internalOverQuota: return OverQuota, nil case internalAllowed: return Allowed, nil case internalHitQuota: return HitQuota, nil default: return Unknown, ErrUnknownCode } } func (h *PeriodLimit) calcExpireSeconds() int { if h.align { now := time.Now() _, offset := now.Zone() unix := now.Unix() + int64(offset) return h.period - int(unix%int64(h.period)) } return h.period } // Align returns a func to customize a PeriodLimit with alignment. // For example, if we want to limit end users with 5 sms verification messages every day, // we need to align with the local timezone and the start of the day. func Align() PeriodOption { return func(l *PeriodLimit) { l.align = true } }