chore: refactor zrpc timeout (#3671)

master
Kevin Wan 1 year ago committed by GitHub
parent 842c4d81cc
commit 922efbfc2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -111,7 +111,7 @@ func SetClientSlowThreshold(threshold time.Duration) {
clientinterceptors.SetSlowThreshold(threshold) clientinterceptors.SetSlowThreshold(threshold)
} }
// WithTimeoutCallOption return a call option with given timeout. // WithCallTimeout return a call option with given timeout to make a method call.
func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption { func WithCallTimeout(timeout time.Duration) grpc.CallOption {
return clientinterceptors.WithTimeoutCallOption(timeout) return clientinterceptors.WithCallTimeout(timeout)
} }

@ -41,12 +41,12 @@ func dialer() func(context.Context, string) (net.Conn, error) {
func TestDepositServer_Deposit(t *testing.T) { func TestDepositServer_Deposit(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
amount float32 amount float32
timeoutCallOption time.Duration timeout time.Duration
res *mock.DepositResponse res *mock.DepositResponse
errCode codes.Code errCode codes.Code
errMsg string errMsg string
}{ }{
{ {
name: "invalid request with negative amount", name: "invalid request with negative amount",
@ -66,12 +66,12 @@ func TestDepositServer_Deposit(t *testing.T) {
errMsg: "context deadline exceeded", errMsg: "context deadline exceeded",
}, },
{ {
name: "valid request with timeout call option", name: "valid request with timeout call option",
amount: 2000.00, amount: 2000.00,
timeoutCallOption: time.Second * 3, timeout: time.Second * 3,
res: &mock.DepositResponse{Ok: true}, res: &mock.DepositResponse{Ok: true},
errCode: codes.OK, errCode: codes.OK,
errMsg: "", errMsg: "",
}, },
} }
@ -171,8 +171,8 @@ func TestDepositServer_Deposit(t *testing.T) {
err error err error
) )
if tt.timeoutCallOption > 0 { if tt.timeout > 0 {
response, err = cli.Deposit(ctx, request, WithTimeoutCallOption(tt.timeoutCallOption)) response, err = cli.Deposit(ctx, request, WithCallTimeout(tt.timeout))
} else { } else {
response, err = cli.Deposit(ctx, request) response, err = cli.Deposit(ctx, request)
} }

@ -17,8 +17,8 @@ type (
ServerMiddlewaresConf = internal.ServerMiddlewaresConf ServerMiddlewaresConf = internal.ServerMiddlewaresConf
// StatConf defines the stat config. // StatConf defines the stat config.
StatConf = internal.StatConf StatConf = internal.StatConf
// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method. // MethodTimeoutConf defines specified timeout for gRPC method.
ServerSpecifiedTimeoutConf = internal.ServerSpecifiedTimeoutConf MethodTimeoutConf = internal.MethodTimeoutConf
// A RpcClientConf is a rpc client config. // A RpcClientConf is a rpc client config.
RpcClientConf struct { RpcClientConf struct {
@ -48,7 +48,7 @@ type (
Health bool `json:",default=true"` Health bool `json:",default=true"`
Middlewares ServerMiddlewaresConf Middlewares ServerMiddlewaresConf
// setting specified timeout for gRPC method // setting specified timeout for gRPC method
SpecifiedTimeouts []ServerSpecifiedTimeoutConf `json:",optional"` MethodTimeouts []MethodTimeoutConf `json:",optional"`
} }
) )

@ -7,11 +7,17 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
) )
// TimeoutCallOption is a call option that controls timeout.
type TimeoutCallOption struct {
grpc.EmptyCallOption
timeout time.Duration
}
// TimeoutInterceptor is an interceptor that controls timeout. // TimeoutInterceptor is an interceptor that controls timeout.
func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor { func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
t := getTimeoutByCallOptions(opts, timeout) t := getTimeoutFromCallOptions(opts, timeout)
if t <= 0 { if t <= 0 {
return invoker(ctx, method, req, reply, cc, opts...) return invoker(ctx, method, req, reply, cc, opts...)
} }
@ -23,24 +29,19 @@ func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
} }
} }
func getTimeoutByCallOptions(callOptions []grpc.CallOption, defaultTimeout time.Duration) time.Duration { // WithCallTimeout returns a call option that controls method call timeout.
for _, callOption := range callOptions { func WithCallTimeout(timeout time.Duration) grpc.CallOption {
if o, ok := callOption.(TimeoutCallOption); ok { return TimeoutCallOption{
timeout: timeout,
}
}
func getTimeoutFromCallOptions(opts []grpc.CallOption, defaultTimeout time.Duration) time.Duration {
for _, opt := range opts {
if o, ok := opt.(TimeoutCallOption); ok {
return o.timeout return o.timeout
} }
} }
return defaultTimeout return defaultTimeout
} }
type TimeoutCallOption struct {
grpc.EmptyCallOption
timeout time.Duration
}
func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption {
return TimeoutCallOption{
timeout: timeout,
}
}

