You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
246 lines
6.3 KiB
Go
246 lines
6.3 KiB
Go
package serverinterceptors
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
var (
|
|
deadlineExceededErr = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
|
|
canceledErr = status.Error(codes.Canceled, context.Canceled.Error())
|
|
)
|
|
|
|
func TestUnaryTimeoutInterceptor(t *testing.T) {
|
|
interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
|
|
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
|
FullMethod: "/",
|
|
}, func(ctx context.Context, req any) (any, error) {
|
|
return nil, nil
|
|
})
|
|
assert.Nil(t, err)
|
|
}
|
|
|
|
func TestUnaryTimeoutInterceptor_panic(t *testing.T) {
|
|
interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
|
|
assert.Panics(t, func() {
|
|
_, _ = interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
|
FullMethod: "/",
|
|
}, func(ctx context.Context, req any) (any, error) {
|
|
panic("any")
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
|
|
const timeout = time.Millisecond * 10
|
|
interceptor := UnaryTimeoutInterceptor(timeout)
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
|
defer cancel()
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
|
FullMethod: "/",
|
|
}, func(ctx context.Context, req any) (any, error) {
|
|
defer wg.Done()
|
|
tm, ok := ctx.Deadline()
|
|
assert.True(t, ok)
|
|
assert.True(t, tm.Before(time.Now().Add(timeout+time.Millisecond)))
|
|
return nil, nil
|
|
})
|
|
wg.Wait()
|
|
assert.Nil(t, err)
|
|
}
|
|
|
|
func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) {
|
|
const timeout = time.Millisecond * 10
|
|
interceptor := UnaryTimeoutInterceptor(timeout)
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
|
defer cancel()
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
|
FullMethod: "/",
|
|
}, func(ctx context.Context, req any) (any, error) {
|
|
defer wg.Done()
|
|
time.Sleep(time.Millisecond * 50)
|
|
return nil, nil
|
|
})
|
|
wg.Wait()
|
|
assert.EqualValues(t, deadlineExceededErr, err)
|
|
}
|
|
|
|
func TestUnaryTimeoutInterceptor_cancel(t *testing.T) {
|
|
const timeout = time.Minute * 10
|
|
interceptor := UnaryTimeoutInterceptor(timeout)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
|
FullMethod: "/",
|
|
}, func(ctx context.Context, req any) (any, error) {
|
|
defer wg.Done()
|
|
time.Sleep(time.Millisecond * 50)
|
|
return nil, nil
|
|
})
|
|
|
|
wg.Wait()
|
|
assert.EqualValues(t, canceledErr, err)
|
|
}
|
|
|
|
type tempServer struct {
|
|
timeout time.Duration
|
|
}
|
|
|
|
func (s *tempServer) run(duration time.Duration) {
|
|
time.Sleep(duration)
|
|
}
|
|
|
|
func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
|
|
type args struct {
|
|
interceptorTimeout time.Duration
|
|
contextTimeout time.Duration
|
|
serverTimeout time.Duration
|
|
runTime time.Duration
|
|
|
|
fullMethod string
|
|
}
|
|
var tests = []struct {
|
|
name string
|
|
args args
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "do not timeout with interceptor timeout",
|
|
args: args{
|
|
interceptorTimeout: time.Second,
|
|
contextTimeout: time.Second * 5,
|
|
serverTimeout: time.Second * 3,
|
|
runTime: time.Millisecond * 50,
|
|
fullMethod: "/",
|
|
},
|
|
wantErr: nil,
|
|
},
|
|
{
|
|
name: "timeout with interceptor timeout",
|
|
args: args{
|
|
interceptorTimeout: time.Second,
|
|
contextTimeout: time.Second * 5,
|
|
serverTimeout: time.Second * 3,
|
|
runTime: time.Second * 2,
|
|
fullMethod: "/",
|
|
},
|
|
wantErr: deadlineExceededErr,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout)
|
|
ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout)
|
|
defer cancel()
|
|
|
|
svr := &tempServer{timeout: tt.args.serverTimeout}
|
|
|
|
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
|
Server: svr,
|
|
FullMethod: tt.args.fullMethod,
|
|
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
svr.run(tt.args.runTime)
|
|
return nil, nil
|
|
})
|
|
t.Logf("error: %+v", err)
|
|
|
|
assert.EqualValues(t, tt.wantErr, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) {
|
|
type args struct {
|
|
interceptorTimeout time.Duration
|
|
contextTimeout time.Duration
|
|
method string
|
|
methodTimeout time.Duration
|
|
runTime time.Duration
|
|
}
|
|
var tests = []struct {
|
|
name string
|
|
args args
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "do not timeout without set timeout for full method",
|
|
args: args{
|
|
interceptorTimeout: time.Second,
|
|
contextTimeout: time.Second * 5,
|
|
method: "/run",
|
|
runTime: time.Millisecond * 50,
|
|
},
|
|
wantErr: nil,
|
|
},
|
|
{
|
|
name: "do not timeout with set timeout for full method",
|
|
args: args{
|
|
interceptorTimeout: time.Second,
|
|
contextTimeout: time.Second * 5,
|
|
method: "/run/do_not_timeout",
|
|
methodTimeout: time.Second * 3,
|
|
runTime: time.Second * 2,
|
|
},
|
|
wantErr: nil,
|
|
},
|
|
{
|
|
name: "timeout with set timeout for full method",
|
|
args: args{
|
|
interceptorTimeout: time.Second,
|
|
contextTimeout: time.Second * 5,
|
|
method: "/run/timeout",
|
|
methodTimeout: time.Millisecond * 100,
|
|
runTime: time.Millisecond * 500,
|
|
},
|
|
wantErr: deadlineExceededErr,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var specifiedTimeouts []MethodTimeoutConf
|
|
if tt.args.methodTimeout > 0 {
|
|
specifiedTimeouts = []MethodTimeoutConf{
|
|
{
|
|
FullMethod: tt.args.method,
|
|
Timeout: tt.args.methodTimeout,
|
|
},
|
|
}
|
|
}
|
|
|
|
interceptor := UnaryTimeoutInterceptor(tt.args.interceptorTimeout, specifiedTimeouts...)
|
|
ctx, cancel := context.WithTimeout(context.Background(), tt.args.contextTimeout)
|
|
defer cancel()
|
|
|
|
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
|
FullMethod: tt.args.method,
|
|
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
time.Sleep(tt.args.runTime)
|
|
return nil, nil
|
|
})
|
|
t.Logf("error: %+v", err)
|
|
|
|
assert.EqualValues(t, tt.wantErr, err)
|
|
})
|
|
}
|
|
}
|