diff --git a/zrpc/client.go b/zrpc/client.go index 7fe63128..790a5a75 100644 --- a/zrpc/client.go +++ b/zrpc/client.go @@ -70,7 +70,7 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) { return nil, err } - client, err := internal.NewClient(target, opts...) + client, err := internal.NewClient(target, c.Middlewares, opts...) if err != nil { return nil, err } @@ -82,7 +82,14 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) { // NewClientWithTarget returns a Client with connecting to given target. func NewClientWithTarget(target string, opts ...ClientOption) (Client, error) { - return internal.NewClient(target, opts...) + middlewares := ClientMiddlewaresConf{ + Trace: true, + Duration: true, + Prometheus: true, + Breaker: true, + Timeout: true, + } + return internal.NewClient(target, middlewares, opts...) } // Conn returns the underlying grpc.ClientConn. diff --git a/zrpc/client_test.go b/zrpc/client_test.go index 80fb39e9..3f15ae54 100644 --- a/zrpc/client_test.go +++ b/zrpc/client_test.go @@ -76,6 +76,13 @@ func TestDepositServer_Deposit(t *testing.T) { App: "foo", Token: "bar", Timeout: 1000, + Middlewares: ClientMiddlewaresConf{ + Trace: true, + Duration: true, + Prometheus: true, + Breaker: true, + Timeout: true, + }, }, WithDialOption(grpc.WithContextDialer(dialer())), WithUnaryClientInterceptor(func(ctx context.Context, method string, req, reply interface{}, @@ -90,6 +97,13 @@ func TestDepositServer_Deposit(t *testing.T) { Token: "bar", Timeout: 1000, NonBlock: true, + Middlewares: ClientMiddlewaresConf{ + Trace: true, + Duration: true, + Prometheus: true, + Breaker: true, + Timeout: true, + }, }, WithDialOption(grpc.WithContextDialer(dialer())), WithUnaryClientInterceptor(func(ctx context.Context, method string, req, reply interface{}, @@ -103,6 +117,13 @@ func TestDepositServer_Deposit(t *testing.T) { App: "foo", Token: "bar", Timeout: 1000, + Middlewares: ClientMiddlewaresConf{ + Trace: true, + Duration: true, + Prometheus: true, + Breaker: true, + Timeout: true, + }, }, WithDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())), WithDialOption(grpc.WithContextDialer(dialer())), diff --git a/zrpc/config.go b/zrpc/config.go index 225ee2eb..7f772807 100644 --- a/zrpc/config.go +++ b/zrpc/config.go @@ -4,10 +4,16 @@ import ( "github.com/zeromicro/go-zero/core/discov" "github.com/zeromicro/go-zero/core/service" "github.com/zeromicro/go-zero/core/stores/redis" + "github.com/zeromicro/go-zero/zrpc/internal" "github.com/zeromicro/go-zero/zrpc/resolver" ) type ( + // ClientMiddlewaresConf defines whether to use client middlewares. + ClientMiddlewaresConf = internal.ClientMiddlewaresConf + // ServerMiddlewaresConf defines whether to use server middlewares. + ServerMiddlewaresConf = internal.ServerMiddlewaresConf + // A RpcServerConf is a rpc server config. RpcServerConf struct { service.ServiceConf @@ -20,18 +26,20 @@ type ( Timeout int64 `json:",default=2000"` CpuThreshold int64 `json:",default=900,range=[0:1000]"` // grpc health check switch - Health bool `json:",default=true"` + Health bool `json:",default=true"` + Middlewares ServerMiddlewaresConf } // A RpcClientConf is a rpc client config. RpcClientConf struct { - Etcd discov.EtcdConf `json:",optional,inherit"` - Endpoints []string `json:",optional"` - Target string `json:",optional"` - App string `json:",optional"` - Token string `json:",optional"` - NonBlock bool `json:",optional"` - Timeout int64 `json:",default=2000"` + Etcd discov.EtcdConf `json:",optional,inherit"` + Endpoints []string `json:",optional"` + Target string `json:",optional"` + App string `json:",optional"` + Token string `json:",optional"` + NonBlock bool `json:",optional"` + Timeout int64 `json:",default=2000"` + Middlewares ClientMiddlewaresConf } ) diff --git a/zrpc/internal/client.go b/zrpc/internal/client.go index 6b3151bf..72c38315 100644 --- a/zrpc/internal/client.go +++ b/zrpc/internal/client.go @@ -42,13 +42,17 @@ type ( ClientOption func(options *ClientOptions) client struct { - conn *grpc.ClientConn + conn *grpc.ClientConn + middlewares ClientMiddlewaresConf } ) // NewClient returns a Client. -func NewClient(target string, opts ...ClientOption) (Client, error) { - var cli client +func NewClient(target string, middlewares ClientMiddlewaresConf, + opts ...ClientOption) (Client, error) { + cli := client{ + middlewares: middlewares, + } svcCfg := fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, p2c.Name) balancerOpt := WithDialOption(grpc.WithDefaultServiceConfig(svcCfg)) @@ -80,21 +84,45 @@ func (c *client) buildDialOptions(opts ...ClientOption) []grpc.DialOption { } options = append(options, - WithUnaryClientInterceptors( - clientinterceptors.UnaryTracingInterceptor, - clientinterceptors.DurationInterceptor, - clientinterceptors.PrometheusInterceptor, - clientinterceptors.BreakerInterceptor, - clientinterceptors.TimeoutInterceptor(cliOpts.Timeout), - ), - WithStreamClientInterceptors( - clientinterceptors.StreamTracingInterceptor, - ), + WithUnaryClientInterceptors(c.buildUnaryInterceptors(cliOpts.Timeout)...), + WithStreamClientInterceptors(c.buildStreamInterceptors()...), ) return append(options, cliOpts.DialOptions...) } +func (c *client) buildStreamInterceptors() []grpc.StreamClientInterceptor { + var interceptors []grpc.StreamClientInterceptor + + if c.middlewares.Trace { + interceptors = append(interceptors, clientinterceptors.StreamTracingInterceptor) + } + + return interceptors +} + +func (c *client) buildUnaryInterceptors(timeout time.Duration) []grpc.UnaryClientInterceptor { + var interceptors []grpc.UnaryClientInterceptor + + if c.middlewares.Trace { + interceptors = append(interceptors, clientinterceptors.UnaryTracingInterceptor) + } + if c.middlewares.Duration { + interceptors = append(interceptors, clientinterceptors.DurationInterceptor) + } + if c.middlewares.Prometheus { + interceptors = append(interceptors, clientinterceptors.PrometheusInterceptor) + } + if c.middlewares.Breaker { + interceptors = append(interceptors, clientinterceptors.BreakerInterceptor) + } + if c.middlewares.Timeout { + interceptors = append(interceptors, clientinterceptors.TimeoutInterceptor(timeout)) + } + + return interceptors +} + func (c *client) dial(server string, opts ...ClientOption) error { options := c.buildDialOptions(opts...) timeCtx, cancel := context.WithTimeout(context.Background(), dialTimeout) diff --git a/zrpc/internal/client_test.go b/zrpc/internal/client_test.go index 487049f6..9587e9ae 100644 --- a/zrpc/internal/client_test.go +++ b/zrpc/internal/client_test.go @@ -2,6 +2,8 @@ package internal import ( "context" + "net" + "strings" "testing" "time" @@ -60,8 +62,62 @@ func TestWithUnaryClientInterceptor(t *testing.T) { } func TestBuildDialOptions(t *testing.T) { - var c client + c := client{ + middlewares: ClientMiddlewaresConf{ + Trace: true, + Duration: true, + Prometheus: true, + Breaker: true, + Timeout: true, + }, + } agent := grpc.WithUserAgent("chrome") opts := c.buildDialOptions(WithDialOption(agent)) assert.Contains(t, opts, agent) } + +func TestClientDial(t *testing.T) { + server := grpc.NewServer() + + go func() { + lis, err := net.Listen("tcp", "localhost:54321") + assert.NoError(t, err) + defer lis.Close() + server.Serve(lis) + }() + + time.Sleep(time.Millisecond) + + c, err := NewClient("localhost:54321", ClientMiddlewaresConf{ + Trace: true, + Duration: true, + Prometheus: true, + Breaker: true, + Timeout: true, + }) + assert.NoError(t, err) + assert.NotNil(t, c.Conn()) + server.Stop() +} + +func TestClientDialFail(t *testing.T) { + _, err := NewClient("localhost:54321", ClientMiddlewaresConf{ + Trace: true, + Duration: true, + Prometheus: true, + Breaker: true, + Timeout: true, + }) + assert.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "localhost:54321")) + + _, err = NewClient("localhost:54321/fail", ClientMiddlewaresConf{ + Trace: true, + Duration: true, + Prometheus: true, + Breaker: true, + Timeout: true, + }) + assert.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "localhost:54321/fail")) +} diff --git a/zrpc/internal/config.go b/zrpc/internal/config.go new file mode 100644 index 00000000..8fc990ff --- /dev/null +++ b/zrpc/internal/config.go @@ -0,0 +1,21 @@ +package internal + +type ( + // ClientMiddlewaresConf defines whether to use client middlewares. + ClientMiddlewaresConf struct { + Trace bool `json:",default=true"` + Duration bool `json:",default=true"` + Prometheus bool `json:",default=true"` + Breaker bool `json:",default=true"` + Timeout bool `json:",default=true"` + } + + // ServerMiddlewaresConf defines whether to use server middlewares. + ServerMiddlewaresConf struct { + Trace bool `json:",default=true"` + Recover bool `json:",default=true"` + Stat bool `json:",default=true"` + Prometheus bool `json:",default=true"` + Breaker bool `json:",default=true"` + } +) diff --git a/zrpc/internal/rpcpubserver.go b/zrpc/internal/rpcpubserver.go index 5a2a5136..61b79a7d 100644 --- a/zrpc/internal/rpcpubserver.go +++ b/zrpc/internal/rpcpubserver.go @@ -14,7 +14,8 @@ const ( ) // NewRpcPubServer returns a Server. -func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, opts ...ServerOption) (Server, error) { +func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, middlewares ServerMiddlewaresConf, + opts ...ServerOption) (Server, error) { registerEtcd := func() error { pubListenOn := figureOutListenOn(listenOn) var pubOpts []discov.PubOption @@ -30,7 +31,7 @@ func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, opts ...ServerOption } server := keepAliveServer{ registerEtcd: registerEtcd, - Server: NewRpcServer(listenOn, opts...), + Server: NewRpcServer(listenOn, middlewares, opts...), } return server, nil diff --git a/zrpc/internal/rpcserver.go b/zrpc/internal/rpcserver.go index 410c0050..72d4c845 100644 --- a/zrpc/internal/rpcserver.go +++ b/zrpc/internal/rpcserver.go @@ -26,12 +26,13 @@ type ( rpcServer struct { *baseRpcServer name string + middlewares ServerMiddlewaresConf healthManager health.Probe } ) // NewRpcServer returns a Server. -func NewRpcServer(addr string, opts ...ServerOption) Server { +func NewRpcServer(addr string, middlewares ServerMiddlewaresConf, opts ...ServerOption) Server { var options rpcServerOptions for _, opt := range opts { opt(&options) @@ -42,6 +43,7 @@ func NewRpcServer(addr string, opts ...ServerOption) Server { return &rpcServer{ baseRpcServer: newBaseRpcServer(addr, &options), + middlewares: middlewares, healthManager: health.NewHealthManager(fmt.Sprintf("%s-%s", probeNamePrefix, addr)), } } @@ -57,19 +59,9 @@ func (s *rpcServer) Start(register RegisterFn) error { return err } - unaryInterceptors := []grpc.UnaryServerInterceptor{ - serverinterceptors.UnaryTracingInterceptor, - serverinterceptors.UnaryCrashInterceptor, - serverinterceptors.UnaryStatInterceptor(s.metrics), - serverinterceptors.UnaryPrometheusInterceptor, - serverinterceptors.UnaryBreakerInterceptor, - } + unaryInterceptors := s.buildUnaryInterceptors() unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...) - streamInterceptors := []grpc.StreamServerInterceptor{ - serverinterceptors.StreamTracingInterceptor, - serverinterceptors.StreamCrashInterceptor, - serverinterceptors.StreamBreakerInterceptor, - } + streamInterceptors := s.buildStreamInterceptors() streamInterceptors = append(streamInterceptors, s.streamInterceptors...) options := append(s.options, WithUnaryServerInterceptors(unaryInterceptors...), WithStreamServerInterceptors(streamInterceptors...)) @@ -97,6 +89,44 @@ func (s *rpcServer) Start(register RegisterFn) error { return server.Serve(lis) } +func (s *rpcServer) buildStreamInterceptors() []grpc.StreamServerInterceptor { + var interceptors []grpc.StreamServerInterceptor + + if s.middlewares.Trace { + interceptors = append(interceptors, serverinterceptors.StreamTracingInterceptor) + } + if s.middlewares.Recover { + interceptors = append(interceptors, serverinterceptors.StreamRecoverInterceptor) + } + if s.middlewares.Breaker { + interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor) + } + + return interceptors +} + +func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor { + var interceptors []grpc.UnaryServerInterceptor + + if s.middlewares.Trace { + interceptors = append(interceptors, serverinterceptors.UnaryTracingInterceptor) + } + if s.middlewares.Recover { + interceptors = append(interceptors, serverinterceptors.UnaryRecoverInterceptor) + } + if s.middlewares.Stat { + interceptors = append(interceptors, serverinterceptors.UnaryStatInterceptor(s.metrics)) + } + if s.middlewares.Prometheus { + interceptors = append(interceptors, serverinterceptors.UnaryPrometheusInterceptor) + } + if s.middlewares.Breaker { + interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor) + } + + return interceptors +} + // WithMetrics returns a func that sets metrics to a Server. func WithMetrics(metrics *stat.Metrics) ServerOption { return func(options *rpcServerOptions) { diff --git a/zrpc/internal/rpcserver_test.go b/zrpc/internal/rpcserver_test.go index 87f922db..c1bab676 100644 --- a/zrpc/internal/rpcserver_test.go +++ b/zrpc/internal/rpcserver_test.go @@ -12,7 +12,13 @@ import ( func TestRpcServer(t *testing.T) { metrics := stat.NewMetrics("foo") - server := NewRpcServer("localhost:54321", WithMetrics(metrics)) + server := NewRpcServer("localhost:54321", ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Stat: true, + Prometheus: true, + Breaker: true, + }, WithMetrics(metrics)) server.SetName("mock") var wg sync.WaitGroup var grpcServer *grpc.Server @@ -36,7 +42,13 @@ func TestRpcServer(t *testing.T) { } func TestRpcServer_WithBadAddress(t *testing.T) { - server := NewRpcServer("localhost:111111") + server := NewRpcServer("localhost:111111", ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Stat: true, + Prometheus: true, + Breaker: true, + }) server.SetName("mock") err := server.Start(func(server *grpc.Server) { mock.RegisterDepositServiceServer(server, new(mock.DepositServer)) diff --git a/zrpc/internal/serverinterceptors/crashinterceptor.go b/zrpc/internal/serverinterceptors/recoverinterceptor.go similarity index 69% rename from zrpc/internal/serverinterceptors/crashinterceptor.go rename to zrpc/internal/serverinterceptors/recoverinterceptor.go index 2d45a3b1..01fde7dc 100644 --- a/zrpc/internal/serverinterceptors/crashinterceptor.go +++ b/zrpc/internal/serverinterceptors/recoverinterceptor.go @@ -10,8 +10,8 @@ import ( "google.golang.org/grpc/status" ) -// StreamCrashInterceptor catches panics in processing stream requests and recovers. -func StreamCrashInterceptor(svr interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, +// StreamRecoverInterceptor catches panics in processing stream requests and recovers. +func StreamRecoverInterceptor(svr interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { defer handleCrash(func(r interface{}) { err = toPanicError(r) @@ -20,8 +20,8 @@ func StreamCrashInterceptor(svr interface{}, stream grpc.ServerStream, _ *grpc.S return handler(svr, stream) } -// UnaryCrashInterceptor catches panics in processing unary requests and recovers. -func UnaryCrashInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, +// UnaryRecoverInterceptor catches panics in processing unary requests and recovers. +func UnaryRecoverInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { defer handleCrash(func(r interface{}) { err = toPanicError(r) diff --git a/zrpc/internal/serverinterceptors/crashinterceptor_test.go b/zrpc/internal/serverinterceptors/recoverinterceptor_test.go similarity index 81% rename from zrpc/internal/serverinterceptors/crashinterceptor_test.go rename to zrpc/internal/serverinterceptors/recoverinterceptor_test.go index 1f5970c9..2325bf14 100644 --- a/zrpc/internal/serverinterceptors/crashinterceptor_test.go +++ b/zrpc/internal/serverinterceptors/recoverinterceptor_test.go @@ -14,7 +14,7 @@ func init() { } func TestStreamCrashInterceptor(t *testing.T) { - err := StreamCrashInterceptor(nil, nil, nil, func( + err := StreamRecoverInterceptor(nil, nil, nil, func( svr interface{}, stream grpc.ServerStream) error { panic("mock panic") }) @@ -22,7 +22,7 @@ func TestStreamCrashInterceptor(t *testing.T) { } func TestUnaryCrashInterceptor(t *testing.T) { - _, err := UnaryCrashInterceptor(context.Background(), nil, nil, + _, err := UnaryRecoverInterceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { panic("mock panic") }) diff --git a/zrpc/server.go b/zrpc/server.go index 6e6b1cfc..99a637ba 100644 --- a/zrpc/server.go +++ b/zrpc/server.go @@ -44,12 +44,12 @@ func NewServer(c RpcServerConf, register internal.RegisterFn) (*RpcServer, error } if c.HasEtcd() { - server, err = internal.NewRpcPubServer(c.Etcd, c.ListenOn, serverOptions...) + server, err = internal.NewRpcPubServer(c.Etcd, c.ListenOn, c.Middlewares, serverOptions...) if err != nil { return nil, err } } else { - server = internal.NewRpcServer(c.ListenOn, serverOptions...) + server = internal.NewRpcServer(c.ListenOn, c.Middlewares, serverOptions...) } server.SetName(c.Name) diff --git a/zrpc/server_test.go b/zrpc/server_test.go index 67d1c2f9..4cdcb8d8 100644 --- a/zrpc/server_test.go +++ b/zrpc/server_test.go @@ -28,6 +28,13 @@ func TestServer_setupInterceptors(t *testing.T) { }, CpuThreshold: 10, Timeout: 100, + Middlewares: ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Stat: true, + Prometheus: true, + Breaker: true, + }, }, new(stat.Metrics)) assert.Nil(t, err) assert.Equal(t, 3, len(server.unaryInterceptors)) @@ -51,11 +58,18 @@ func TestServer(t *testing.T) { StrictControl: false, Timeout: 0, CpuThreshold: 0, + Middlewares: ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Stat: true, + Prometheus: true, + Breaker: true, + }, }, func(server *grpc.Server) { }) svr.AddOptions(grpc.ConnectionTimeout(time.Hour)) - svr.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor) - svr.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor) + svr.AddUnaryInterceptors(serverinterceptors.UnaryRecoverInterceptor) + svr.AddStreamInterceptors(serverinterceptors.StreamRecoverInterceptor) go svr.Start() svr.Stop() } @@ -74,6 +88,13 @@ func TestServerError(t *testing.T) { }, Auth: true, Redis: redis.RedisKeyConf{}, + Middlewares: ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Stat: true, + Prometheus: true, + Breaker: true, + }, }, func(server *grpc.Server) { }) assert.NotNil(t, err) @@ -93,11 +114,18 @@ func TestServer_HasEtcd(t *testing.T) { Key: "any", }, Redis: redis.RedisKeyConf{}, + Middlewares: ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Stat: true, + Prometheus: true, + Breaker: true, + }, }, func(server *grpc.Server) { }) svr.AddOptions(grpc.ConnectionTimeout(time.Hour)) - svr.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor) - svr.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor) + svr.AddUnaryInterceptors(serverinterceptors.UnaryRecoverInterceptor) + svr.AddStreamInterceptors(serverinterceptors.StreamRecoverInterceptor) go svr.Start() svr.Stop() } @@ -111,6 +139,13 @@ func TestServer_StartFailed(t *testing.T) { }, }, ListenOn: "localhost:aaa", + Middlewares: ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Stat: true, + Prometheus: true, + Breaker: true, + }, }, func(server *grpc.Server) { })