From 922efbfc2d39476794b3d65966effb182f291023 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Thu, 26 Oct 2023 08:55:26 +0800 Subject: [PATCH] chore: refactor zrpc timeout (#3671) --- zrpc/client.go | 6 ++-- zrpc/client_test.go | 28 +++++++-------- zrpc/config.go | 6 ++-- .../clientinterceptors/timeoutinterceptor.go | 33 +++++++++--------- .../timeoutinterceptor_test.go | 2 +- zrpc/internal/config.go | 3 +- .../serverinterceptors/timeoutinterceptor.go | 34 ++++++++----------- .../timeoutinterceptor_test.go | 22 ++---------- zrpc/server.go | 8 ++--- zrpc/server_test.go | 8 ++--- 10 files changed, 63 insertions(+), 87 deletions(-) diff --git a/zrpc/client.go b/zrpc/client.go index a34ad087..e44abc16 100644 --- a/zrpc/client.go +++ b/zrpc/client.go @@ -111,7 +111,7 @@ func SetClientSlowThreshold(threshold time.Duration) { clientinterceptors.SetSlowThreshold(threshold) } -// WithTimeoutCallOption return a call option with given timeout. -func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption { - return clientinterceptors.WithTimeoutCallOption(timeout) +// WithCallTimeout return a call option with given timeout to make a method call. +func WithCallTimeout(timeout time.Duration) grpc.CallOption { + return clientinterceptors.WithCallTimeout(timeout) } diff --git a/zrpc/client_test.go b/zrpc/client_test.go index 95a0c215..09a06cbb 100644 --- a/zrpc/client_test.go +++ b/zrpc/client_test.go @@ -41,12 +41,12 @@ func dialer() func(context.Context, string) (net.Conn, error) { func TestDepositServer_Deposit(t *testing.T) { tests := []struct { - name string - amount float32 - timeoutCallOption time.Duration - res *mock.DepositResponse - errCode codes.Code - errMsg string + name string + amount float32 + timeout time.Duration + res *mock.DepositResponse + errCode codes.Code + errMsg string }{ { name: "invalid request with negative amount", @@ -66,12 +66,12 @@ func TestDepositServer_Deposit(t *testing.T) { errMsg: "context deadline exceeded", }, { - name: "valid request with timeout call option", - amount: 2000.00, - timeoutCallOption: time.Second * 3, - res: &mock.DepositResponse{Ok: true}, - errCode: codes.OK, - errMsg: "", + name: "valid request with timeout call option", + amount: 2000.00, + timeout: time.Second * 3, + res: &mock.DepositResponse{Ok: true}, + errCode: codes.OK, + errMsg: "", }, } @@ -171,8 +171,8 @@ func TestDepositServer_Deposit(t *testing.T) { err error ) - if tt.timeoutCallOption > 0 { - response, err = cli.Deposit(ctx, request, WithTimeoutCallOption(tt.timeoutCallOption)) + if tt.timeout > 0 { + response, err = cli.Deposit(ctx, request, WithCallTimeout(tt.timeout)) } else { response, err = cli.Deposit(ctx, request) } diff --git a/zrpc/config.go b/zrpc/config.go index 54123e39..84a32160 100644 --- a/zrpc/config.go +++ b/zrpc/config.go @@ -17,8 +17,8 @@ type ( ServerMiddlewaresConf = internal.ServerMiddlewaresConf // StatConf defines the stat config. StatConf = internal.StatConf - // ServerSpecifiedTimeoutConf defines specified timeout for gRPC method. - ServerSpecifiedTimeoutConf = internal.ServerSpecifiedTimeoutConf + // MethodTimeoutConf defines specified timeout for gRPC method. + MethodTimeoutConf = internal.MethodTimeoutConf // A RpcClientConf is a rpc client config. RpcClientConf struct { @@ -48,7 +48,7 @@ type ( Health bool `json:",default=true"` Middlewares ServerMiddlewaresConf // setting specified timeout for gRPC method - SpecifiedTimeouts []ServerSpecifiedTimeoutConf `json:",optional"` + MethodTimeouts []MethodTimeoutConf `json:",optional"` } ) diff --git a/zrpc/internal/clientinterceptors/timeoutinterceptor.go b/zrpc/internal/clientinterceptors/timeoutinterceptor.go index b28f82a0..770030fe 100644 --- a/zrpc/internal/clientinterceptors/timeoutinterceptor.go +++ b/zrpc/internal/clientinterceptors/timeoutinterceptor.go @@ -7,11 +7,17 @@ import ( "google.golang.org/grpc" ) +// TimeoutCallOption is a call option that controls timeout. +type TimeoutCallOption struct { + grpc.EmptyCallOption + timeout time.Duration +} + // TimeoutInterceptor is an interceptor that controls timeout. func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - t := getTimeoutByCallOptions(opts, timeout) + t := getTimeoutFromCallOptions(opts, timeout) if t <= 0 { return invoker(ctx, method, req, reply, cc, opts...) } @@ -23,24 +29,19 @@ func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor { } } -func getTimeoutByCallOptions(callOptions []grpc.CallOption, defaultTimeout time.Duration) time.Duration { - for _, callOption := range callOptions { - if o, ok := callOption.(TimeoutCallOption); ok { +// WithCallTimeout returns a call option that controls method call timeout. +func WithCallTimeout(timeout time.Duration) grpc.CallOption { + return TimeoutCallOption{ + timeout: timeout, + } +} + +func getTimeoutFromCallOptions(opts []grpc.CallOption, defaultTimeout time.Duration) time.Duration { + for _, opt := range opts { + if o, ok := opt.(TimeoutCallOption); ok { return o.timeout } } return defaultTimeout } - -type TimeoutCallOption struct { - grpc.EmptyCallOption - - timeout time.Duration -} - -func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption { - return TimeoutCallOption{ - timeout: timeout, - } -} diff --git a/zrpc/internal/clientinterceptors/timeoutinterceptor_test.go b/zrpc/internal/clientinterceptors/timeoutinterceptor_test.go index 68d654c9..d9347665 100644 --- a/zrpc/internal/clientinterceptors/timeoutinterceptor_test.go +++ b/zrpc/internal/clientinterceptors/timeoutinterceptor_test.go @@ -114,7 +114,7 @@ func TestTimeoutInterceptor_TimeoutCallOption(t *testing.T) { cc := new(grpc.ClientConn) var co []grpc.CallOption if tt.args.callOptionTimeout > 0 { - co = append(co, WithTimeoutCallOption(tt.args.callOptionTimeout)) + co = append(co, WithCallTimeout(tt.args.callOptionTimeout)) } err := interceptor(context.Background(), "/foo", nil, nil, cc, diff --git a/zrpc/internal/config.go b/zrpc/internal/config.go index d2a542de..df141c46 100644 --- a/zrpc/internal/config.go +++ b/zrpc/internal/config.go @@ -25,5 +25,6 @@ type ( Breaker bool `json:",default=true"` } - ServerSpecifiedTimeoutConf = serverinterceptors.ServerSpecifiedTimeoutConf + // MethodTimeoutConf defines specified timeout for gRPC methods. + MethodTimeoutConf = serverinterceptors.MethodTimeoutConf ) diff --git a/zrpc/internal/serverinterceptors/timeoutinterceptor.go b/zrpc/internal/serverinterceptors/timeoutinterceptor.go index 277c89c0..a6eebbeb 100644 --- a/zrpc/internal/serverinterceptors/timeoutinterceptor.go +++ b/zrpc/internal/serverinterceptors/timeoutinterceptor.go @@ -15,21 +15,22 @@ import ( ) type ( - // ServerSpecifiedTimeoutConf defines specified timeout for gRPC method. - ServerSpecifiedTimeoutConf struct { + // MethodTimeoutConf defines specified timeout for gRPC method. + MethodTimeoutConf struct { FullMethod string Timeout time.Duration } - specifiedTimeoutCache map[string]time.Duration + methodTimeouts map[string]time.Duration ) // UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests. -func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerSpecifiedTimeoutConf) grpc.UnaryServerInterceptor { - cache := cacheSpecifiedTimeout(specifiedTimeouts) +func UnaryTimeoutInterceptor(timeout time.Duration, + methodTimeouts ...MethodTimeoutConf) grpc.UnaryServerInterceptor { + timeouts := buildMethodTimeouts(methodTimeouts) return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - t := getTimeoutByUnaryServerInfo(info, timeout, cache) + t := getTimeoutByUnaryServerInfo(info.FullMethod, timeouts, timeout) ctx, cancel := context.WithTimeout(ctx, t) defer cancel() @@ -72,27 +73,22 @@ func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerS } } -func cacheSpecifiedTimeout(specifiedTimeouts []ServerSpecifiedTimeoutConf) specifiedTimeoutCache { - cache := make(specifiedTimeoutCache, len(specifiedTimeouts)) - for _, st := range specifiedTimeouts { +func buildMethodTimeouts(timeouts []MethodTimeoutConf) methodTimeouts { + mt := make(methodTimeouts, len(timeouts)) + for _, st := range timeouts { if st.FullMethod != "" { - cache[st.FullMethod] = st.Timeout + mt[st.FullMethod] = st.Timeout } } - return cache + return mt } -func getTimeoutByUnaryServerInfo(info *grpc.UnaryServerInfo, defaultTimeout time.Duration, specifiedTimeout specifiedTimeoutCache) time.Duration { - if ts, ok := info.Server.(TimeoutStrategy); ok { - return ts.GetTimeoutByFullMethod(info.FullMethod, defaultTimeout) - } else if v, ok := specifiedTimeout[info.FullMethod]; ok { +func getTimeoutByUnaryServerInfo(method string, timeouts methodTimeouts, + defaultTimeout time.Duration) time.Duration { + if v, ok := timeouts[method]; ok { return v } return defaultTimeout } - -type TimeoutStrategy interface { - GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration -} diff --git a/zrpc/internal/serverinterceptors/timeoutinterceptor_test.go b/zrpc/internal/serverinterceptors/timeoutinterceptor_test.go index fd9e4d14..1469e4c3 100644 --- a/zrpc/internal/serverinterceptors/timeoutinterceptor_test.go +++ b/zrpc/internal/serverinterceptors/timeoutinterceptor_test.go @@ -103,13 +103,6 @@ type tempServer struct { func (s *tempServer) run(duration time.Duration) { time.Sleep(duration) } -func (s *tempServer) GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration { - if fullMethod == "/" { - return defaultTimeout - } - - return s.timeout -} func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) { type args struct { @@ -136,17 +129,6 @@ func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) { }, wantErr: nil, }, - { - name: "do not timeout with timeout strategy", - args: args{ - interceptorTimeout: time.Second, - contextTimeout: time.Second * 5, - serverTimeout: time.Second * 3, - runTime: time.Second * 2, - fullMethod: "/2s", - }, - wantErr: nil, - }, { name: "timeout with interceptor timeout", args: args{ @@ -235,9 +217,9 @@ func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - var specifiedTimeouts []ServerSpecifiedTimeoutConf + var specifiedTimeouts []MethodTimeoutConf if tt.args.methodTimeout > 0 { - specifiedTimeouts = []ServerSpecifiedTimeoutConf{ + specifiedTimeouts = []MethodTimeoutConf{ { FullMethod: tt.args.method, Timeout: tt.args.methodTimeout, diff --git a/zrpc/server.go b/zrpc/server.go index 9bf4e89b..d891c8e6 100644 --- a/zrpc/server.go +++ b/zrpc/server.go @@ -131,12 +131,8 @@ func setupInterceptors(svr internal.Server, c RpcServerConf, metrics *stat.Metri } if c.Timeout > 0 { - svr.AddUnaryInterceptors( - serverinterceptors.UnaryTimeoutInterceptor( - time.Duration(c.Timeout)*time.Millisecond, - c.SpecifiedTimeouts..., - ), - ) + svr.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor( + time.Duration(c.Timeout)*time.Millisecond, c.MethodTimeouts...)) } if c.Auth { diff --git a/zrpc/server_test.go b/zrpc/server_test.go index af5ebc72..e99f224f 100644 --- a/zrpc/server_test.go +++ b/zrpc/server_test.go @@ -40,7 +40,7 @@ func TestServer_setupInterceptors(t *testing.T) { Prometheus: true, Breaker: true, }, - SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{ + MethodTimeouts: []MethodTimeoutConf{ { FullMethod: "/foo", Timeout: 5 * time.Second, @@ -81,7 +81,7 @@ func TestServer(t *testing.T) { Prometheus: true, Breaker: true, }, - SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{ + MethodTimeouts: []MethodTimeoutConf{ { FullMethod: "/foo", Timeout: time.Second, @@ -117,7 +117,7 @@ func TestServerError(t *testing.T) { Prometheus: true, Breaker: true, }, - SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{}, + MethodTimeouts: []MethodTimeoutConf{}, }, func(server *grpc.Server) { }) assert.NotNil(t, err) @@ -144,7 +144,7 @@ func TestServer_HasEtcd(t *testing.T) { Prometheus: true, Breaker: true, }, - SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{}, + MethodTimeouts: []MethodTimeoutConf{}, }, func(server *grpc.Server) { }) svr.AddOptions(grpc.ConnectionTimeout(time.Hour))