@ -114,7 +114,7 @@ func TestTimeoutInterceptor_TimeoutCallOption(t *testing.T) {
cc := new(grpc.ClientConn) cc := new(grpc.ClientConn)
var co []grpc.CallOption var co []grpc.CallOption
if tt.args.callOptionTimeout > 0 { if tt.args.callOptionTimeout > 0 {
co = append(co, WithTimeoutCallOption(tt.args.callOptionTimeout)) co = append(co, WithCallTimeout(tt.args.callOptionTimeout))
} }
err := interceptor(context.Background(), "/foo", nil, nil, cc, err := interceptor(context.Background(), "/foo", nil, nil, cc,

@ -25,5 +25,6 @@ type (
Breaker bool `json:",default=true"` Breaker bool `json:",default=true"`
} }
ServerSpecifiedTimeoutConf = serverinterceptors.ServerSpecifiedTimeoutConf // MethodTimeoutConf defines specified timeout for gRPC methods.
MethodTimeoutConf = serverinterceptors.MethodTimeoutConf
) )

@ -15,21 +15,22 @@ import (
) )
type ( type (
// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method. // MethodTimeoutConf defines specified timeout for gRPC method.
ServerSpecifiedTimeoutConf struct { MethodTimeoutConf struct {
FullMethod string FullMethod string
Timeout time.Duration Timeout time.Duration
} }
specifiedTimeoutCache map[string]time.Duration methodTimeouts map[string]time.Duration
) )
// UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests. // UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerSpecifiedTimeoutConf) grpc.UnaryServerInterceptor { func UnaryTimeoutInterceptor(timeout time.Duration,
cache := cacheSpecifiedTimeout(specifiedTimeouts) methodTimeouts ...MethodTimeoutConf) grpc.UnaryServerInterceptor {
timeouts := buildMethodTimeouts(methodTimeouts)
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, return func(ctx context.Context, req any, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (any, error) { handler grpc.UnaryHandler) (any, error) {
t := getTimeoutByUnaryServerInfo(info, timeout, cache) t := getTimeoutByUnaryServerInfo(info.FullMethod, timeouts, timeout)
ctx, cancel := context.WithTimeout(ctx, t) ctx, cancel := context.WithTimeout(ctx, t)
defer cancel() defer cancel()
@ -72,27 +73,22 @@ func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerS
} }
} }
func cacheSpecifiedTimeout(specifiedTimeouts []ServerSpecifiedTimeoutConf) specifiedTimeoutCache { func buildMethodTimeouts(timeouts []MethodTimeoutConf) methodTimeouts {
cache := make(specifiedTimeoutCache, len(specifiedTimeouts)) mt := make(methodTimeouts, len(timeouts))
for _, st := range specifiedTimeouts { for _, st := range timeouts {
if st.FullMethod != "" { if st.FullMethod != "" {
cache[st.FullMethod] = st.Timeout mt[st.FullMethod] = st.Timeout
} }
} }
return cache return mt
} }
func getTimeoutByUnaryServerInfo(info *grpc.UnaryServerInfo, defaultTimeout time.Duration, specifiedTimeout specifiedTimeoutCache) time.Duration { func getTimeoutByUnaryServerInfo(method string, timeouts methodTimeouts,
if ts, ok := info.Server.(TimeoutStrategy); ok { defaultTimeout time.Duration) time.Duration {
return ts.GetTimeoutByFullMethod(info.FullMethod, defaultTimeout) if v, ok := timeouts[method]; ok {
} else if v, ok := specifiedTimeout[info.FullMethod]; ok {
return v return v
} }
return defaultTimeout return defaultTimeout
} }
type TimeoutStrategy interface {
GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration
}

