From 462ddbb145997824eded5c3e1b13c4681f74b483 Mon Sep 17 00:00:00 2001 From: chenquan Date: Wed, 27 Oct 2021 19:46:07 +0800 Subject: [PATCH] Add grpc retry (#1160) * Add grpc retry * Update grpc retry * Add tests * Fix a bug * Add api && some tests * Add comment * Add double check * Add server retry quota * Update optimize code * Fix bug * Update optimize code * Update optimize code * Fix bug --- core/retry/backoff/backoff.go | 31 +++ core/retry/backoff/backoff_test.go | 17 ++ core/retry/options.go | 42 ++++ core/retry/options_test.go | 92 +++++++++ core/retry/retryinterceptor.go | 179 ++++++++++++++++++ core/retry/retryinterceptor_test.go | 25 +++ zrpc/client.go | 5 + zrpc/config.go | 2 + zrpc/internal/client.go | 9 + .../clientinterceptors/retryinterceptor.go | 19 ++ .../retryinterceptor_test.go | 27 +++ zrpc/internal/rpcserver.go | 13 +- zrpc/internal/server.go | 8 +- zrpc/internal/server_test.go | 6 +- .../serverinterceptors/retryinterceptor.go | 33 ++++ .../retryinterceptor_test.go | 40 ++++ zrpc/server.go | 6 +- 17 files changed, 544 insertions(+), 10 deletions(-) create mode 100644 core/retry/backoff/backoff.go create mode 100644 core/retry/backoff/backoff_test.go create mode 100644 core/retry/options.go create mode 100644 core/retry/options_test.go create mode 100644 core/retry/retryinterceptor.go create mode 100644 core/retry/retryinterceptor_test.go create mode 100644 zrpc/internal/clientinterceptors/retryinterceptor.go create mode 100644 zrpc/internal/clientinterceptors/retryinterceptor_test.go create mode 100644 zrpc/internal/serverinterceptors/retryinterceptor.go create mode 100644 zrpc/internal/serverinterceptors/retryinterceptor_test.go diff --git a/core/retry/backoff/backoff.go b/core/retry/backoff/backoff.go new file mode 100644 index 00000000..b3612dd4 --- /dev/null +++ b/core/retry/backoff/backoff.go @@ -0,0 +1,31 @@ +package backoff + +import ( + "math/rand" + "time" +) + +type Func func(attempt int) time.Duration + +// LinearWithJitter waits a set period of time, allowing for jitter (fractional adjustment). +func LinearWithJitter(waitBetween time.Duration, jitterFraction float64) Func { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + return func(attempt int) time.Duration { + multiplier := jitterFraction * (r.Float64()*2 - 1) + return time.Duration(float64(waitBetween) * (1 + multiplier)) + } +} + +// Interval it waits for a fixed period of time between calls. +func Interval(interval time.Duration) Func { + return func(attempt int) time.Duration { + return interval + } +} + +// Exponential produces increasing intervals for each attempt. +func Exponential(scalar time.Duration) Func { + return func(attempt int) time.Duration { + return scalar * time.Duration((1<>1) + } +} diff --git a/core/retry/backoff/backoff_test.go b/core/retry/backoff/backoff_test.go new file mode 100644 index 00000000..c6b7483b --- /dev/null +++ b/core/retry/backoff/backoff_test.go @@ -0,0 +1,17 @@ +package backoff + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestWaitBetween(t *testing.T) { + fn := Interval(time.Second) + assert.EqualValues(t, time.Second, fn(1)) +} + +func TestExponential(t *testing.T) { + fn := Exponential(time.Second) + assert.EqualValues(t, time.Second, fn(1)) +} diff --git a/core/retry/options.go b/core/retry/options.go new file mode 100644 index 00000000..115473bf --- /dev/null +++ b/core/retry/options.go @@ -0,0 +1,42 @@ +package retry + +import ( + "github.com/tal-tech/go-zero/core/retry/backoff" + "google.golang.org/grpc/codes" + "time" +) + +// WithDisable disables the retry behaviour on this call, or this interceptor. +// +// Its semantically the same to `WithMax` +func WithDisable() *CallOption { + return WithMax(0) +} + +// WithMax sets the maximum number of retries on this call, or this interceptor. +func WithMax(maxRetries int) *CallOption { + return &CallOption{apply: func(options *options) { + options.max = maxRetries + }} +} + +// WithBackoff sets the `BackoffFunc` used to control time between retries. +func WithBackoff(backoffFunc backoff.Func) *CallOption { + return &CallOption{apply: func(o *options) { + o.backoffFunc = backoffFunc + }} +} + +// WithCodes Allow code to be retried. +func WithCodes(retryCodes ...codes.Code) *CallOption { + return &CallOption{apply: func(o *options) { + o.codes = retryCodes + }} +} + +// WithPerRetryTimeout timeout for each retry +func WithPerRetryTimeout(timeout time.Duration) *CallOption { + return &CallOption{apply: func(o *options) { + o.perCallTimeout = timeout + }} +} diff --git a/core/retry/options_test.go b/core/retry/options_test.go new file mode 100644 index 00000000..adcf25ae --- /dev/null +++ b/core/retry/options_test.go @@ -0,0 +1,92 @@ +package retry + +import ( + "context" + "errors" + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/logx" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "testing" + "time" +) + +func TestRetryWithDisable(t *testing.T) { + opt := &options{} + assert.EqualValues(t, &options{}, parseRetryCallOptions(opt, WithDisable())) +} + +func TestRetryWithMax(t *testing.T) { + n := 5 + for i := 0; i < n; i++ { + opt := &options{} + assert.EqualValues(t, &options{max: i}, parseRetryCallOptions(opt, WithMax(i))) + } +} + +func TestRetryWithBackoff(t *testing.T) { + opt := &options{} + + retryCallOptions := parseRetryCallOptions(opt, WithBackoff(func(attempt int) time.Duration { + return time.Millisecond + })) + assert.EqualValues(t, time.Millisecond, retryCallOptions.backoffFunc(1)) + +} + +func TestRetryWithCodes(t *testing.T) { + opt := &options{} + c := []codes.Code{codes.Unknown, codes.NotFound} + options := parseRetryCallOptions(opt, WithCodes(c...)) + assert.EqualValues(t, c, options.codes) +} + +func TestRetryWithPerRetryTimeout(t *testing.T) { + opt := &options{} + options := parseRetryCallOptions(opt, WithPerRetryTimeout(time.Millisecond)) + assert.EqualValues(t, time.Millisecond, options.perCallTimeout) +} + +func Test_waitRetryBackoff(t *testing.T) { + + opt := &options{perCallTimeout: time.Second, backoffFunc: func(attempt int) time.Duration { + return time.Second + }} + logger := logx.WithContext(context.Background()) + err := waitRetryBackoff(logger, 1, context.Background(), opt) + assert.NoError(t, err) + ctx, cancelFunc := context.WithTimeout(context.Background(), time.Millisecond) + defer cancelFunc() + err = waitRetryBackoff(logger, 1, ctx, opt) + assert.ErrorIs(t, err, status.FromContextError(context.DeadlineExceeded).Err()) +} + +func Test_isRetriable(t *testing.T) { + assert.False(t, isRetriable(status.FromContextError(context.DeadlineExceeded).Err(), &options{codes: DefaultRetriableCodes})) + assert.True(t, isRetriable(status.Error(codes.ResourceExhausted, ""), &options{codes: DefaultRetriableCodes})) + assert.False(t, isRetriable(errors.New("error"), &options{})) +} + +func Test_perCallContext(t *testing.T) { + opt := &options{perCallTimeout: time.Second, includeRetryHeader: true} + ctx := metadata.NewIncomingContext(context.Background(), map[string][]string{"1": {"1"}}) + callContext := perCallContext(ctx, opt, 1) + md, ok := metadata.FromOutgoingContext(callContext) + assert.True(t, ok) + assert.EqualValues(t, metadata.MD{"1": {"1"}, AttemptMetadataKey: {"1"}}, md) + +} + +func Test_filterCallOptions(t *testing.T) { + grpcEmptyCallOpt := &grpc.EmptyCallOption{} + retryCallOpt := &CallOption{} + options, retryCallOptions := filterCallOptions([]grpc.CallOption{ + grpcEmptyCallOpt, + retryCallOpt, + }) + assert.EqualValues(t, []grpc.CallOption{grpcEmptyCallOpt}, options) + assert.EqualValues(t, []*CallOption{retryCallOpt}, retryCallOptions) + +} diff --git a/core/retry/retryinterceptor.go b/core/retry/retryinterceptor.go new file mode 100644 index 00000000..f85bb2f7 --- /dev/null +++ b/core/retry/retryinterceptor.go @@ -0,0 +1,179 @@ +package retry + +import ( + "context" + "github.com/tal-tech/go-zero/core/logx" + "github.com/tal-tech/go-zero/core/retry/backoff" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "strconv" + "time" +) + +const AttemptMetadataKey = "x-retry-attempt" + +var ( + // DefaultRetriableCodes default retry code + DefaultRetriableCodes = []codes.Code{codes.ResourceExhausted, codes.Unavailable} + // defaultRetryOptions default retry configuration + defaultRetryOptions = &options{ + max: 0, // disabled + perCallTimeout: 0, // disabled + includeRetryHeader: true, + codes: DefaultRetriableCodes, + backoffFunc: backoff.LinearWithJitter(50*time.Millisecond /*jitter*/, 0.10), + } +) + +type ( + // options retry the configuration + options struct { + max int + perCallTimeout time.Duration + includeRetryHeader bool + codes []codes.Code + backoffFunc backoff.Func + } + // CallOption is a grpc.CallOption that is local to grpc retry. + CallOption struct { + grpc.EmptyCallOption // make sure we implement private after() and before() fields so we don't panic. + apply func(opt *options) + } +) + +func waitRetryBackoff(logger logx.Logger, attempt int, ctx context.Context, retryOptions *options) error { + var waitTime time.Duration = 0 + if attempt > 0 { + waitTime = retryOptions.backoffFunc(attempt) + } + if waitTime > 0 { + timer := time.NewTimer(waitTime) + logger.Infof("grpc retry attempt: %d, backoff for %v", attempt, waitTime) + select { + case <-ctx.Done(): + timer.Stop() + return status.FromContextError(ctx.Err()).Err() + case <-timer.C: + // double check + err := ctx.Err() + if err != nil { + return status.FromContextError(err).Err() + } + } + } + return nil +} + +func isRetriable(err error, retryOptions *options) bool { + errCode := status.Code(err) + if isContextError(err) { + return false + } + for _, code := range retryOptions.codes { + if code == errCode { + return true + } + } + return false +} + +func isContextError(err error) bool { + code := status.Code(err) + return code == codes.DeadlineExceeded || code == codes.Canceled +} + +func reuseOrNewWithCallOptions(opt *options, retryCallOptions []*CallOption) *options { + if len(retryCallOptions) == 0 { + return opt + } + return parseRetryCallOptions(opt, retryCallOptions...) +} + +func parseRetryCallOptions(opt *options, opts ...*CallOption) *options { + for _, option := range opts { + option.apply(opt) + } + return opt +} + +func perCallContext(ctx context.Context, callOpts *options, attempt int) context.Context { + if attempt > 0 { + if callOpts.perCallTimeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, callOpts.perCallTimeout) + _ = cancel + } + if callOpts.includeRetryHeader { + cloneMd := extractIncomingAndClone(ctx) + cloneMd.Set(AttemptMetadataKey, strconv.Itoa(attempt)) + ctx = metadata.NewOutgoingContext(ctx, cloneMd) + } + } + + return ctx +} + +func extractIncomingAndClone(ctx context.Context) metadata.MD { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return metadata.MD{} + } + // clone + return md.Copy() +} + +func filterCallOptions(callOptions []grpc.CallOption) (grpcOptions []grpc.CallOption, retryOptions []*CallOption) { + for _, opt := range callOptions { + if co, ok := opt.(*CallOption); ok { + retryOptions = append(retryOptions, co) + } else { + grpcOptions = append(grpcOptions, opt) + } + } + return grpcOptions, retryOptions +} + +func Do(ctx context.Context, call func(ctx context.Context, opts ...grpc.CallOption) error, opts ...grpc.CallOption) error { + logger := logx.WithContext(ctx) + grpcOpts, retryOpts := filterCallOptions(opts) + callOpts := reuseOrNewWithCallOptions(defaultRetryOptions, retryOpts) + + if callOpts.max == 0 { + return call(ctx, opts...) + } + var lastErr error + for attempt := 0; attempt <= callOpts.max; attempt++ { + if err := waitRetryBackoff(logger, attempt, ctx, callOpts); err != nil { + return err + } + + callCtx := perCallContext(ctx, callOpts, attempt) + lastErr = call(callCtx, grpcOpts...) + + if lastErr == nil { + return nil + } + if attempt == 0 { + logger.Errorf("grpc call failed, got err: %v", lastErr) + } else { + logger.Errorf("grpc retry attempt: %d, got err: %v", attempt, lastErr) + } + if isContextError(lastErr) { + if ctx.Err() != nil { + logger.Errorf("grpc retry attempt: %d, parent context error: %v", attempt, ctx.Err()) + return lastErr + } else if callOpts.perCallTimeout != 0 { + logger.Errorf("grpc retry attempt: %d, context error from retry call", attempt) + continue + } + } + if !isRetriable(lastErr, callOpts) { + return lastErr + } + } + return lastErr +} diff --git a/core/retry/retryinterceptor_test.go b/core/retry/retryinterceptor_test.go new file mode 100644 index 00000000..f5a5756f --- /dev/null +++ b/core/retry/retryinterceptor_test.go @@ -0,0 +1,25 @@ +package retry + +import ( + "context" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "testing" +) + +func TestDo(t *testing.T) { + n := 4 + for i := 0; i < n; i++ { + count := 0 + err := Do(context.Background(), func(ctx context.Context, opts ...grpc.CallOption) error { + count++ + return status.Error(codes.ResourceExhausted, "ResourceExhausted") + + }, WithMax(i)) + assert.Error(t, err) + assert.Equal(t, i+1, count) + } + +} diff --git a/zrpc/client.go b/zrpc/client.go index e8cd14c1..f671d2e7 100644 --- a/zrpc/client.go +++ b/zrpc/client.go @@ -14,6 +14,8 @@ var ( WithDialOption = internal.WithDialOption // WithTimeout is an alias of internal.WithTimeout. WithTimeout = internal.WithTimeout + // WithRetry is an alias of internal.WithRetry. + WithRetry = internal.WithRetry // WithUnaryClientInterceptor is an alias of internal.WithUnaryClientInterceptor. WithUnaryClientInterceptor = internal.WithUnaryClientInterceptor ) @@ -52,6 +54,9 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) { if c.Timeout > 0 { opts = append(opts, WithTimeout(time.Duration(c.Timeout)*time.Millisecond)) } + if c.Retry { + opts = append(opts, WithRetry()) + } opts = append(opts, options...) var target string diff --git a/zrpc/config.go b/zrpc/config.go index 9ec3fd8f..36cbcdba 100644 --- a/zrpc/config.go +++ b/zrpc/config.go @@ -18,6 +18,7 @@ type ( // setting 0 means no timeout Timeout int64 `json:",default=2000"` CpuThreshold int64 `json:",default=900,range=[0:1000]"` + MaxRetries int `json:",range=[0:]"` } // A RpcClientConf is a rpc client config. @@ -27,6 +28,7 @@ type ( Target string `json:",optional"` App string `json:",optional"` Token string `json:",optional"` + Retry bool `json:",optional"` // grpc auto retry Timeout int64 `json:",default=2000"` } ) diff --git a/zrpc/internal/client.go b/zrpc/internal/client.go index 4b942598..861982ef 100644 --- a/zrpc/internal/client.go +++ b/zrpc/internal/client.go @@ -31,6 +31,7 @@ type ( // A ClientOptions is a client options. ClientOptions struct { Timeout time.Duration + Retry bool DialOptions []grpc.DialOption } @@ -72,6 +73,7 @@ func (c *client) buildDialOptions(opts ...ClientOption) []grpc.DialOption { clientinterceptors.PrometheusInterceptor, clientinterceptors.BreakerInterceptor, clientinterceptors.TimeoutInterceptor(cliOpts.Timeout), + clientinterceptors.RetryInterceptor(cliOpts.Retry), ), WithStreamClientInterceptors( clientinterceptors.StreamTracingInterceptor, @@ -117,6 +119,13 @@ func WithTimeout(timeout time.Duration) ClientOption { } } +// WithRetry returns a func to customize a ClientOptions with auto retry. +func WithRetry() ClientOption { + return func(options *ClientOptions) { + options.Retry = true + } +} + // WithUnaryClientInterceptor returns a func to customize a ClientOptions with given interceptor. func WithUnaryClientInterceptor(interceptor grpc.UnaryClientInterceptor) ClientOption { return func(options *ClientOptions) { diff --git a/zrpc/internal/clientinterceptors/retryinterceptor.go b/zrpc/internal/clientinterceptors/retryinterceptor.go new file mode 100644 index 00000000..a5cdee6c --- /dev/null +++ b/zrpc/internal/clientinterceptors/retryinterceptor.go @@ -0,0 +1,19 @@ +package clientinterceptors + +import ( + "context" + "github.com/tal-tech/go-zero/core/retry" + "google.golang.org/grpc" +) + +// RetryInterceptor retry interceptor +func RetryInterceptor(enable bool) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if !enable { + return invoker(ctx, method, req, reply, cc, opts...) + } + return retry.Do(ctx, func(ctx context.Context, callOpts ...grpc.CallOption) error { + return invoker(ctx, method, req, reply, cc, callOpts...) + }, opts...) + } +} diff --git a/zrpc/internal/clientinterceptors/retryinterceptor_test.go b/zrpc/internal/clientinterceptors/retryinterceptor_test.go new file mode 100644 index 00000000..7952cb78 --- /dev/null +++ b/zrpc/internal/clientinterceptors/retryinterceptor_test.go @@ -0,0 +1,27 @@ +package clientinterceptors + +import ( + "context" + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/retry" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "testing" +) + +func TestRetryInterceptor_WithMax(t *testing.T) { + n := 4 + for i := 0; i < n; i++ { + count := 0 + cc := new(grpc.ClientConn) + err := RetryInterceptor(true)(context.Background(), "/1", nil, nil, cc, + func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + count++ + return status.Error(codes.ResourceExhausted, "ResourceExhausted") + }, retry.WithMax(i)) + assert.Error(t, err) + assert.Equal(t, i+1, count) + } + +} diff --git a/zrpc/internal/rpcserver.go b/zrpc/internal/rpcserver.go index 784e6854..ce6165fe 100644 --- a/zrpc/internal/rpcserver.go +++ b/zrpc/internal/rpcserver.go @@ -14,7 +14,8 @@ type ( ServerOption func(options *rpcServerOptions) rpcServerOptions struct { - metrics *stat.Metrics + metrics *stat.Metrics + MaxRetries int } rpcServer struct { @@ -38,7 +39,7 @@ func NewRpcServer(address string, opts ...ServerOption) Server { } return &rpcServer{ - baseRpcServer: newBaseRpcServer(address, options.metrics), + baseRpcServer: newBaseRpcServer(address, &options), } } @@ -55,6 +56,7 @@ func (s *rpcServer) Start(register RegisterFn) error { unaryInterceptors := []grpc.UnaryServerInterceptor{ serverinterceptors.UnaryTracingInterceptor, + serverinterceptors.RetryInterceptor(s.maxRetries), serverinterceptors.UnaryCrashInterceptor, serverinterceptors.UnaryStatInterceptor(s.metrics), serverinterceptors.UnaryPrometheusInterceptor, @@ -87,3 +89,10 @@ func WithMetrics(metrics *stat.Metrics) ServerOption { options.metrics = metrics } } + +// WithMaxRetries returns a func that sets a max retries to a Server. +func WithMaxRetries(maxRetries int) ServerOption { + return func(options *rpcServerOptions) { + options.MaxRetries = maxRetries + } +} diff --git a/zrpc/internal/server.go b/zrpc/internal/server.go index 10b0c818..7f8423b2 100644 --- a/zrpc/internal/server.go +++ b/zrpc/internal/server.go @@ -21,16 +21,18 @@ type ( baseRpcServer struct { address string metrics *stat.Metrics + maxRetries int options []grpc.ServerOption streamInterceptors []grpc.StreamServerInterceptor unaryInterceptors []grpc.UnaryServerInterceptor } ) -func newBaseRpcServer(address string, metrics *stat.Metrics) *baseRpcServer { +func newBaseRpcServer(address string, rpcServerOpts *rpcServerOptions) *baseRpcServer { return &baseRpcServer{ - address: address, - metrics: metrics, + address: address, + metrics: rpcServerOpts.metrics, + maxRetries: rpcServerOpts.MaxRetries, } } diff --git a/zrpc/internal/server_test.go b/zrpc/internal/server_test.go index c4daf9f1..edb563c7 100644 --- a/zrpc/internal/server_test.go +++ b/zrpc/internal/server_test.go @@ -11,7 +11,7 @@ import ( func TestBaseRpcServer_AddOptions(t *testing.T) { metrics := stat.NewMetrics("foo") - server := newBaseRpcServer("foo", metrics) + server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics}) server.SetName("bar") var opt grpc.EmptyServerOption server.AddOptions(opt) @@ -20,7 +20,7 @@ func TestBaseRpcServer_AddOptions(t *testing.T) { func TestBaseRpcServer_AddStreamInterceptors(t *testing.T) { metrics := stat.NewMetrics("foo") - server := newBaseRpcServer("foo", metrics) + server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics}) server.SetName("bar") var vals []int f := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { @@ -36,7 +36,7 @@ func TestBaseRpcServer_AddStreamInterceptors(t *testing.T) { func TestBaseRpcServer_AddUnaryInterceptors(t *testing.T) { metrics := stat.NewMetrics("foo") - server := newBaseRpcServer("foo", metrics) + server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics}) server.SetName("bar") var vals []int f := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( diff --git a/zrpc/internal/serverinterceptors/retryinterceptor.go b/zrpc/internal/serverinterceptors/retryinterceptor.go new file mode 100644 index 00000000..401d099f --- /dev/null +++ b/zrpc/internal/serverinterceptors/retryinterceptor.go @@ -0,0 +1,33 @@ +package serverinterceptors + +import ( + "context" + "github.com/tal-tech/go-zero/core/logx" + "github.com/tal-tech/go-zero/core/retry" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "strconv" +) + +func RetryInterceptor(maxAttempt int) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + var md metadata.MD + requestMd, ok := metadata.FromIncomingContext(ctx) + if ok { + md = requestMd.Copy() + attemptMd := md.Get(retry.AttemptMetadataKey) + if len(attemptMd) != 0 && attemptMd[0] != "" { + if attempt, err := strconv.Atoi(attemptMd[0]); err == nil { + if attempt > maxAttempt { + logx.WithContext(ctx).Errorf("retries exceeded:%d, max retries:%d", attempt, maxAttempt) + return nil, status.Error(codes.FailedPrecondition, "Retries exceeded") + } + } + } + } + + return handler(ctx, req) + } +} diff --git a/zrpc/internal/serverinterceptors/retryinterceptor_test.go b/zrpc/internal/serverinterceptors/retryinterceptor_test.go new file mode 100644 index 00000000..1f0c601b --- /dev/null +++ b/zrpc/internal/serverinterceptors/retryinterceptor_test.go @@ -0,0 +1,40 @@ +package serverinterceptors + +import ( + "context" + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/retry" + "google.golang.org/grpc/metadata" + "testing" +) + +func TestRetryInterceptor(t *testing.T) { + t.Run("retries exceeded", func(t *testing.T) { + interceptor := RetryInterceptor(2) + ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{retry.AttemptMetadataKey: "3"})) + resp, err := interceptor(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("reasonable retries", func(t *testing.T) { + interceptor := RetryInterceptor(2) + ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{retry.AttemptMetadataKey: "2"})) + resp, err := interceptor(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + assert.NoError(t, err) + assert.Nil(t, resp) + }) + t.Run("no retries", func(t *testing.T) { + interceptor := RetryInterceptor(0) + resp, err := interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + assert.NoError(t, err) + assert.Nil(t, resp) + }) + +} diff --git a/zrpc/server.go b/zrpc/server.go index 0f5fe8a5..e1930dd2 100644 --- a/zrpc/server.go +++ b/zrpc/server.go @@ -38,13 +38,15 @@ func NewServer(c RpcServerConf, register internal.RegisterFn) (*RpcServer, error var server internal.Server metrics := stat.NewMetrics(c.ListenOn) + serverOptions := []internal.ServerOption{internal.WithMetrics(metrics), internal.WithMaxRetries(c.MaxRetries)} + if c.HasEtcd() { - server, err = internal.NewRpcPubServer(c.Etcd.Hosts, c.Etcd.Key, c.ListenOn, internal.WithMetrics(metrics)) + server, err = internal.NewRpcPubServer(c.Etcd.Hosts, c.Etcd.Key, c.ListenOn, serverOptions...) if err != nil { return nil, err } } else { - server = internal.NewRpcServer(c.ListenOn, internal.WithMetrics(metrics)) + server = internal.NewRpcServer(c.ListenOn, serverOptions...) } server.SetName(c.Name)