feat: support the specified timeout of rpc methods (#2742)

Co-authored-by: hanzijian <hanzijian@52tt.com>
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
master
vankillua 1 year ago committed by GitHub
parent 2a335c7608
commit 842c4d81cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -110,3 +110,8 @@ func DontLogClientContentForMethod(method string) {
func SetClientSlowThreshold(threshold time.Duration) {
clientinterceptors.SetSlowThreshold(threshold)
}
// WithTimeoutCallOption return a call option with given timeout.
func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption {
return clientinterceptors.WithTimeoutCallOption(timeout)
}

@ -41,32 +41,37 @@ func dialer() func(context.Context, string) (net.Conn, error) {
func TestDepositServer_Deposit(t *testing.T) {
tests := []struct {
name string
amount float32
res *mock.DepositResponse
errCode codes.Code
errMsg string
name string
amount float32
timeoutCallOption time.Duration
res *mock.DepositResponse
errCode codes.Code
errMsg string
}{
{
"invalid request with negative amount",
-1.11,
nil,
codes.InvalidArgument,
fmt.Sprintf("cannot deposit %v", -1.11),
name: "invalid request with negative amount",
amount: -1.11,
errCode: codes.InvalidArgument,
errMsg: fmt.Sprintf("cannot deposit %v", -1.11),
},
{
"valid request with non negative amount",
0.00,
&mock.DepositResponse{Ok: true},
codes.OK,
"",
name: "valid request with non negative amount",
res: &mock.DepositResponse{Ok: true},
errCode: codes.OK,
},
{
"valid request with long handling time",
2000.00,
nil,
codes.DeadlineExceeded,
"context deadline exceeded",
name: "valid request with long handling time",
amount: 2000.00,
errCode: codes.DeadlineExceeded,
errMsg: "context deadline exceeded",
},
{
name: "valid request with timeout call option",
amount: 2000.00,
timeoutCallOption: time.Second * 3,
res: &mock.DepositResponse{Ok: true},
errCode: codes.OK,
errMsg: "",
},
}
@ -156,9 +161,22 @@ func TestDepositServer_Deposit(t *testing.T) {
client := client
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
cli := mock.NewDepositServiceClient(client.Conn())
request := &mock.DepositRequest{Amount: tt.amount}
response, err := cli.Deposit(context.Background(), request)
var (
ctx = context.Background()
response *mock.DepositResponse
err error
)
if tt.timeoutCallOption > 0 {
response, err = cli.Deposit(ctx, request, WithTimeoutCallOption(tt.timeoutCallOption))
} else {
response, err = cli.Deposit(ctx, request)
}
if response != nil {
assert.True(t, len(response.String()) > 0)
if response.GetOk() != tt.res.GetOk() {

@ -17,6 +17,8 @@ type (
ServerMiddlewaresConf = internal.ServerMiddlewaresConf
// StatConf defines the stat config.
StatConf = internal.StatConf
// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
ServerSpecifiedTimeoutConf = internal.ServerSpecifiedTimeoutConf
// A RpcClientConf is a rpc client config.
RpcClientConf struct {
@ -45,6 +47,8 @@ type (
// grpc health check switch
Health bool `json:",default=true"`
Middlewares ServerMiddlewaresConf
// setting specified timeout for gRPC method
SpecifiedTimeouts []ServerSpecifiedTimeoutConf `json:",optional"`
}
)

@ -11,13 +11,36 @@ import (
func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if timeout <= 0 {
t := getTimeoutByCallOptions(opts, timeout)
if t <= 0 {
return invoker(ctx, method, req, reply, cc, opts...)
}
ctx, cancel := context.WithTimeout(ctx, timeout)
ctx, cancel := context.WithTimeout(ctx, t)
defer cancel()
return invoker(ctx, method, req, reply, cc, opts...)
}
}
func getTimeoutByCallOptions(callOptions []grpc.CallOption, defaultTimeout time.Duration) time.Duration {
for _, callOption := range callOptions {
if o, ok := callOption.(TimeoutCallOption); ok {
return o.timeout
}
}
return defaultTimeout
}
type TimeoutCallOption struct {
grpc.EmptyCallOption
timeout time.Duration
}
func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption {
return TimeoutCallOption{
timeout: timeout,
}
}

@ -66,3 +66,74 @@ func TestTimeoutInterceptor_panic(t *testing.T) {
})
}
}
func TestTimeoutInterceptor_TimeoutCallOption(t *testing.T) {
type args struct {
interceptorTimeout time.Duration
callOptionTimeout time.Duration
runTime time.Duration
}
var tests = []struct {
name string
args args
wantErr error
}{
{
name: "do not timeout without call option timeout",
args: args{
interceptorTimeout: time.Second,
runTime: time.Millisecond * 50,
},
wantErr: nil,
},
{
name: "timeout without call option timeout",
args: args{
interceptorTimeout: time.Second,
runTime: time.Second * 2,
},
wantErr: context.DeadlineExceeded,
},
{
name: "do not timeout with call option timeout",
args: args{
interceptorTimeout: time.Second,
callOptionTimeout: time.Second * 3,
runTime: time.Second * 2,
},
wantErr: nil,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
interceptor := TimeoutInterceptor(tt.args.interceptorTimeout)
cc := new(grpc.ClientConn)
var co []grpc.CallOption
if tt.args.callOptionTimeout > 0 {
co = append(co, WithTimeoutCallOption(tt.args.callOptionTimeout))
}
err := interceptor(context.Background(), "/foo", nil, nil, cc,
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
timer := time.NewTimer(tt.args.runTime)
defer timer.Stop()
select {
case <-timer.C:
return nil
case <-ctx.Done():
return ctx.Err()
}
}, co...,
)
t.Logf("error: %+v", err)
assert.EqualValues(t, tt.wantErr, err)
})
}
}

@ -24,4 +24,6 @@ type (
Prometheus bool `json:",default=true"`
Breaker bool `json:",default=true"`
}
ServerSpecifiedTimeoutConf = serverinterceptors.ServerSpecifiedTimeoutConf
)

@ -14,11 +14,23 @@ import (
"google.golang.org/grpc/status"
)
type (
// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
ServerSpecifiedTimeoutConf struct {
FullMethod string
Timeout time.Duration
}
specifiedTimeoutCache map[string]time.Duration
)
// UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerSpecifiedTimeoutConf) grpc.UnaryServerInterceptor {
cache := cacheSpecifiedTimeout(specifiedTimeouts)
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (any, error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
t := getTimeoutByUnaryServerInfo(info, timeout, cache)
ctx, cancel := context.WithTimeout(ctx, t)
defer cancel()
var resp any
@ -59,3 +71,28 @@ func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor
}
}
}
func cacheSpecifiedTimeout(specifiedTimeouts []ServerSpecifiedTimeoutConf) specifiedTimeoutCache {
cache := make(specifiedTimeoutCache, len(specifiedTimeouts))
for _, st := range specifiedTimeouts {
if st.FullMethod != "" {
cache[st.FullMethod] = st.Timeout
}
}
return cache
}
func getTimeoutByUnaryServerInfo(info *grpc.UnaryServerInfo, defaultTimeout time.Duration, specifiedTimeout specifiedTimeoutCache) time.Duration {
if ts, ok := info.Server.(TimeoutStrategy); ok {
return ts.GetTimeoutByFullMethod(info.FullMethod, defaultTimeout)
} else if v, ok := specifiedTimeout[info.FullMethod]; ok {
return v
}
return defaultTimeout
}
type TimeoutStrategy interface {
GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration
}