@ -103,13 +103,6 @@ type tempServer struct {
func (s *tempServer) run(duration time.Duration) { func (s *tempServer) run(duration time.Duration) {
time.Sleep(duration) time.Sleep(duration)
} }
func (s *tempServer) GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration {
if fullMethod == "/" {
return defaultTimeout
}
return s.timeout
}
func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) { func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
type args struct { type args struct {
@ -136,17 +129,6 @@ func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
}, },
wantErr: nil, wantErr: nil,
}, },
{
name: "do not timeout with timeout strategy",
args: args{
interceptorTimeout: time.Second,
contextTimeout: time.Second * 5,
serverTimeout: time.Second * 3,
runTime: time.Second * 2,
fullMethod: "/2s",
},
wantErr: nil,
},
{ {
name: "timeout with interceptor timeout", name: "timeout with interceptor timeout",
args: args{ args: args{
@ -235,9 +217,9 @@ func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
var specifiedTimeouts []ServerSpecifiedTimeoutConf var specifiedTimeouts []MethodTimeoutConf
if tt.args.methodTimeout > 0 { if tt.args.methodTimeout > 0 {
specifiedTimeouts = []ServerSpecifiedTimeoutConf{ specifiedTimeouts = []MethodTimeoutConf{
{ {
FullMethod: tt.args.method, FullMethod: tt.args.method,
Timeout: tt.args.methodTimeout, Timeout: tt.args.methodTimeout,

@ -131,12 +131,8 @@ func setupInterceptors(svr internal.Server, c RpcServerConf, metrics *stat.Metri
} }
if c.Timeout > 0 { if c.Timeout > 0 {
svr.AddUnaryInterceptors( svr.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor(
serverinterceptors.UnaryTimeoutInterceptor( time.Duration(c.Timeout)*time.Millisecond, c.MethodTimeouts...))
time.Duration(c.Timeout)*time.Millisecond,
c.SpecifiedTimeouts...,
),
)
} }
if c.Auth { if c.Auth {

@ -40,7 +40,7 @@ func TestServer_setupInterceptors(t *testing.T) {
Prometheus: true, Prometheus: true,
Breaker: true, Breaker: true,
}, },
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{ MethodTimeouts: []MethodTimeoutConf{
{ {
FullMethod: "/foo", FullMethod: "/foo",
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
@ -81,7 +81,7 @@ func TestServer(t *testing.T) {
Prometheus: true, Prometheus: true,
Breaker: true, Breaker: true,
}, },
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{ MethodTimeouts: []MethodTimeoutConf{
{ {
FullMethod: "/foo", FullMethod: "/foo",
Timeout: time.Second, Timeout: time.Second,
@ -117,7 +117,7 @@ func TestServerError(t *testing.T) {
Prometheus: true, Prometheus: true,
Breaker: true, Breaker: true,
}, },
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{}, MethodTimeouts: []MethodTimeoutConf{},
}, func(server *grpc.Server) { }, func(server *grpc.Server) {
}) })
assert.NotNil(t, err) assert.NotNil(t, err)
@ -144,7 +144,7 @@ func TestServer_HasEtcd(t *testing.T) {
Prometheus: true, Prometheus: true,
Breaker: true, Breaker: true,
}, },
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{}, MethodTimeouts: []MethodTimeoutConf{},
}, func(server *grpc.Server) { }, func(server *grpc.Server) {
}) })
svr.AddOptions(grpc.ConnectionTimeout(time.Hour)) svr.AddOptions(grpc.ConnectionTimeout(time.Hour))

Loading…
Cancel
Save