package limit import ( "errors" "strconv" "time" "github.com/zeromicro/go-zero/core/stores/redis" ) const ( // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key periodScript = `local limit = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local current = redis.call("INCRBY", KEYS[1], 1) if current == 1 then redis.call("expire", KEYS[1], window) return 1 elseif current < limit then return 1 elseif current == limit then return 2 else return 0 end` zoneDiff = 3600 * 8 // GMT+8 for our services ) const ( // Unknown means not initialized state. Unknown = iota // Allowed means allowed state. Allowed // HitQuota means this request exactly hit the quota. HitQuota // OverQuota means passed the quota. OverQuota internalOverQuota = 0 internalAllowed = 1 internalHitQuota = 2 ) // ErrUnknownCode is an error that represents unknown status code. var ErrUnknownCode = errors.New("unknown status code") type ( // PeriodOption defines the method to customize a PeriodLimit. PeriodOption func(l *PeriodLimit) // A PeriodLimit is used to limit requests during a period of time. PeriodLimit struct { period int quota int limitStore *redis.Redis keyPrefix string align bool } ) // NewPeriodLimit returns a PeriodLimit with given parameters. func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string, opts ...PeriodOption) *PeriodLimit { limiter := &PeriodLimit{ period: period, quota: quota, limitStore: limitStore, keyPrefix: keyPrefix, } for _, opt := range opts { opt(limiter) } return limiter } // Take requests a permit, it returns the permit state. func (h *PeriodLimit) Take(key string) (int, error) { resp, err := h.limitStore.Eval(periodScript, []string{h.keyPrefix + key}, []string{ strconv.Itoa(h.quota), strconv.Itoa(h.calcExpireSeconds()), }) if err != nil { return Unknown, err } code, ok := resp.(int64) if !ok { return Unknown, ErrUnknownCode } switch code { case internalOverQuota: return OverQuota, nil case internalAllowed: return Allowed, nil case internalHitQuota: return HitQuota, nil default: return Unknown, ErrUnknownCode } } func (h *PeriodLimit) calcExpireSeconds() int { if h.align { unix := time.Now().Unix() + zoneDiff return h.period - int(unix%int64(h.period)) } return h.period } // Align returns a func to customize a PeriodLimit with alignment. func Align() PeriodOption { return func(l *PeriodLimit) { l.align = true } }