diff --git a/zrpc/internal/rpcserver.go b/zrpc/internal/rpcserver.go index 3cf8c4f1..b981edf3 100644 --- a/zrpc/internal/rpcserver.go +++ b/zrpc/internal/rpcserver.go @@ -59,12 +59,10 @@ func (s *rpcServer) Start(register RegisterFn) error { return err } - unaryInterceptors := s.buildUnaryInterceptors() - unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...) - streamInterceptors := s.buildStreamInterceptors() - streamInterceptors = append(streamInterceptors, s.streamInterceptors...) - options := append(s.options, grpc.ChainUnaryInterceptor(unaryInterceptors...), - grpc.ChainStreamInterceptor(streamInterceptors...)) + unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.buildUnaryInterceptors()...) + streamInterceptorOption := grpc.ChainStreamInterceptor(s.buildStreamInterceptors()...) + + options := append(s.options, unaryInterceptorOption, streamInterceptorOption) server := grpc.NewServer(options...) register(server) @@ -102,7 +100,7 @@ func (s *rpcServer) buildStreamInterceptors() []grpc.StreamServerInterceptor { interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor) } - return interceptors + return append(interceptors, s.streamInterceptors...) } func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor { @@ -124,7 +122,7 @@ func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor { interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor) } - return interceptors + return append(interceptors, s.unaryInterceptors...) } // WithMetrics returns a func that sets metrics to a Server. diff --git a/zrpc/internal/rpcserver_test.go b/zrpc/internal/rpcserver_test.go index ef088a99..db358b57 100644 --- a/zrpc/internal/rpcserver_test.go +++ b/zrpc/internal/rpcserver_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "sync" "testing" @@ -58,3 +59,115 @@ func TestRpcServer_WithBadAddress(t *testing.T) { }) assert.NotNil(t, err) } + +func TestRpcServer_buildUnaryInterceptor(t *testing.T) { + tests := []struct { + name string + r *rpcServer + len int + }{ + { + name: "empty", + r: &rpcServer{ + baseRpcServer: &baseRpcServer{}, + }, + len: 0, + }, + { + name: "custom", + r: &rpcServer{ + baseRpcServer: &baseRpcServer{ + unaryInterceptors: []grpc.UnaryServerInterceptor{ + func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler) (interface{}, error) { + return nil, nil + }, + }, + }, + }, + len: 1, + }, + { + name: "middleware", + r: &rpcServer{ + baseRpcServer: &baseRpcServer{ + unaryInterceptors: []grpc.UnaryServerInterceptor{ + func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler) (interface{}, error) { + return nil, nil + }, + }, + }, + middlewares: ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Stat: true, + Prometheus: true, + Breaker: true, + }, + }, + len: 6, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.len, len(test.r.buildUnaryInterceptors())) + }) + } +} + +func TestRpcServer_buildStreamInterceptor(t *testing.T) { + tests := []struct { + name string + r *rpcServer + len int + }{ + { + name: "empty", + r: &rpcServer{ + baseRpcServer: &baseRpcServer{}, + }, + len: 0, + }, + { + name: "custom", + r: &rpcServer{ + baseRpcServer: &baseRpcServer{ + streamInterceptors: []grpc.StreamServerInterceptor{ + func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, + handler grpc.StreamHandler) error { + return nil + }, + }, + }, + }, + len: 1, + }, + { + name: "middleware", + r: &rpcServer{ + baseRpcServer: &baseRpcServer{ + streamInterceptors: []grpc.StreamServerInterceptor{ + func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, + handler grpc.StreamHandler) error { + return nil + }, + }, + }, + middlewares: ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Breaker: true, + }, + }, + len: 4, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.len, len(test.r.buildStreamInterceptors())) + }) + } +}