diff --git a/rpcx/internal/chainclientinterceptors_test.go b/rpcx/internal/chainclientinterceptors_test.go index 7ca6019e..a1432fe6 100644 --- a/rpcx/internal/chainclientinterceptors_test.go +++ b/rpcx/internal/chainclientinterceptors_test.go @@ -2,7 +2,6 @@ package internal import ( "context" - "sync/atomic" "testing" "github.com/stretchr/testify/assert" @@ -20,28 +19,105 @@ func TestWithUnaryClientInterceptors(t *testing.T) { } func TestChainStreamClientInterceptors_zero(t *testing.T) { + var vals []int interceptors := chainStreamClientInterceptors() _, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + vals = append(vals, 1) return nil, nil }) assert.Nil(t, err) + assert.ElementsMatch(t, []int{1}, vals) } func TestChainStreamClientInterceptors_one(t *testing.T) { - var called int32 + var vals []int interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) ( grpc.ClientStream, error) { - atomic.AddInt32(&called, 1) - return nil, nil + vals = append(vals, 1) + return streamer(ctx, desc, cc, method, opts...) }) _, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + vals = append(vals, 2) return nil, nil }) assert.Nil(t, err) - assert.Equal(t, int32(1), atomic.LoadInt32(&called)) + assert.ElementsMatch(t, []int{1, 2}, vals) +} + +func TestChainStreamClientInterceptors_more(t *testing.T) { + var vals []int + interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc, + cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) ( + grpc.ClientStream, error) { + vals = append(vals, 1) + return streamer(ctx, desc, cc, method, opts...) + }, func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, + streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + vals = append(vals, 2) + return streamer(ctx, desc, cc, method, opts...) + }) + _, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo", + func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, + opts ...grpc.CallOption) (grpc.ClientStream, error) { + vals = append(vals, 3) + return nil, nil + }) + assert.Nil(t, err) + assert.ElementsMatch(t, []int{1, 2, 3}, vals) +} + +func TestWithUnaryClientInterceptors_zero(t *testing.T) { + var vals []int + interceptors := chainUnaryClientInterceptors() + err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn), + func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, + opts ...grpc.CallOption) error { + vals = append(vals, 1) + return nil + }) + assert.Nil(t, err) + assert.ElementsMatch(t, []int{1}, vals) +} + +func TestWithUnaryClientInterceptors_one(t *testing.T) { + var vals []int + interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req, + reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + vals = append(vals, 1) + return invoker(ctx, method, req, reply, cc, opts...) + }) + err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn), + func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, + opts ...grpc.CallOption) error { + vals = append(vals, 2) + return nil + }) + assert.Nil(t, err) + assert.ElementsMatch(t, []int{1, 2}, vals) +} + +func TestWithUnaryClientInterceptors_more(t *testing.T) { + var vals []int + interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req, + reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + vals = append(vals, 1) + return invoker(ctx, method, req, reply, cc, opts...) + }, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + vals = append(vals, 2) + return invoker(ctx, method, req, reply, cc, opts...) + }) + err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn), + func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, + opts ...grpc.CallOption) error { + vals = append(vals, 3) + return nil + }) + assert.Nil(t, err) + assert.ElementsMatch(t, []int{1, 2, 3}, vals) } diff --git a/rpcx/internal/chainserverinterceptors_test.go b/rpcx/internal/chainserverinterceptors_test.go new file mode 100644 index 00000000..04704e09 --- /dev/null +++ b/rpcx/internal/chainserverinterceptors_test.go @@ -0,0 +1,111 @@ +package internal + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" +) + +func TestWithStreamServerInterceptors(t *testing.T) { + opts := WithStreamServerInterceptors() + assert.NotNil(t, opts) +} + +func TestWithUnaryServerInterceptors(t *testing.T) { + opts := WithUnaryServerInterceptors() + assert.NotNil(t, opts) +} + +func TestChainStreamServerInterceptors_zero(t *testing.T) { + var vals []int + interceptors := chainStreamServerInterceptors() + err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error { + vals = append(vals, 1) + return nil + }) + assert.Nil(t, err) + assert.ElementsMatch(t, []int{1}, vals) +} + +func TestChainStreamServerInterceptors_one(t *testing.T) { + var vals []int + interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream, + info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + vals = append(vals, 1) + return handler(srv, ss) + }) + err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error { + vals = append(vals, 2) + return nil + }) + assert.Nil(t, err) + assert.ElementsMatch(t, []int{1, 2}, vals) +} + +func TestChainStreamServerInterceptors_more(t *testing.T) { + var vals []int + interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream, + info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + vals = append(vals, 1) + return handler(srv, ss) + }, func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + vals = append(vals, 2) + return handler(srv, ss) + }) + err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error { + vals = append(vals, 3) + return nil + }) + assert.Nil(t, err) + assert.ElementsMatch(t, []int{1, 2, 3}, vals) +} + +func TestChainUnaryServerInterceptors_zero(t *testing.T) { + var vals []int + interceptors := chainUnaryServerInterceptors() + _, err := interceptors(context.Background(), nil, nil, + func(ctx context.Context, req interface{}) (interface{}, error) { + vals = append(vals, 1) + return nil, nil + }) + assert.Nil(t, err) + assert.ElementsMatch(t, []int{1}, vals) +} + +func TestChainUnaryServerInterceptors_one(t *testing.T) { + var vals []int + interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{}, + info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + vals = append(vals, 1) + return handler(ctx, req) + }) + _, err := interceptors(context.Background(), nil, nil, + func(ctx context.Context, req interface{}) (interface{}, error) { + vals = append(vals, 2) + return nil, nil + }) + assert.Nil(t, err) + assert.ElementsMatch(t, []int{1, 2}, vals) +} + +func TestChainUnaryServerInterceptors_more(t *testing.T) { + var vals []int + interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{}, + info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + vals = append(vals, 1) + return handler(ctx, req) + }, func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler) (resp interface{}, err error) { + vals = append(vals, 2) + return handler(ctx, req) + }) + _, err := interceptors(context.Background(), nil, nil, + func(ctx context.Context, req interface{}) (interface{}, error) { + vals = append(vals, 3) + return nil, nil + }) + assert.Nil(t, err) + assert.ElementsMatch(t, []int{1, 2, 3}, vals) +}