feat: migrate redis breaker into hook (#3982)

master^2
MarkJoyMa 8 months ago committed by GitHub
parent f372b98d96
commit 7d90f906f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,41 @@
package redis
import (
"context"
red "github.com/redis/go-redis/v9"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/lang"
)
var ignoreCmds = map[string]lang.PlaceholderType{
"blpop": {},
}
type breakerHook struct {
brk breaker.Breaker
}
func (h breakerHook) DialHook(next red.DialHook) red.DialHook {
return next
}
func (h breakerHook) ProcessHook(next red.ProcessHook) red.ProcessHook {
return func(ctx context.Context, cmd red.Cmder) error {
if _, ok := ignoreCmds[cmd.Name()]; ok {
return next(ctx, cmd)
}
return h.brk.DoWithAcceptable(func() error {
return next(ctx, cmd)
}, acceptable)
}
}
func (h breakerHook) ProcessPipelineHook(next red.ProcessPipelineHook) red.ProcessPipelineHook {
return func(ctx context.Context, cmds []red.Cmder) error {
return h.brk.DoWithAcceptable(func() error {
return next(ctx, cmds)
}, acceptable)
}
}

@ -0,0 +1,135 @@
package redis
import (
"context"
"errors"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
)
func TestBreakerHook_ProcessHook(t *testing.T) {
t.Run("breakerHookOpen", func(t *testing.T) {
s := miniredis.RunT(t)
rds := MustNewRedis(RedisConf{
Host: s.Addr(),
Type: NodeType,
})
someError := errors.New("ERR some error")
s.SetError(someError.Error())
var err error
for i := 0; i < 1000; i++ {
_, err = rds.Get("key")
if err != nil && err.Error() != someError.Error() {
break
}
}
assert.Equal(t, breaker.ErrServiceUnavailable, err)
})
t.Run("breakerHookClose", func(t *testing.T) {
s := miniredis.RunT(t)
rds := MustNewRedis(RedisConf{
Host: s.Addr(),
Type: NodeType,
})
var err error
for i := 0; i < 1000; i++ {
_, err = rds.Get("key")
if err != nil {
break
}
}
assert.NotEqual(t, breaker.ErrServiceUnavailable, err)
})
t.Run("breakerHook_ignoreCmd", func(t *testing.T) {
s := miniredis.RunT(t)
rds := MustNewRedis(RedisConf{
Host: s.Addr(),
Type: NodeType,
})
someError := errors.New("ERR some error")
s.SetError(someError.Error())
var err error
node, err := getRedis(rds)
assert.NoError(t, err)
for i := 0; i < 1000; i++ {
_, err = rds.Blpop(node, "key")
if err != nil && err.Error() != someError.Error() {
break
}
}
assert.Equal(t, someError.Error(), err.Error())
})
}
func TestBreakerHook_ProcessPipelineHook(t *testing.T) {
t.Run("breakerPipelineHookOpen", func(t *testing.T) {
s := miniredis.RunT(t)
rds := MustNewRedis(RedisConf{
Host: s.Addr(),
Type: NodeType,
})
someError := errors.New("ERR some error")
s.SetError(someError.Error())
var err error
for i := 0; i < 1000; i++ {
err = rds.Pipelined(
func(pipe Pipeliner) error {
pipe.Incr(context.Background(), "pipelined_counter")
pipe.Expire(context.Background(), "pipelined_counter", time.Hour)
pipe.ZAdd(context.Background(), "zadd", Z{Score: 12, Member: "zadd"})
return nil
},
)
if err != nil && err.Error() != someError.Error() {
break
}
}
assert.Equal(t, breaker.ErrServiceUnavailable, err)
})
t.Run("breakerPipelineHookClose", func(t *testing.T) {
s := miniredis.RunT(t)
rds := MustNewRedis(RedisConf{
Host: s.Addr(),
Type: NodeType,
})
var err error
for i := 0; i < 1000; i++ {
err = rds.Pipelined(
func(pipe Pipeliner) error {
pipe.Incr(context.Background(), "pipelined_counter")
pipe.Expire(context.Background(), "pipelined_counter", time.Hour)
pipe.ZAdd(context.Background(), "zadd", Z{Score: 12, Member: "zadd"})
return nil
},
)
if err != nil {
break
}
}
assert.NotEqual(t, breaker.ErrServiceUnavailable, err)
})
}

@ -23,17 +23,18 @@ import (
const spanName = "redis" const spanName = "redis"
var ( var (
durationHook = hook{} defaultDurationHook = durationHook{}
redisCmdsAttributeKey = attribute.Key("redis.cmds") redisCmdsAttributeKey = attribute.Key("redis.cmds")
) )
type hook struct{} type durationHook struct {
}
func (h hook) DialHook(next red.DialHook) red.DialHook { func (h durationHook) DialHook(next red.DialHook) red.DialHook {
return next return next
} }
func (h hook) ProcessHook(next red.ProcessHook) red.ProcessHook { func (h durationHook) ProcessHook(next red.ProcessHook) red.ProcessHook {
return func(ctx context.Context, cmd red.Cmder) error { return func(ctx context.Context, cmd red.Cmder) error {
start := timex.Now() start := timex.Now()
ctx, endSpan := h.startSpan(ctx, cmd) ctx, endSpan := h.startSpan(ctx, cmd)
@ -57,7 +58,7 @@ func (h hook) ProcessHook(next red.ProcessHook) red.ProcessHook {
} }
} }
func (h hook) ProcessPipelineHook(next red.ProcessPipelineHook) red.ProcessPipelineHook { func (h durationHook) ProcessPipelineHook(next red.ProcessPipelineHook) red.ProcessPipelineHook {
return func(ctx context.Context, cmds []red.Cmder) error { return func(ctx context.Context, cmds []red.Cmder) error {
if len(cmds) == 0 { if len(cmds) == 0 {
return next(ctx, cmds) return next(ctx, cmds)
@ -83,7 +84,7 @@ func (h hook) ProcessPipelineHook(next red.ProcessPipelineHook) red.ProcessPipel
} }
} }
func (h hook) startSpan(ctx context.Context, cmds ...red.Cmder) (context.Context, func(err error)) { func (h durationHook) startSpan(ctx context.Context, cmds ...red.Cmder) (context.Context, func(err error)) {
tracer := trace.TracerFromContext(ctx) tracer := trace.TracerFromContext(ctx)
ctx, span := tracer.Start(ctx, ctx, span := tracer.Start(ctx,

@ -21,7 +21,7 @@ func TestHookProcessCase1(t *testing.T) {
tracetest.NewInMemoryExporter(t) tracetest.NewInMemoryExporter(t)
w := logtest.NewCollector(t) w := logtest.NewCollector(t)
err := durationHook.ProcessHook(func(ctx context.Context, cmd red.Cmder) error { err := defaultDurationHook.ProcessHook(func(ctx context.Context, cmd red.Cmder) error {
assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name()) assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
return nil return nil
})(context.Background(), red.NewCmd(context.Background())) })(context.Background(), red.NewCmd(context.Background()))
@ -36,7 +36,7 @@ func TestHookProcessCase2(t *testing.T) {
tracetest.NewInMemoryExporter(t) tracetest.NewInMemoryExporter(t)
w := logtest.NewCollector(t) w := logtest.NewCollector(t)
err := durationHook.ProcessHook(func(ctx context.Context, cmd red.Cmder) error { err := defaultDurationHook.ProcessHook(func(ctx context.Context, cmd red.Cmder) error {
assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name()) assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
time.Sleep(slowThreshold.Load() + time.Millisecond) time.Sleep(slowThreshold.Load() + time.Millisecond)
return nil return nil
@ -54,12 +54,12 @@ func TestHookProcessPipelineCase1(t *testing.T) {
tracetest.NewInMemoryExporter(t) tracetest.NewInMemoryExporter(t)
w := logtest.NewCollector(t) w := logtest.NewCollector(t)
err := durationHook.ProcessPipelineHook(func(ctx context.Context, cmds []red.Cmder) error { err := defaultDurationHook.ProcessPipelineHook(func(ctx context.Context, cmds []red.Cmder) error {
return nil return nil
})(context.Background(), nil) })(context.Background(), nil)
assert.NoError(t, err) assert.NoError(t, err)
err = durationHook.ProcessPipelineHook(func(ctx context.Context, cmds []red.Cmder) error { err = defaultDurationHook.ProcessPipelineHook(func(ctx context.Context, cmds []red.Cmder) error {
assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name()) assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
return nil return nil
})(context.Background(), []red.Cmder{ })(context.Background(), []red.Cmder{
@ -74,7 +74,7 @@ func TestHookProcessPipelineCase2(t *testing.T) {
tracetest.NewInMemoryExporter(t) tracetest.NewInMemoryExporter(t)
w := logtest.NewCollector(t) w := logtest.NewCollector(t)
err := durationHook.ProcessPipelineHook(func(ctx context.Context, cmds []red.Cmder) error { err := defaultDurationHook.ProcessPipelineHook(func(ctx context.Context, cmds []red.Cmder) error {
assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name()) assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
time.Sleep(slowThreshold.Load() + time.Millisecond) time.Sleep(slowThreshold.Load() + time.Millisecond)
return nil return nil
@ -91,7 +91,7 @@ func TestHookProcessPipelineCase2(t *testing.T) {
func TestHookProcessPipelineCase3(t *testing.T) { func TestHookProcessPipelineCase3(t *testing.T) {
te := tracetest.NewInMemoryExporter(t) te := tracetest.NewInMemoryExporter(t)
err := durationHook.ProcessPipelineHook(func(ctx context.Context, cmds []red.Cmder) error { err := defaultDurationHook.ProcessPipelineHook(func(ctx context.Context, cmds []red.Cmder) error {
assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name()) assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
return assert.AnError return assert.AnError
})(context.Background(), []red.Cmder{ })(context.Background(), []red.Cmder{

File diff suppressed because it is too large Load Diff

@ -36,7 +36,7 @@ func (m myHook) ProcessHook(next red.ProcessHook) red.ProcessHook {
if cmd.Name() == "ping" && !m.includePing { if cmd.Name() == "ping" && !m.includePing {
return next(ctx, cmd) return next(ctx, cmd)
} }
return errors.New("hook error") return errors.New("durationHook error")
} }
} }
@ -155,7 +155,7 @@ func TestRedis_NonBlock(t *testing.T) {
t.Run("nonBlock true", func(t *testing.T) { t.Run("nonBlock true", func(t *testing.T) {
s := miniredis.RunT(t) s := miniredis.RunT(t)
// use hook to simulate redis ping error // use durationHook to simulate redis ping error
_, err := NewRedis(RedisConf{ _, err := NewRedis(RedisConf{
Host: s.Addr(), Host: s.Addr(),
NonBlock: true, NonBlock: true,

@ -37,8 +37,11 @@ func getClient(r *Redis) (*red.Client, error) {
MinIdleConns: idleConns, MinIdleConns: idleConns,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
}) })
store.AddHook(durationHook)
for _, hook := range r.hooks { hooks := append([]red.Hook{defaultDurationHook, breakerHook{
brk: r.brk,
}}, r.hooks...)
for _, hook := range hooks {
store.AddHook(hook) store.AddHook(hook)
} }

@ -33,8 +33,11 @@ func getCluster(r *Redis) (*red.ClusterClient, error) {
MinIdleConns: idleConns, MinIdleConns: idleConns,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
}) })
store.AddHook(durationHook)
for _, hook := range r.hooks { hooks := append([]red.Hook{defaultDurationHook, breakerHook{
brk: r.brk,
}}, r.hooks...)
for _, hook := range hooks {
store.AddHook(hook) store.AddHook(hook)
} }

@ -51,7 +51,7 @@ func TestGetCluster(t *testing.T) {
Addr: r.Addr(), Addr: r.Addr(),
Type: ClusterType, Type: ClusterType,
tls: true, tls: true,
hooks: []red.Hook{durationHook}, hooks: []red.Hook{defaultDurationHook},
}) })
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.NotNil(t, c) assert.NotNil(t, c)

Loading…
Cancel
Save