From e8c307e4dcca75064b43c0503e560cba39964af6 Mon Sep 17 00:00:00 2001 From: chenquan Date: Sun, 13 Feb 2022 03:28:14 -0600 Subject: [PATCH] feat: support ctx in `Cache` (#1518) * feature: support ctx in `Cache` Signed-off-by: chenquan * fix: `errors.Is` instead of `=` Signed-off-by: chenquan --- core/stores/cache/cache.go | 82 ++++++++++++++++++++++++++--- core/stores/cache/cache_test.go | 32 +++++++++++- core/stores/cache/cachenode.go | 91 +++++++++++++++++++++++---------- 3 files changed, 169 insertions(+), 36 deletions(-) diff --git a/core/stores/cache/cache.go b/core/stores/cache/cache.go index fe3fc8dd..dc491eb9 100644 --- a/core/stores/cache/cache.go +++ b/core/stores/cache/cache.go @@ -1,6 +1,8 @@ package cache import ( + "context" + "errors" "fmt" "log" "time" @@ -13,13 +15,36 @@ import ( type ( // Cache interface is used to define the cache implementation. Cache interface { + // Del deletes cached values with keys. Del(keys ...string) error + // DelCtx deletes cached values with keys. + DelCtx(ctx context.Context, keys ...string) error + // Get gets the cache with key and fills into v. Get(key string, v interface{}) error + // GetCtx gets the cache with key and fills into v. + GetCtx(ctx context.Context, key string, v interface{}) error + // IsNotFound checks if the given error is the defined errNotFound. IsNotFound(err error) bool + // Set sets the cache with key and v, using c.expiry. Set(key string, v interface{}) error + // SetCtx sets the cache with key and v, using c.expiry. + SetCtx(ctx context.Context, key string, v interface{}) error + // SetWithExpire sets the cache with key and v, using given expire. SetWithExpire(key string, v interface{}, expire time.Duration) error + // SetWithExpireCtx sets the cache with key and v, using given expire. + SetWithExpireCtx(ctx context.Context, key string, v interface{}, expire time.Duration) error + // Take takes the result from cache first, if not found, + // query from DB and set cache using c.expiry, then return the result. Take(v interface{}, key string, query func(v interface{}) error) error + // TakeCtx takes the result from cache first, if not found, + // query from DB and set cache using c.expiry, then return the result. + TakeCtx(ctx context.Context, v interface{}, key string, query func(v interface{}) error) error + // TakeWithExpire takes the result from cache first, if not found, + // query from DB and set cache using given expire, then return the result. TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error + // TakeWithExpireCtx takes the result from cache first, if not found, + // query from DB and set cache using given expire, then return the result. + TakeWithExpireCtx(ctx context.Context, v interface{}, key string, query func(v interface{}, expire time.Duration) error) error } cacheCluster struct { @@ -51,7 +76,13 @@ func New(c ClusterConf, barrier syncx.SingleFlight, st *Stat, errNotFound error, } } +// Del deletes cached values with keys. func (cc cacheCluster) Del(keys ...string) error { + return cc.DelCtx(context.Background(), keys...) +} + +// DelCtx deletes cached values with keys. +func (cc cacheCluster) DelCtx(ctx context.Context, keys ...string) error { switch len(keys) { case 0: return nil @@ -62,7 +93,7 @@ func (cc cacheCluster) Del(keys ...string) error { return cc.errNotFound } - return c.(Cache).Del(key) + return c.(Cache).DelCtx(ctx, key) default: var be errorx.BatchError nodes := make(map[interface{}][]string) @@ -76,7 +107,7 @@ func (cc cacheCluster) Del(keys ...string) error { nodes[c] = append(nodes[c], key) } for c, ks := range nodes { - if err := c.(Cache).Del(ks...); err != nil { + if err := c.(Cache).DelCtx(ctx, ks...); err != nil { be.Add(err) } } @@ -85,52 +116,87 @@ func (cc cacheCluster) Del(keys ...string) error { } } +// Get gets the cache with key and fills into v. func (cc cacheCluster) Get(key string, v interface{}) error { + return cc.GetCtx(context.Background(), key, v) +} + +// GetCtx gets the cache with key and fills into v. +func (cc cacheCluster) GetCtx(ctx context.Context, key string, v interface{}) error { c, ok := cc.dispatcher.Get(key) if !ok { return cc.errNotFound } - return c.(Cache).Get(key, v) + return c.(Cache).GetCtx(ctx, key, v) } +// IsNotFound checks if the given error is the defined errNotFound. func (cc cacheCluster) IsNotFound(err error) bool { - return err == cc.errNotFound + return errors.Is(err, cc.errNotFound) } +// Set sets the cache with key and v, using c.expiry. func (cc cacheCluster) Set(key string, v interface{}) error { + return cc.SetCtx(context.Background(), key, v) +} + +// SetCtx sets the cache with key and v, using c.expiry. +func (cc cacheCluster) SetCtx(ctx context.Context, key string, v interface{}) error { c, ok := cc.dispatcher.Get(key) if !ok { return cc.errNotFound } - return c.(Cache).Set(key, v) + return c.(Cache).SetCtx(ctx, key, v) } +// SetWithExpire sets the cache with key and v, using given expire. func (cc cacheCluster) SetWithExpire(key string, v interface{}, expire time.Duration) error { + return cc.SetWithExpireCtx(context.Background(), key, v, expire) +} + +// SetWithExpireCtx sets the cache with key and v, using given expire. +func (cc cacheCluster) SetWithExpireCtx(ctx context.Context, key string, v interface{}, expire time.Duration) error { c, ok := cc.dispatcher.Get(key) if !ok { return cc.errNotFound } - return c.(Cache).SetWithExpire(key, v, expire) + return c.(Cache).SetWithExpireCtx(ctx, key, v, expire) } +// Take takes the result from cache first, if not found, +// query from DB and set cache using c.expiry, then return the result. func (cc cacheCluster) Take(v interface{}, key string, query func(v interface{}) error) error { + return cc.TakeCtx(context.Background(), v, key, query) +} + +// TakeCtx takes the result from cache first, if not found, +// query from DB and set cache using c.expiry, then return the result. +func (cc cacheCluster) TakeCtx(ctx context.Context, v interface{}, key string, query func(v interface{}) error) error { c, ok := cc.dispatcher.Get(key) if !ok { return cc.errNotFound } - return c.(Cache).Take(v, key, query) + return c.(Cache).TakeCtx(ctx, v, key, query) } +// TakeWithExpire takes the result from cache first, if not found, +// query from DB and set cache using given expire, then return the result. func (cc cacheCluster) TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error { + return cc.TakeWithExpireCtx(context.Background(), v, key, query) +} + +// TakeWithExpireCtx takes the result from cache first, if not found, +// query from DB and set cache using given expire, then return the result. +func (cc cacheCluster) TakeWithExpireCtx(ctx context.Context, v interface{}, key string, query func(v interface{}, expire time.Duration) error) error { c, ok := cc.dispatcher.Get(key) if !ok { return cc.errNotFound } - return c.(Cache).TakeWithExpire(v, key, query) + return c.(Cache).TakeWithExpireCtx(ctx, v, key, query) } diff --git a/core/stores/cache/cache_test.go b/core/stores/cache/cache_test.go index 12c6e281..f819126e 100644 --- a/core/stores/cache/cache_test.go +++ b/core/stores/cache/cache_test.go @@ -1,7 +1,9 @@ package cache import ( + "context" "encoding/json" + "errors" "fmt" "math" "strconv" @@ -16,6 +18,8 @@ import ( "github.com/zeromicro/go-zero/core/syncx" ) +var _ Cache = (*mockedNode)(nil) + type mockedNode struct { vals map[string][]byte errNotFound error @@ -45,7 +49,7 @@ func (mc *mockedNode) Get(key string, v interface{}) error { } func (mc *mockedNode) IsNotFound(err error) bool { - return err == mc.errNotFound + return errors.Is(err, mc.errNotFound) } func (mc *mockedNode) Set(key string, v interface{}) error { @@ -58,7 +62,7 @@ func (mc *mockedNode) Set(key string, v interface{}) error { return nil } -func (mc *mockedNode) SetWithExpire(key string, v interface{}, expire time.Duration) error { +func (mc *mockedNode) SetWithExpire(key string, v interface{}, _ time.Duration) error { return mc.Set(key, v) } @@ -80,6 +84,30 @@ func (mc *mockedNode) TakeWithExpire(v interface{}, key string, query func(v int }) } +func (mc *mockedNode) DelCtx(_ context.Context, keys ...string) error { + return mc.Del(keys...) +} + +func (mc *mockedNode) GetCtx(_ context.Context, key string, v interface{}) error { + return mc.Get(key, v) +} + +func (mc *mockedNode) SetCtx(_ context.Context, key string, v interface{}) error { + return mc.Set(key, v) +} + +func (mc *mockedNode) SetWithExpireCtx(_ context.Context, key string, v interface{}, expire time.Duration) error { + return mc.SetWithExpire(key, v, expire) +} + +func (mc *mockedNode) TakeCtx(_ context.Context, v interface{}, key string, query func(v interface{}) error) error { + return mc.Take(v, key, query) +} + +func (mc *mockedNode) TakeWithExpireCtx(_ context.Context, v interface{}, key string, query func(v interface{}, expire time.Duration) error) error { + return mc.TakeWithExpire(v, key, query) +} + func TestCache_SetDel(t *testing.T) { const total = 1000 r1, clean1, err := redistest.CreateRedis() diff --git a/core/stores/cache/cachenode.go b/core/stores/cache/cachenode.go index 49904071..55a7fc27 100644 --- a/core/stores/cache/cachenode.go +++ b/core/stores/cache/cachenode.go @@ -1,6 +1,7 @@ package cache import ( + "context" "errors" "fmt" "math/rand" @@ -61,20 +62,27 @@ func NewNode(rds *redis.Redis, barrier syncx.SingleFlight, st *Stat, // Del deletes cached values with keys. func (c cacheNode) Del(keys ...string) error { + return c.DelCtx(context.Background(), keys...) +} + +// DelCtx deletes cached values with keys. +func (c cacheNode) DelCtx(ctx context.Context, keys ...string) error { if len(keys) == 0 { return nil } + logger := logx.WithContext(ctx) + if len(keys) > 1 && c.rds.Type == redis.ClusterType { for _, key := range keys { - if _, err := c.rds.Del(key); err != nil { - logx.Errorf("failed to clear cache with key: %q, error: %v", key, err) + if _, err := c.rds.DelCtx(ctx, key); err != nil { + logger.Errorf("failed to clear cache with key: %q, error: %v", key, err) c.asyncRetryDelCache(key) } } } else { - if _, err := c.rds.Del(keys...); err != nil { - logx.Errorf("failed to clear cache with keys: %q, error: %v", formatKeys(keys), err) + if _, err := c.rds.DelCtx(ctx, keys...); err != nil { + logger.Errorf("failed to clear cache with keys: %q, error: %v", formatKeys(keys), err) c.asyncRetryDelCache(keys...) } } @@ -84,7 +92,12 @@ func (c cacheNode) Del(keys ...string) error { // Get gets the cache with key and fills into v. func (c cacheNode) Get(key string, v interface{}) error { - err := c.doGetCache(key, v) + return c.GetCtx(context.Background(), key, v) +} + +// GetCtx gets the cache with key and fills into v. +func (c cacheNode) GetCtx(ctx context.Context, key string, v interface{}) error { + err := c.doGetCache(ctx, key, v) if err == errPlaceholder { return c.errNotFound } @@ -94,22 +107,32 @@ func (c cacheNode) Get(key string, v interface{}) error { // IsNotFound checks if the given error is the defined errNotFound. func (c cacheNode) IsNotFound(err error) bool { - return err == c.errNotFound + return errors.Is(err, c.errNotFound) } // Set sets the cache with key and v, using c.expiry. func (c cacheNode) Set(key string, v interface{}) error { - return c.SetWithExpire(key, v, c.aroundDuration(c.expiry)) + return c.SetCtx(context.Background(), key, v) +} + +// SetCtx sets the cache with key and v, using c.expiry. +func (c cacheNode) SetCtx(ctx context.Context, key string, v interface{}) error { + return c.SetWithExpireCtx(ctx, key, v, c.aroundDuration(c.expiry)) } // SetWithExpire sets the cache with key and v, using given expire. func (c cacheNode) SetWithExpire(key string, v interface{}, expire time.Duration) error { + return c.SetWithExpireCtx(context.Background(), key, v, expire) +} + +// SetWithExpireCtx sets the cache with key and v, using given expire. +func (c cacheNode) SetWithExpireCtx(ctx context.Context, key string, v interface{}, expire time.Duration) error { data, err := jsonx.Marshal(v) if err != nil { return err } - return c.rds.Setex(key, string(data), int(expire.Seconds())) + return c.rds.SetexCtx(ctx, key, string(data), int(expire.Seconds())) } // String returns a string that represents the cacheNode. @@ -120,8 +143,14 @@ func (c cacheNode) String() string { // Take takes the result from cache first, if not found, // query from DB and set cache using c.expiry, then return the result. func (c cacheNode) Take(v interface{}, key string, query func(v interface{}) error) error { - return c.doTake(v, key, query, func(v interface{}) error { - return c.Set(key, v) + return c.TakeCtx(context.Background(), v, key, query) +} + +// TakeCtx takes the result from cache first, if not found, +// query from DB and set cache using c.expiry, then return the result. +func (c cacheNode) TakeCtx(ctx context.Context, v interface{}, key string, query func(v interface{}) error) error { + return c.doTake(ctx, v, key, query, func(v interface{}) error { + return c.SetCtx(ctx, key, v) }) } @@ -129,11 +158,17 @@ func (c cacheNode) Take(v interface{}, key string, query func(v interface{}) err // query from DB and set cache using given expire, then return the result. func (c cacheNode) TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error { + return c.TakeWithExpireCtx(context.Background(), v, key, query) +} + +// TakeWithExpireCtx takes the result from cache first, if not found, +// query from DB and set cache using given expire, then return the result. +func (c cacheNode) TakeWithExpireCtx(ctx context.Context, v interface{}, key string, query func(v interface{}, expire time.Duration) error) error { expire := c.aroundDuration(c.expiry) - return c.doTake(v, key, func(v interface{}) error { + return c.doTake(ctx, v, key, func(v interface{}) error { return query(v, expire) }, func(v interface{}) error { - return c.SetWithExpire(key, v, expire) + return c.SetWithExpireCtx(ctx, key, v, expire) }) } @@ -148,9 +183,9 @@ func (c cacheNode) asyncRetryDelCache(keys ...string) { }, keys...) } -func (c cacheNode) doGetCache(key string, v interface{}) error { +func (c cacheNode) doGetCache(ctx context.Context, key string, v interface{}) error { c.stat.IncrementTotal() - data, err := c.rds.Get(key) + data, err := c.rds.GetCtx(ctx, key) if err != nil { c.stat.IncrementMiss() return err @@ -166,13 +201,15 @@ func (c cacheNode) doGetCache(key string, v interface{}) error { return errPlaceholder } - return c.processCache(key, data, v) + return c.processCache(ctx, key, data, v) } -func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) error, +func (c cacheNode) doTake(ctx context.Context, v interface{}, key string, query func(v interface{}) error, cacheVal func(v interface{}) error) error { + logger := logx.WithContext(ctx) + val, fresh, err := c.barrier.DoEx(key, func() (interface{}, error) { - if err := c.doGetCache(key, v); err != nil { + if err := c.doGetCache(ctx, key, v); err != nil { if err == errPlaceholder { return nil, c.errNotFound } else if err != c.errNotFound { @@ -183,8 +220,8 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e } if err = query(v); err == c.errNotFound { - if err = c.setCacheWithNotFound(key); err != nil { - logx.Error(err) + if err = c.setCacheWithNotFound(ctx, key); err != nil { + logger.Error(err) } return nil, c.errNotFound @@ -194,7 +231,7 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e } if err = cacheVal(v); err != nil { - logx.Error(err) + logger.Error(err) } } @@ -214,7 +251,9 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e return jsonx.Unmarshal(val.([]byte), v) } -func (c cacheNode) processCache(key, data string, v interface{}) error { +func (c cacheNode) processCache(ctx context.Context, key, data string, v interface{}) error { + logger := logx.WithContext(ctx) + err := jsonx.Unmarshal([]byte(data), v) if err == nil { return nil @@ -222,10 +261,10 @@ func (c cacheNode) processCache(key, data string, v interface{}) error { report := fmt.Sprintf("unmarshal cache, node: %s, key: %s, value: %s, error: %v", c.rds.Addr, key, data, err) - logx.Error(report) + logger.Error(report) stat.Report(report) - if _, e := c.rds.Del(key); e != nil { - logx.Errorf("delete invalid cache, node: %s, key: %s, value: %s, error: %v", + if _, e := c.rds.DelCtx(ctx, key); e != nil { + logger.Errorf("delete invalid cache, node: %s, key: %s, value: %s, error: %v", c.rds.Addr, key, data, e) } @@ -233,6 +272,6 @@ func (c cacheNode) processCache(key, data string, v interface{}) error { return c.errNotFound } -func (c cacheNode) setCacheWithNotFound(key string) error { - return c.rds.Setex(key, notFoundPlaceholder, int(c.aroundDuration(c.notFoundExpiry).Seconds())) +func (c cacheNode) setCacheWithNotFound(ctx context.Context, key string) error { + return c.rds.SetexCtx(ctx, key, notFoundPlaceholder, int(c.aroundDuration(c.notFoundExpiry).Seconds())) }