diff --git a/zrpc/client.go b/zrpc/client.go index ab0f924a..a34ad087 100644 --- a/zrpc/client.go +++ b/zrpc/client.go @@ -110,3 +110,8 @@ func DontLogClientContentForMethod(method string) { func SetClientSlowThreshold(threshold time.Duration) { clientinterceptors.SetSlowThreshold(threshold) } + +// WithTimeoutCallOption return a call option with given timeout. +func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption { + return clientinterceptors.WithTimeoutCallOption(timeout) +} diff --git a/zrpc/client_test.go b/zrpc/client_test.go index eb232a82..95a0c215 100644 --- a/zrpc/client_test.go +++ b/zrpc/client_test.go @@ -41,32 +41,37 @@ func dialer() func(context.Context, string) (net.Conn, error) { func TestDepositServer_Deposit(t *testing.T) { tests := []struct { - name string - amount float32 - res *mock.DepositResponse - errCode codes.Code - errMsg string + name string + amount float32 + timeoutCallOption time.Duration + res *mock.DepositResponse + errCode codes.Code + errMsg string }{ { - "invalid request with negative amount", - -1.11, - nil, - codes.InvalidArgument, - fmt.Sprintf("cannot deposit %v", -1.11), + name: "invalid request with negative amount", + amount: -1.11, + errCode: codes.InvalidArgument, + errMsg: fmt.Sprintf("cannot deposit %v", -1.11), }, { - "valid request with non negative amount", - 0.00, - &mock.DepositResponse{Ok: true}, - codes.OK, - "", + name: "valid request with non negative amount", + res: &mock.DepositResponse{Ok: true}, + errCode: codes.OK, }, { - "valid request with long handling time", - 2000.00, - nil, - codes.DeadlineExceeded, - "context deadline exceeded", + name: "valid request with long handling time", + amount: 2000.00, + errCode: codes.DeadlineExceeded, + errMsg: "context deadline exceeded", + }, + { + name: "valid request with timeout call option", + amount: 2000.00, + timeoutCallOption: time.Second * 3, + res: &mock.DepositResponse{Ok: true}, + errCode: codes.OK, + errMsg: "", }, } @@ -156,9 +161,22 @@ func TestDepositServer_Deposit(t *testing.T) { client := client t.Run(tt.name, func(t *testing.T) { t.Parallel() + cli := mock.NewDepositServiceClient(client.Conn()) request := &mock.DepositRequest{Amount: tt.amount} - response, err := cli.Deposit(context.Background(), request) + + var ( + ctx = context.Background() + response *mock.DepositResponse + err error + ) + + if tt.timeoutCallOption > 0 { + response, err = cli.Deposit(ctx, request, WithTimeoutCallOption(tt.timeoutCallOption)) + } else { + response, err = cli.Deposit(ctx, request) + } + if response != nil { assert.True(t, len(response.String()) > 0) if response.GetOk() != tt.res.GetOk() { diff --git a/zrpc/config.go b/zrpc/config.go index afbb4882..54123e39 100644 --- a/zrpc/config.go +++ b/zrpc/config.go @@ -17,6 +17,8 @@ type ( ServerMiddlewaresConf = internal.ServerMiddlewaresConf // StatConf defines the stat config. StatConf = internal.StatConf + // ServerSpecifiedTimeoutConf defines specified timeout for gRPC method. + ServerSpecifiedTimeoutConf = internal.ServerSpecifiedTimeoutConf // A RpcClientConf is a rpc client config. RpcClientConf struct { @@ -45,6 +47,8 @@ type ( // grpc health check switch Health bool `json:",default=true"` Middlewares ServerMiddlewaresConf + // setting specified timeout for gRPC method + SpecifiedTimeouts []ServerSpecifiedTimeoutConf `json:",optional"` } ) diff --git a/zrpc/internal/clientinterceptors/timeoutinterceptor.go b/zrpc/internal/clientinterceptors/timeoutinterceptor.go index 20111fc1..b28f82a0 100644 --- a/zrpc/internal/clientinterceptors/timeoutinterceptor.go +++ b/zrpc/internal/clientinterceptors/timeoutinterceptor.go @@ -11,13 +11,36 @@ import ( func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - if timeout <= 0 { + t := getTimeoutByCallOptions(opts, timeout) + if t <= 0 { return invoker(ctx, method, req, reply, cc, opts...) } - ctx, cancel := context.WithTimeout(ctx, timeout) + ctx, cancel := context.WithTimeout(ctx, t) defer cancel() return invoker(ctx, method, req, reply, cc, opts...) } } + +func getTimeoutByCallOptions(callOptions []grpc.CallOption, defaultTimeout time.Duration) time.Duration { + for _, callOption := range callOptions { + if o, ok := callOption.(TimeoutCallOption); ok { + return o.timeout + } + } + + return defaultTimeout +} + +type TimeoutCallOption struct { + grpc.EmptyCallOption + + timeout time.Duration +} + +func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption { + return TimeoutCallOption{ + timeout: timeout, + } +} diff --git a/zrpc/internal/clientinterceptors/timeoutinterceptor_test.go b/zrpc/internal/clientinterceptors/timeoutinterceptor_test.go index 67b5ab05..68d654c9 100644 --- a/zrpc/internal/clientinterceptors/timeoutinterceptor_test.go +++ b/zrpc/internal/clientinterceptors/timeoutinterceptor_test.go @@ -66,3 +66,74 @@ func TestTimeoutInterceptor_panic(t *testing.T) { }) } } + +func TestTimeoutInterceptor_TimeoutCallOption(t *testing.T) { + type args struct { + interceptorTimeout time.Duration + callOptionTimeout time.Duration + runTime time.Duration + } + var tests = []struct { + name string + args args + wantErr error + }{ + { + name: "do not timeout without call option timeout", + args: args{ + interceptorTimeout: time.Second, + runTime: time.Millisecond * 50, + }, + wantErr: nil, + }, + { + name: "timeout without call option timeout", + args: args{ + interceptorTimeout: time.Second, + runTime: time.Second * 2, + }, + wantErr: context.DeadlineExceeded, + }, + { + name: "do not timeout with call option timeout", + args: args{ + interceptorTimeout: time.Second, + callOptionTimeout: time.Second * 3, + runTime: time.Second * 2, + }, + wantErr: nil, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + interceptor := TimeoutInterceptor(tt.args.interceptorTimeout) + + cc := new(grpc.ClientConn) + var co []grpc.CallOption + if tt.args.callOptionTimeout > 0 { + co = append(co, WithTimeoutCallOption(tt.args.callOptionTimeout)) + } + + err := interceptor(context.Background(), "/foo", nil, nil, cc, + func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, + opts ...grpc.CallOption) error { + timer := time.NewTimer(tt.args.runTime) + defer timer.Stop() + + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } + }, co..., + ) + t.Logf("error: %+v", err) + + assert.EqualValues(t, tt.wantErr, err) + }) + } +} diff --git a/zrpc/internal/config.go b/zrpc/internal/config.go index b60e1f09..d2a542de 100644 --- a/zrpc/internal/config.go +++ b/zrpc/internal/config.go @@ -24,4 +24,6 @@ type ( Prometheus bool `json:",default=true"` Breaker bool `json:",default=true"` } + + ServerSpecifiedTimeoutConf = serverinterceptors.ServerSpecifiedTimeoutConf ) diff --git a/zrpc/internal/serverinterceptors/timeoutinterceptor.go b/zrpc/internal/serverinterceptors/timeoutinterceptor.go index fb652909..277c89c0 100644 --- a/zrpc/internal/serverinterceptors/timeoutinterceptor.go +++ b/zrpc/internal/serverinterceptors/timeoutinterceptor.go @@ -14,11 +14,23 @@ import ( "google.golang.org/grpc/status" ) +type ( + // ServerSpecifiedTimeoutConf defines specified timeout for gRPC method. + ServerSpecifiedTimeoutConf struct { + FullMethod string + Timeout time.Duration + } + + specifiedTimeoutCache map[string]time.Duration +) + // UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests. -func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor { +func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerSpecifiedTimeoutConf) grpc.UnaryServerInterceptor { + cache := cacheSpecifiedTimeout(specifiedTimeouts) return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - ctx, cancel := context.WithTimeout(ctx, timeout) + t := getTimeoutByUnaryServerInfo(info, timeout, cache) + ctx, cancel := context.WithTimeout(ctx, t) defer cancel() var resp any @@ -59,3 +71,28 @@ func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor } } } + +func cacheSpecifiedTimeout(specifiedTimeouts []ServerSpecifiedTimeoutConf) specifiedTimeoutCache { + cache := make(specifiedTimeoutCache, len(specifiedTimeouts)) + for _, st := range specifiedTimeouts { + if st.FullMethod != "" { + cache[st.FullMethod] = st.Timeout + } + } + + return cache +} + +func getTimeoutByUnaryServerInfo(info *grpc.UnaryServerInfo, defaultTimeout time.Duration, specifiedTimeout specifiedTimeoutCache) time.Duration { + if ts, ok := info.Server.(TimeoutStrategy); ok { + return ts.GetTimeoutByFullMethod(info.FullMethod, defaultTimeout) + } else if v, ok := specifiedTimeout[info.FullMethod]; ok { + return v + } + + return defaultTimeout +} + +type TimeoutStrategy interface { + GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration +} diff --git a/zrpc/internal/serverinterceptors/timeoutinterceptor_test.go b/zrpc/internal/serverinterceptors/timeoutinterceptor_test.go index 0e89eac6..fd9e4d14 100644 --- a/zrpc/internal/serverinterceptors/timeoutinterceptor_test.go +++ b/zrpc/internal/serverinterceptors/timeoutinterceptor_test.go @@ -12,6 +12,11 @@ import ( "google.golang.org/grpc/status" ) +var ( + deadlineExceededErr = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) + canceledErr = status.Error(codes.Canceled, context.Canceled.Error()) +) + func TestUnaryTimeoutInterceptor(t *testing.T) { interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10) _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ @@ -68,7 +73,7 @@ func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) { return nil, nil }) wg.Wait() - assert.EqualValues(t, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()), err) + assert.EqualValues(t, deadlineExceededErr, err) } func TestUnaryTimeoutInterceptor_cancel(t *testing.T) { @@ -88,5 +93,171 @@ func TestUnaryTimeoutInterceptor_cancel(t *testing.T) { }) wg.Wait() - assert.EqualValues(t, status.Error(codes.Canceled, context.Canceled.Error()), err) + assert.EqualValues(t, canceledErr, err) +} + +type tempServer struct { + timeout time.Duration +} + +func (s *tempServer) run(duration time.Duration) { + time.Sleep(duration) +} +func (s *tempServer) GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration { + if fullMethod == "/" { + return defaultTimeout + } + + return s.timeout +} + +func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) { + type args struct { + interceptorTimeout time.Duration + contextTimeout time.Duration + serverTimeout time.Duration + runTime time.Duration + + fullMethod string + } + var tests = []struct { + name string + args args + wantErr error + }{ + { + name: "do not timeout with interceptor timeout", + args: args{ + interceptorTimeout: time.Second, + contextTimeout: time.Second * 5, + serverTimeout: time.Second * 3, + runTime: time.Millisecond * 50, + fullMethod: "/", + }, + wantErr: nil, + }, + { + name: "do not timeout with timeout strategy", + args: args{ + interceptorTimeout: time.Second, + contextTimeout: time.Second * 5, + serverTimeout: time.Second * 3, + runTime: time.Second * 2, + fullMethod: "/2s", + }, + wantErr: nil, + }, + { + name: "timeout with interceptor timeout", + args: args{ + interceptorTimeout: time.Second, + contextTimeout: time.Second * 5, + serverTimeout: time.Second * 3, + runTime: time.Second * 2, + fullMethod: "/", + }, + wantErr: deadlineExceededErr, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout) + ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout) + defer cancel() + + svr := &tempServer{timeout: tt.args.serverTimeout} + + _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{ + Server: svr, + FullMethod: tt.args.fullMethod, + }, func(ctx context.Context, req interface{}) (interface{}, error) { + svr.run(tt.args.runTime) + return nil, nil + }) + t.Logf("error: %+v", err) + + assert.EqualValues(t, tt.wantErr, err) + }) + } +} + +func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) { + type args struct { + interceptorTimeout time.Duration + contextTimeout time.Duration + method string + methodTimeout time.Duration + runTime time.Duration + } + var tests = []struct { + name string + args args + wantErr error + }{ + { + name: "do not timeout without set timeout for full method", + args: args{ + interceptorTimeout: time.Second, + contextTimeout: time.Second * 5, + method: "/run", + runTime: time.Millisecond * 50, + }, + wantErr: nil, + }, + { + name: "do not timeout with set timeout for full method", + args: args{ + interceptorTimeout: time.Second, + contextTimeout: time.Second * 5, + method: "/run/do_not_timeout", + methodTimeout: time.Second * 3, + runTime: time.Second * 2, + }, + wantErr: nil, + }, + { + name: "timeout with set timeout for full method", + args: args{ + interceptorTimeout: time.Second, + contextTimeout: time.Second * 5, + method: "/run/timeout", + methodTimeout: time.Millisecond * 100, + runTime: time.Millisecond * 500, + }, + wantErr: deadlineExceededErr, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var specifiedTimeouts []ServerSpecifiedTimeoutConf + if tt.args.methodTimeout > 0 { + specifiedTimeouts = []ServerSpecifiedTimeoutConf{ + { + FullMethod: tt.args.method, + Timeout: tt.args.methodTimeout, + }, + } + } + + interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout, specifiedTimeouts...) + ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout) + defer cancel() + + _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{ + FullMethod: tt.args.method, + }, func(ctx context.Context, req interface{}) (interface{}, error) { + time.Sleep(tt.args.runTime) + return nil, nil + }) + t.Logf("error: %+v", err) + + assert.EqualValues(t, tt.wantErr, err) + }) + } } diff --git a/zrpc/server.go b/zrpc/server.go index 9cf5a873..9bf4e89b 100644 --- a/zrpc/server.go +++ b/zrpc/server.go @@ -131,8 +131,12 @@ func setupInterceptors(svr internal.Server, c RpcServerConf, metrics *stat.Metri } if c.Timeout > 0 { - svr.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor( - time.Duration(c.Timeout) * time.Millisecond)) + svr.AddUnaryInterceptors( + serverinterceptors.UnaryTimeoutInterceptor( + time.Duration(c.Timeout)*time.Millisecond, + c.SpecifiedTimeouts..., + ), + ) } if c.Auth { diff --git a/zrpc/server_test.go b/zrpc/server_test.go index d35f6c3a..af5ebc72 100644 --- a/zrpc/server_test.go +++ b/zrpc/server_test.go @@ -40,6 +40,12 @@ func TestServer_setupInterceptors(t *testing.T) { Prometheus: true, Breaker: true, }, + SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{ + { + FullMethod: "/foo", + Timeout: 5 * time.Second, + }, + }, } err = setupInterceptors(server, conf, new(stat.Metrics)) assert.Nil(t, err) @@ -75,6 +81,12 @@ func TestServer(t *testing.T) { Prometheus: true, Breaker: true, }, + SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{ + { + FullMethod: "/foo", + Timeout: time.Second, + }, + }, }, func(server *grpc.Server) { }) svr.AddOptions(grpc.ConnectionTimeout(time.Hour)) @@ -105,6 +117,7 @@ func TestServerError(t *testing.T) { Prometheus: true, Breaker: true, }, + SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{}, }, func(server *grpc.Server) { }) assert.NotNil(t, err) @@ -131,6 +144,7 @@ func TestServer_HasEtcd(t *testing.T) { Prometheus: true, Breaker: true, }, + SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{}, }, func(server *grpc.Server) { }) svr.AddOptions(grpc.ConnectionTimeout(time.Hour))