feat: converge grpc interceptor processing (#2830)

* feat: converge grpc interceptor processing

* x

* x
master
MarkJoyMa 2 years ago committed by GitHub
parent 3c0dc8435e
commit dd117ce9cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -59,12 +59,10 @@ func (s *rpcServer) Start(register RegisterFn) error {
return err return err
} }
unaryInterceptors := s.buildUnaryInterceptors() unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.buildUnaryInterceptors()...)
unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...) streamInterceptorOption := grpc.ChainStreamInterceptor(s.buildStreamInterceptors()...)
streamInterceptors := s.buildStreamInterceptors()
streamInterceptors = append(streamInterceptors, s.streamInterceptors...) options := append(s.options, unaryInterceptorOption, streamInterceptorOption)
options := append(s.options, grpc.ChainUnaryInterceptor(unaryInterceptors...),
grpc.ChainStreamInterceptor(streamInterceptors...))
server := grpc.NewServer(options...) server := grpc.NewServer(options...)
register(server) register(server)
@ -102,7 +100,7 @@ func (s *rpcServer) buildStreamInterceptors() []grpc.StreamServerInterceptor {
interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor) interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor)
} }
return interceptors return append(interceptors, s.streamInterceptors...)
} }
func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor { func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
@ -124,7 +122,7 @@ func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor) interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor)
} }
return interceptors return append(interceptors, s.unaryInterceptors...)
} }
// WithMetrics returns a func that sets metrics to a Server. // WithMetrics returns a func that sets metrics to a Server.

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"sync" "sync"
"testing" "testing"
@ -58,3 +59,115 @@ func TestRpcServer_WithBadAddress(t *testing.T) {
}) })
assert.NotNil(t, err) assert.NotNil(t, err)
} }
func TestRpcServer_buildUnaryInterceptor(t *testing.T) {
tests := []struct {
name string
r *rpcServer
len int
}{
{
name: "empty",
r: &rpcServer{
baseRpcServer: &baseRpcServer{},
},
len: 0,
},
{
name: "custom",
r: &rpcServer{
baseRpcServer: &baseRpcServer{
unaryInterceptors: []grpc.UnaryServerInterceptor{
func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) {
return nil, nil
},
},
},
},
len: 1,
},
{
name: "middleware",
r: &rpcServer{
baseRpcServer: &baseRpcServer{
unaryInterceptors: []grpc.UnaryServerInterceptor{
func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) {
return nil, nil
},
},
},
middlewares: ServerMiddlewaresConf{
Trace: true,
Recover: true,
Stat: true,
Prometheus: true,
Breaker: true,
},
},
len: 6,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.len, len(test.r.buildUnaryInterceptors()))
})
}
}
func TestRpcServer_buildStreamInterceptor(t *testing.T) {
tests := []struct {
name string
r *rpcServer
len int
}{
{
name: "empty",
r: &rpcServer{
baseRpcServer: &baseRpcServer{},
},
len: 0,
},
{
name: "custom",
r: &rpcServer{
baseRpcServer: &baseRpcServer{
streamInterceptors: []grpc.StreamServerInterceptor{
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {
return nil
},
},
},
},
len: 1,
},
{
name: "middleware",
r: &rpcServer{
baseRpcServer: &baseRpcServer{
streamInterceptors: []grpc.StreamServerInterceptor{
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {
return nil
},
},
},
middlewares: ServerMiddlewaresConf{
Trace: true,
Recover: true,
Breaker: true,
},
},
len: 4,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.len, len(test.r.buildStreamInterceptors()))
})
}
}

Loading…
Cancel
Save