diff --git a/zrpc/internal/chainclientinterceptors.go b/zrpc/internal/chainclientinterceptors.go index bd3f2d92..c557dd2f 100644 --- a/zrpc/internal/chainclientinterceptors.go +++ b/zrpc/internal/chainclientinterceptors.go @@ -1,83 +1,13 @@ package internal import ( - "context" - "google.golang.org/grpc" ) func WithStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.DialOption { - return grpc.WithStreamInterceptor(chainStreamClientInterceptors(interceptors...)) + return grpc.WithChainStreamInterceptor(interceptors...) } func WithUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.DialOption { - return grpc.WithUnaryInterceptor(chainUnaryClientInterceptors(interceptors...)) -} - -func chainStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor { - switch len(interceptors) { - case 0: - return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, - streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { - return streamer(ctx, desc, cc, method, opts...) - } - case 1: - return interceptors[0] - default: - last := len(interceptors) - 1 - return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, - method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { - var chainStreamer grpc.Streamer - var current int - - chainStreamer = func(curCtx context.Context, curDesc *grpc.StreamDesc, curCc *grpc.ClientConn, - curMethod string, curOpts ...grpc.CallOption) (grpc.ClientStream, error) { - if current == last { - return streamer(curCtx, curDesc, curCc, curMethod, curOpts...) - } - - current++ - clientStream, err := interceptors[current](curCtx, curDesc, curCc, curMethod, chainStreamer, curOpts...) - current-- - - return clientStream, err - } - - return interceptors[0](ctx, desc, cc, method, chainStreamer, opts...) - } - } -} - -func chainUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor { - switch len(interceptors) { - case 0: - return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, - invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - return invoker(ctx, method, req, reply, cc, opts...) - } - case 1: - return interceptors[0] - default: - last := len(interceptors) - 1 - return func(ctx context.Context, method string, req, reply interface{}, - cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - var chainInvoker grpc.UnaryInvoker - var current int - - chainInvoker = func(curCtx context.Context, curMethod string, curReq, curReply interface{}, - curCc *grpc.ClientConn, curOpts ...grpc.CallOption) error { - if current == last { - return invoker(curCtx, curMethod, curReq, curReply, curCc, curOpts...) - } - - current++ - err := interceptors[current](curCtx, curMethod, curReq, curReply, curCc, chainInvoker, curOpts...) - current-- - - return err - } - - return interceptors[0](ctx, method, req, reply, cc, chainInvoker, opts...) - } - } -} + return grpc.WithChainUnaryInterceptor(interceptors...) +} \ No newline at end of file diff --git a/zrpc/internal/chainclientinterceptors_test.go b/zrpc/internal/chainclientinterceptors_test.go index a1432fe6..b1957c19 100644 --- a/zrpc/internal/chainclientinterceptors_test.go +++ b/zrpc/internal/chainclientinterceptors_test.go @@ -1,11 +1,9 @@ package internal import ( - "context" "testing" "github.com/stretchr/testify/assert" - "google.golang.org/grpc" ) func TestWithStreamClientInterceptors(t *testing.T) { @@ -16,108 +14,4 @@ func TestWithStreamClientInterceptors(t *testing.T) { func TestWithUnaryClientInterceptors(t *testing.T) { opts := WithUnaryClientInterceptors() assert.NotNil(t, opts) -} - -func TestChainStreamClientInterceptors_zero(t *testing.T) { - var vals []int - interceptors := chainStreamClientInterceptors() - _, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo", - func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, - opts ...grpc.CallOption) (grpc.ClientStream, error) { - vals = append(vals, 1) - return nil, nil - }) - assert.Nil(t, err) - assert.ElementsMatch(t, []int{1}, vals) -} - -func TestChainStreamClientInterceptors_one(t *testing.T) { - var vals []int - interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc, - cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) ( - grpc.ClientStream, error) { - vals = append(vals, 1) - return streamer(ctx, desc, cc, method, opts...) - }) - _, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo", - func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, - opts ...grpc.CallOption) (grpc.ClientStream, error) { - vals = append(vals, 2) - return nil, nil - }) - assert.Nil(t, err) - assert.ElementsMatch(t, []int{1, 2}, vals) -} - -func TestChainStreamClientInterceptors_more(t *testing.T) { - var vals []int - interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc, - cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) ( - grpc.ClientStream, error) { - vals = append(vals, 1) - return streamer(ctx, desc, cc, method, opts...) - }, func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, - streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { - vals = append(vals, 2) - return streamer(ctx, desc, cc, method, opts...) - }) - _, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo", - func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, - opts ...grpc.CallOption) (grpc.ClientStream, error) { - vals = append(vals, 3) - return nil, nil - }) - assert.Nil(t, err) - assert.ElementsMatch(t, []int{1, 2, 3}, vals) -} - -func TestWithUnaryClientInterceptors_zero(t *testing.T) { - var vals []int - interceptors := chainUnaryClientInterceptors() - err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn), - func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, - opts ...grpc.CallOption) error { - vals = append(vals, 1) - return nil - }) - assert.Nil(t, err) - assert.ElementsMatch(t, []int{1}, vals) -} - -func TestWithUnaryClientInterceptors_one(t *testing.T) { - var vals []int - interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req, - reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - vals = append(vals, 1) - return invoker(ctx, method, req, reply, cc, opts...) - }) - err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn), - func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, - opts ...grpc.CallOption) error { - vals = append(vals, 2) - return nil - }) - assert.Nil(t, err) - assert.ElementsMatch(t, []int{1, 2}, vals) -} - -func TestWithUnaryClientInterceptors_more(t *testing.T) { - var vals []int - interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req, - reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - vals = append(vals, 1) - return invoker(ctx, method, req, reply, cc, opts...) - }, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, - invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - vals = append(vals, 2) - return invoker(ctx, method, req, reply, cc, opts...) - }) - err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn), - func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, - opts ...grpc.CallOption) error { - vals = append(vals, 3) - return nil - }) - assert.Nil(t, err) - assert.ElementsMatch(t, []int{1, 2, 3}, vals) -} +} \ No newline at end of file diff --git a/zrpc/internal/chainserverinterceptors.go b/zrpc/internal/chainserverinterceptors.go index 73768944..89fd81d4 100644 --- a/zrpc/internal/chainserverinterceptors.go +++ b/zrpc/internal/chainserverinterceptors.go @@ -1,81 +1,13 @@ package internal import ( - "context" - "google.golang.org/grpc" ) func WithStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption { - return grpc.StreamInterceptor(chainStreamServerInterceptors(interceptors...)) + return grpc.ChainStreamInterceptor(interceptors...) } func WithUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption { - return grpc.UnaryInterceptor(chainUnaryServerInterceptors(interceptors...)) -} - -func chainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor { - switch len(interceptors) { - case 0: - return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, - handler grpc.StreamHandler) error { - return handler(srv, stream) - } - case 1: - return interceptors[0] - default: - last := len(interceptors) - 1 - return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, - handler grpc.StreamHandler) error { - var chainHandler grpc.StreamHandler - var current int - - chainHandler = func(curSrv interface{}, curStream grpc.ServerStream) error { - if current == last { - return handler(curSrv, curStream) - } - - current++ - err := interceptors[current](curSrv, curStream, info, chainHandler) - current-- - - return err - } - - return interceptors[0](srv, stream, info, chainHandler) - } - } -} - -func chainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { - switch len(interceptors) { - case 0: - return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( - interface{}, error) { - return handler(ctx, req) - } - case 1: - return interceptors[0] - default: - last := len(interceptors) - 1 - return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( - interface{}, error) { - var chainHandler grpc.UnaryHandler - var current int - - chainHandler = func(curCtx context.Context, curReq interface{}) (interface{}, error) { - if current == last { - return handler(curCtx, curReq) - } - - current++ - resp, err := interceptors[current](curCtx, curReq, info, chainHandler) - current-- - - return resp, err - } - - return interceptors[0](ctx, req, info, chainHandler) - } - } -} + return grpc.ChainUnaryInterceptor(interceptors...) +} \ No newline at end of file diff --git a/zrpc/internal/chainserverinterceptors_test.go b/zrpc/internal/chainserverinterceptors_test.go index 04704e09..05b26a0c 100644 --- a/zrpc/internal/chainserverinterceptors_test.go +++ b/zrpc/internal/chainserverinterceptors_test.go @@ -1,11 +1,9 @@ package internal import ( - "context" "testing" "github.com/stretchr/testify/assert" - "google.golang.org/grpc" ) func TestWithStreamServerInterceptors(t *testing.T) { @@ -16,96 +14,4 @@ func TestWithStreamServerInterceptors(t *testing.T) { func TestWithUnaryServerInterceptors(t *testing.T) { opts := WithUnaryServerInterceptors() assert.NotNil(t, opts) -} - -func TestChainStreamServerInterceptors_zero(t *testing.T) { - var vals []int - interceptors := chainStreamServerInterceptors() - err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error { - vals = append(vals, 1) - return nil - }) - assert.Nil(t, err) - assert.ElementsMatch(t, []int{1}, vals) -} - -func TestChainStreamServerInterceptors_one(t *testing.T) { - var vals []int - interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream, - info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - vals = append(vals, 1) - return handler(srv, ss) - }) - err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error { - vals = append(vals, 2) - return nil - }) - assert.Nil(t, err) - assert.ElementsMatch(t, []int{1, 2}, vals) -} - -func TestChainStreamServerInterceptors_more(t *testing.T) { - var vals []int - interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream, - info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - vals = append(vals, 1) - return handler(srv, ss) - }, func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - vals = append(vals, 2) - return handler(srv, ss) - }) - err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error { - vals = append(vals, 3) - return nil - }) - assert.Nil(t, err) - assert.ElementsMatch(t, []int{1, 2, 3}, vals) -} - -func TestChainUnaryServerInterceptors_zero(t *testing.T) { - var vals []int - interceptors := chainUnaryServerInterceptors() - _, err := interceptors(context.Background(), nil, nil, - func(ctx context.Context, req interface{}) (interface{}, error) { - vals = append(vals, 1) - return nil, nil - }) - assert.Nil(t, err) - assert.ElementsMatch(t, []int{1}, vals) -} - -func TestChainUnaryServerInterceptors_one(t *testing.T) { - var vals []int - interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{}, - info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { - vals = append(vals, 1) - return handler(ctx, req) - }) - _, err := interceptors(context.Background(), nil, nil, - func(ctx context.Context, req interface{}) (interface{}, error) { - vals = append(vals, 2) - return nil, nil - }) - assert.Nil(t, err) - assert.ElementsMatch(t, []int{1, 2}, vals) -} - -func TestChainUnaryServerInterceptors_more(t *testing.T) { - var vals []int - interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{}, - info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { - vals = append(vals, 1) - return handler(ctx, req) - }, func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, - handler grpc.UnaryHandler) (resp interface{}, err error) { - vals = append(vals, 2) - return handler(ctx, req) - }) - _, err := interceptors(context.Background(), nil, nil, - func(ctx context.Context, req interface{}) (interface{}, error) { - vals = append(vals, 3) - return nil, nil - }) - assert.Nil(t, err) - assert.ElementsMatch(t, []int{1, 2, 3}, vals) -} +} \ No newline at end of file