@ -12,6 +12,11 @@ import (
"google.golang.org/grpc/status"
)
var (
deadlineExceededErr = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
canceledErr = status.Error(codes.Canceled, context.Canceled.Error())
)
func TestUnaryTimeoutInterceptor(t *testing.T) {
interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
@ -68,7 +73,7 @@ func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) {
return nil, nil
})
wg.Wait()
assert.EqualValues(t, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()), err)
assert.EqualValues(t, deadlineExceededErr, err)
}
func TestUnaryTimeoutInterceptor_cancel(t *testing.T) {
@ -88,5 +93,171 @@ func TestUnaryTimeoutInterceptor_cancel(t *testing.T) {
})
wg.Wait()
assert.EqualValues(t, status.Error(codes.Canceled, context.Canceled.Error()), err)
assert.EqualValues(t, canceledErr, err)
}
type tempServer struct {
timeout time.Duration
}
func (s *tempServer) run(duration time.Duration) {
time.Sleep(duration)
}
func (s *tempServer) GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration {
if fullMethod == "/" {
return defaultTimeout
}
return s.timeout
}
func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
type args struct {
interceptorTimeout time.Duration
contextTimeout time.Duration
serverTimeout time.Duration
runTime time.Duration
fullMethod string
}
var tests = []struct {
name string
args args
wantErr error
}{
{
name: "do not timeout with interceptor timeout",
args: args{
interceptorTimeout: time.Second,
contextTimeout: time.Second * 5,
serverTimeout: time.Second * 3,
runTime: time.Millisecond * 50,
fullMethod: "/",
},
wantErr: nil,
},
{
name: "do not timeout with timeout strategy",
args: args{
interceptorTimeout: time.Second,
contextTimeout: time.Second * 5,
serverTimeout: time.Second * 3,
runTime: time.Second * 2,
fullMethod: "/2s",
},
wantErr: nil,
},
{
name: "timeout with interceptor timeout",
args: args{
interceptorTimeout: time.Second,
contextTimeout: time.Second * 5,
serverTimeout: time.Second * 3,
runTime: time.Second * 2,
fullMethod: "/",
},
wantErr: deadlineExceededErr,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout)
ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout)
defer cancel()
svr := &tempServer{timeout: tt.args.serverTimeout}
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
Server: svr,
FullMethod: tt.args.fullMethod,
}, func(ctx context.Context, req interface{}) (interface{}, error) {
svr.run(tt.args.runTime)
return nil, nil
})
t.Logf("error: %+v", err)
assert.EqualValues(t, tt.wantErr, err)
})
}
}
func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) {
type args struct {
interceptorTimeout time.Duration
contextTimeout time.Duration
method string
methodTimeout time.Duration
runTime time.Duration
}
var tests = []struct {
name string
args args
wantErr error
}{
{
name: "do not timeout without set timeout for full method",
args: args{
interceptorTimeout: time.Second,
contextTimeout: time.Second * 5,
method: "/run",
runTime: time.Millisecond * 50,
},
wantErr: nil,
},
{
name: "do not timeout with set timeout for full method",
args: args{
interceptorTimeout: time.Second,
contextTimeout: time.Second * 5,
method: "/run/do_not_timeout",
methodTimeout: time.Second * 3,
runTime: time.Second * 2,
},
wantErr: nil,
},
{
name: "timeout with set timeout for full method",
args: args{
interceptorTimeout: time.Second,
contextTimeout: time.Second * 5,
method: "/run/timeout",
methodTimeout: time.Millisecond * 100,
runTime: time.Millisecond * 500,
},
wantErr: deadlineExceededErr,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var specifiedTimeouts []ServerSpecifiedTimeoutConf
if tt.args.methodTimeout > 0 {
specifiedTimeouts = []ServerSpecifiedTimeoutConf{
{
FullMethod: tt.args.method,
Timeout: tt.args.methodTimeout,
},
}
}
interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout, specifiedTimeouts...)
ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout)
defer cancel()
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
FullMethod: tt.args.method,
}, func(ctx context.Context, req interface{}) (interface{}, error) {
time.Sleep(tt.args.runTime)
return nil, nil
})
t.Logf("error: %+v", err)
assert.EqualValues(t, tt.wantErr, err)
})
}
}

@ -131,8 +131,12 @@ func setupInterceptors(svr internal.Server, c RpcServerConf, metrics *stat.Metri
}
if c.Timeout > 0 {
svr.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor(
time.Duration(c.Timeout) * time.Millisecond))
svr.AddUnaryInterceptors(
serverinterceptors.UnaryTimeoutInterceptor(
time.Duration(c.Timeout)*time.Millisecond,
c.SpecifiedTimeouts...,
),
)
}
if c.Auth {

@ -40,6 +40,12 @@ func TestServer_setupInterceptors(t *testing.T) {
Prometheus: true,
Breaker: true,
},
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{
{
FullMethod: "/foo",
Timeout: 5 * time.Second,
},
},
}
err = setupInterceptors(server, conf, new(stat.Metrics))
assert.Nil(t, err)
@ -75,6 +81,12 @@ func TestServer(t *testing.T) {
Prometheus: true,
Breaker: true,
},
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{
{
FullMethod: "/foo",
Timeout: time.Second,
},
},
}, func(server *grpc.Server) {
})
svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
@ -105,6 +117,7 @@ func TestServerError(t *testing.T) {
Prometheus: true,
Breaker: true,
},
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{},
}, func(server *grpc.Server) {
})
assert.NotNil(t, err)
@ -131,6 +144,7 @@ func TestServer_HasEtcd(t *testing.T) {
Prometheus: true,
Breaker: true,
},
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{},
}, func(server *grpc.Server) {
})
svr.AddOptions(grpc.ConnectionTimeout(time.Hour))

Loading…
Cancel
Save