feat: add middlewares config for zrpc (#2766)

* feat: add middlewares config for zrpc

* chore: add tests

* chore: improve codecov

* chore: improve codecov
master
Kevin Wan 2 years ago committed by GitHub
parent ade6f9ee46
commit 26c541b9cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -70,7 +70,7 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
return nil, err return nil, err
} }
client, err := internal.NewClient(target, opts...) client, err := internal.NewClient(target, c.Middlewares, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -82,7 +82,14 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
// NewClientWithTarget returns a Client with connecting to given target. // NewClientWithTarget returns a Client with connecting to given target.
func NewClientWithTarget(target string, opts ...ClientOption) (Client, error) { func NewClientWithTarget(target string, opts ...ClientOption) (Client, error) {
return internal.NewClient(target, opts...) middlewares := ClientMiddlewaresConf{
Trace: true,
Duration: true,
Prometheus: true,
Breaker: true,
Timeout: true,
}
return internal.NewClient(target, middlewares, opts...)
} }
// Conn returns the underlying grpc.ClientConn. // Conn returns the underlying grpc.ClientConn.

@ -76,6 +76,13 @@ func TestDepositServer_Deposit(t *testing.T) {
App: "foo", App: "foo",
Token: "bar", Token: "bar",
Timeout: 1000, Timeout: 1000,
Middlewares: ClientMiddlewaresConf{
Trace: true,
Duration: true,
Prometheus: true,
Breaker: true,
Timeout: true,
},
}, },
WithDialOption(grpc.WithContextDialer(dialer())), WithDialOption(grpc.WithContextDialer(dialer())),
WithUnaryClientInterceptor(func(ctx context.Context, method string, req, reply interface{}, WithUnaryClientInterceptor(func(ctx context.Context, method string, req, reply interface{},
@ -90,6 +97,13 @@ func TestDepositServer_Deposit(t *testing.T) {
Token: "bar", Token: "bar",
Timeout: 1000, Timeout: 1000,
NonBlock: true, NonBlock: true,
Middlewares: ClientMiddlewaresConf{
Trace: true,
Duration: true,
Prometheus: true,
Breaker: true,
Timeout: true,
},
}, },
WithDialOption(grpc.WithContextDialer(dialer())), WithDialOption(grpc.WithContextDialer(dialer())),
WithUnaryClientInterceptor(func(ctx context.Context, method string, req, reply interface{}, WithUnaryClientInterceptor(func(ctx context.Context, method string, req, reply interface{},
@ -103,6 +117,13 @@ func TestDepositServer_Deposit(t *testing.T) {
App: "foo", App: "foo",
Token: "bar", Token: "bar",
Timeout: 1000, Timeout: 1000,
Middlewares: ClientMiddlewaresConf{
Trace: true,
Duration: true,
Prometheus: true,
Breaker: true,
Timeout: true,
},
}, },
WithDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())), WithDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())),
WithDialOption(grpc.WithContextDialer(dialer())), WithDialOption(grpc.WithContextDialer(dialer())),

@ -4,10 +4,16 @@ import (
"github.com/zeromicro/go-zero/core/discov" "github.com/zeromicro/go-zero/core/discov"
"github.com/zeromicro/go-zero/core/service" "github.com/zeromicro/go-zero/core/service"
"github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/zrpc/internal"
"github.com/zeromicro/go-zero/zrpc/resolver" "github.com/zeromicro/go-zero/zrpc/resolver"
) )
type ( type (
// ClientMiddlewaresConf defines whether to use client middlewares.
ClientMiddlewaresConf = internal.ClientMiddlewaresConf
// ServerMiddlewaresConf defines whether to use server middlewares.
ServerMiddlewaresConf = internal.ServerMiddlewaresConf
// A RpcServerConf is a rpc server config. // A RpcServerConf is a rpc server config.
RpcServerConf struct { RpcServerConf struct {
service.ServiceConf service.ServiceConf
@ -20,18 +26,20 @@ type (
Timeout int64 `json:",default=2000"` Timeout int64 `json:",default=2000"`
CpuThreshold int64 `json:",default=900,range=[0:1000]"` CpuThreshold int64 `json:",default=900,range=[0:1000]"`
// grpc health check switch // grpc health check switch
Health bool `json:",default=true"` Health bool `json:",default=true"`
Middlewares ServerMiddlewaresConf
} }
// A RpcClientConf is a rpc client config. // A RpcClientConf is a rpc client config.
RpcClientConf struct { RpcClientConf struct {
Etcd discov.EtcdConf `json:",optional,inherit"` Etcd discov.EtcdConf `json:",optional,inherit"`
Endpoints []string `json:",optional"` Endpoints []string `json:",optional"`
Target string `json:",optional"` Target string `json:",optional"`
App string `json:",optional"` App string `json:",optional"`
Token string `json:",optional"` Token string `json:",optional"`
NonBlock bool `json:",optional"` NonBlock bool `json:",optional"`
Timeout int64 `json:",default=2000"` Timeout int64 `json:",default=2000"`
Middlewares ClientMiddlewaresConf
} }
) )

@ -42,13 +42,17 @@ type (
ClientOption func(options *ClientOptions) ClientOption func(options *ClientOptions)
client struct { client struct {
conn *grpc.ClientConn conn *grpc.ClientConn
middlewares ClientMiddlewaresConf
} }
) )
// NewClient returns a Client. // NewClient returns a Client.
func NewClient(target string, opts ...ClientOption) (Client, error) { func NewClient(target string, middlewares ClientMiddlewaresConf,
var cli client opts ...ClientOption) (Client, error) {
cli := client{
middlewares: middlewares,
}
svcCfg := fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, p2c.Name) svcCfg := fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, p2c.Name)
balancerOpt := WithDialOption(grpc.WithDefaultServiceConfig(svcCfg)) balancerOpt := WithDialOption(grpc.WithDefaultServiceConfig(svcCfg))
@ -80,21 +84,45 @@ func (c *client) buildDialOptions(opts ...ClientOption) []grpc.DialOption {
} }
options = append(options, options = append(options,
WithUnaryClientInterceptors( WithUnaryClientInterceptors(c.buildUnaryInterceptors(cliOpts.Timeout)...),
clientinterceptors.UnaryTracingInterceptor, WithStreamClientInterceptors(c.buildStreamInterceptors()...),
clientinterceptors.DurationInterceptor,
clientinterceptors.PrometheusInterceptor,
clientinterceptors.BreakerInterceptor,
clientinterceptors.TimeoutInterceptor(cliOpts.Timeout),
),
WithStreamClientInterceptors(
clientinterceptors.StreamTracingInterceptor,
),
) )
return append(options, cliOpts.DialOptions...) return append(options, cliOpts.DialOptions...)
} }
func (c *client) buildStreamInterceptors() []grpc.StreamClientInterceptor {
var interceptors []grpc.StreamClientInterceptor
if c.middlewares.Trace {
interceptors = append(interceptors, clientinterceptors.StreamTracingInterceptor)
}
return interceptors
}
func (c *client) buildUnaryInterceptors(timeout time.Duration) []grpc.UnaryClientInterceptor {
var interceptors []grpc.UnaryClientInterceptor
if c.middlewares.Trace {
interceptors = append(interceptors, clientinterceptors.UnaryTracingInterceptor)
}
if c.middlewares.Duration {
interceptors = append(interceptors, clientinterceptors.DurationInterceptor)
}
if c.middlewares.Prometheus {
interceptors = append(interceptors, clientinterceptors.PrometheusInterceptor)
}
if c.middlewares.Breaker {
interceptors = append(interceptors, clientinterceptors.BreakerInterceptor)
}
if c.middlewares.Timeout {
interceptors = append(interceptors, clientinterceptors.TimeoutInterceptor(timeout))
}
return interceptors
}
func (c *client) dial(server string, opts ...ClientOption) error { func (c *client) dial(server string, opts ...ClientOption) error {
options := c.buildDialOptions(opts...) options := c.buildDialOptions(opts...)
timeCtx, cancel := context.WithTimeout(context.Background(), dialTimeout) timeCtx, cancel := context.WithTimeout(context.Background(), dialTimeout)

@ -2,6 +2,8 @@ package internal
import ( import (
"context" "context"
"net"
"strings"
"testing" "testing"
"time" "time"
@ -60,8 +62,62 @@ func TestWithUnaryClientInterceptor(t *testing.T) {
} }
func TestBuildDialOptions(t *testing.T) { func TestBuildDialOptions(t *testing.T) {
var c client c := client{
middlewares: ClientMiddlewaresConf{
Trace: true,
Duration: true,
Prometheus: true,
Breaker: true,
Timeout: true,
},
}
agent := grpc.WithUserAgent("chrome") agent := grpc.WithUserAgent("chrome")
opts := c.buildDialOptions(WithDialOption(agent)) opts := c.buildDialOptions(WithDialOption(agent))
assert.Contains(t, opts, agent) assert.Contains(t, opts, agent)
} }
func TestClientDial(t *testing.T) {
server := grpc.NewServer()
go func() {
lis, err := net.Listen("tcp", "localhost:54321")
assert.NoError(t, err)
defer lis.Close()
server.Serve(lis)
}()
time.Sleep(time.Millisecond)
c, err := NewClient("localhost:54321", ClientMiddlewaresConf{
Trace: true,
Duration: true,
Prometheus: true,
Breaker: true,
Timeout: true,
})
assert.NoError(t, err)
assert.NotNil(t, c.Conn())
server.Stop()
}
func TestClientDialFail(t *testing.T) {
_, err := NewClient("localhost:54321", ClientMiddlewaresConf{
Trace: true,
Duration: true,
Prometheus: true,
Breaker: true,
Timeout: true,
})
assert.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "localhost:54321"))
_, err = NewClient("localhost:54321/fail", ClientMiddlewaresConf{
Trace: true,
Duration: true,
Prometheus: true,
Breaker: true,
Timeout: true,
})
assert.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "localhost:54321/fail"))
}

@ -0,0 +1,21 @@
package internal
type (
// ClientMiddlewaresConf defines whether to use client middlewares.
ClientMiddlewaresConf struct {
Trace bool `json:",default=true"`
Duration bool `json:",default=true"`
Prometheus bool `json:",default=true"`
Breaker bool `json:",default=true"`
Timeout bool `json:",default=true"`
}
// ServerMiddlewaresConf defines whether to use server middlewares.
ServerMiddlewaresConf struct {
Trace bool `json:",default=true"`
Recover bool `json:",default=true"`
Stat bool `json:",default=true"`
Prometheus bool `json:",default=true"`
Breaker bool `json:",default=true"`
}
)

@ -14,7 +14,8 @@ const (
) )
// NewRpcPubServer returns a Server. // NewRpcPubServer returns a Server.
func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, opts ...ServerOption) (Server, error) { func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, middlewares ServerMiddlewaresConf,
opts ...ServerOption) (Server, error) {
registerEtcd := func() error { registerEtcd := func() error {
pubListenOn := figureOutListenOn(listenOn) pubListenOn := figureOutListenOn(listenOn)
var pubOpts []discov.PubOption var pubOpts []discov.PubOption
@ -30,7 +31,7 @@ func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, opts ...ServerOption
} }
server := keepAliveServer{ server := keepAliveServer{
registerEtcd: registerEtcd, registerEtcd: registerEtcd,
Server: NewRpcServer(listenOn, opts...), Server: NewRpcServer(listenOn, middlewares, opts...),
} }
return server, nil return server, nil

@ -26,12 +26,13 @@ type (
rpcServer struct { rpcServer struct {
*baseRpcServer *baseRpcServer
name string name string
middlewares ServerMiddlewaresConf
healthManager health.Probe healthManager health.Probe
} }
) )
// NewRpcServer returns a Server. // NewRpcServer returns a Server.
func NewRpcServer(addr string, opts ...ServerOption) Server { func NewRpcServer(addr string, middlewares ServerMiddlewaresConf, opts ...ServerOption) Server {
var options rpcServerOptions var options rpcServerOptions
for _, opt := range opts { for _, opt := range opts {
opt(&options) opt(&options)
@ -42,6 +43,7 @@ func NewRpcServer(addr string, opts ...ServerOption) Server {
return &rpcServer{ return &rpcServer{
baseRpcServer: newBaseRpcServer(addr, &options), baseRpcServer: newBaseRpcServer(addr, &options),
middlewares: middlewares,
healthManager: health.NewHealthManager(fmt.Sprintf("%s-%s", probeNamePrefix, addr)), healthManager: health.NewHealthManager(fmt.Sprintf("%s-%s", probeNamePrefix, addr)),
} }
} }
@ -57,19 +59,9 @@ func (s *rpcServer) Start(register RegisterFn) error {
return err return err
} }
unaryInterceptors := []grpc.UnaryServerInterceptor{ unaryInterceptors := s.buildUnaryInterceptors()
serverinterceptors.UnaryTracingInterceptor,
serverinterceptors.UnaryCrashInterceptor,
serverinterceptors.UnaryStatInterceptor(s.metrics),
serverinterceptors.UnaryPrometheusInterceptor,
serverinterceptors.UnaryBreakerInterceptor,
}
unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...) unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...)
streamInterceptors := []grpc.StreamServerInterceptor{ streamInterceptors := s.buildStreamInterceptors()
serverinterceptors.StreamTracingInterceptor,
serverinterceptors.StreamCrashInterceptor,
serverinterceptors.StreamBreakerInterceptor,
}
streamInterceptors = append(streamInterceptors, s.streamInterceptors...) streamInterceptors = append(streamInterceptors, s.streamInterceptors...)
options := append(s.options, WithUnaryServerInterceptors(unaryInterceptors...), options := append(s.options, WithUnaryServerInterceptors(unaryInterceptors...),
WithStreamServerInterceptors(streamInterceptors...)) WithStreamServerInterceptors(streamInterceptors...))
@ -97,6 +89,44 @@ func (s *rpcServer) Start(register RegisterFn) error {
return server.Serve(lis) return server.Serve(lis)
} }
func (s *rpcServer) buildStreamInterceptors() []grpc.StreamServerInterceptor {
var interceptors []grpc.StreamServerInterceptor
if s.middlewares.Trace {
interceptors = append(interceptors, serverinterceptors.StreamTracingInterceptor)
}
if s.middlewares.Recover {
interceptors = append(interceptors, serverinterceptors.StreamRecoverInterceptor)
}
if s.middlewares.Breaker {
interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor)
}
return interceptors
}
func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
var interceptors []grpc.UnaryServerInterceptor
if s.middlewares.Trace {
interceptors = append(interceptors, serverinterceptors.UnaryTracingInterceptor)
}
if s.middlewares.Recover {
interceptors = append(interceptors, serverinterceptors.UnaryRecoverInterceptor)
}
if s.middlewares.Stat {
interceptors = append(interceptors, serverinterceptors.UnaryStatInterceptor(s.metrics))
}
if s.middlewares.Prometheus {
interceptors = append(interceptors, serverinterceptors.UnaryPrometheusInterceptor)
}
if s.middlewares.Breaker {
interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor)
}
return interceptors
}
// WithMetrics returns a func that sets metrics to a Server. // WithMetrics returns a func that sets metrics to a Server.
func WithMetrics(metrics *stat.Metrics) ServerOption { func WithMetrics(metrics *stat.Metrics) ServerOption {
return func(options *rpcServerOptions) { return func(options *rpcServerOptions) {

@ -12,7 +12,13 @@ import (
func TestRpcServer(t *testing.T) { func TestRpcServer(t *testing.T) {
metrics := stat.NewMetrics("foo") metrics := stat.NewMetrics("foo")
server := NewRpcServer("localhost:54321", WithMetrics(metrics)) server := NewRpcServer("localhost:54321", ServerMiddlewaresConf{
Trace: true,
Recover: true,
Stat: true,
Prometheus: true,
Breaker: true,
}, WithMetrics(metrics))
server.SetName("mock") server.SetName("mock")
var wg sync.WaitGroup var wg sync.WaitGroup
var grpcServer *grpc.Server var grpcServer *grpc.Server
@ -36,7 +42,13 @@ func TestRpcServer(t *testing.T) {
} }
func TestRpcServer_WithBadAddress(t *testing.T) { func TestRpcServer_WithBadAddress(t *testing.T) {
server := NewRpcServer("localhost:111111") server := NewRpcServer("localhost:111111", ServerMiddlewaresConf{
Trace: true,
Recover: true,
Stat: true,
Prometheus: true,
Breaker: true,
})
server.SetName("mock") server.SetName("mock")
err := server.Start(func(server *grpc.Server) { err := server.Start(func(server *grpc.Server) {
mock.RegisterDepositServiceServer(server, new(mock.DepositServer)) mock.RegisterDepositServiceServer(server, new(mock.DepositServer))

@ -10,8 +10,8 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
// StreamCrashInterceptor catches panics in processing stream requests and recovers. // StreamRecoverInterceptor catches panics in processing stream requests and recovers.
func StreamCrashInterceptor(svr interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, func StreamRecoverInterceptor(svr interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo,
handler grpc.StreamHandler) (err error) { handler grpc.StreamHandler) (err error) {
defer handleCrash(func(r interface{}) { defer handleCrash(func(r interface{}) {
err = toPanicError(r) err = toPanicError(r)
@ -20,8 +20,8 @@ func StreamCrashInterceptor(svr interface{}, stream grpc.ServerStream, _ *grpc.S
return handler(svr, stream) return handler(svr, stream)
} }
// UnaryCrashInterceptor catches panics in processing unary requests and recovers. // UnaryRecoverInterceptor catches panics in processing unary requests and recovers.
func UnaryCrashInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, func UnaryRecoverInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp interface{}, err error) { handler grpc.UnaryHandler) (resp interface{}, err error) {
defer handleCrash(func(r interface{}) { defer handleCrash(func(r interface{}) {
err = toPanicError(r) err = toPanicError(r)

@ -14,7 +14,7 @@ func init() {
} }
func TestStreamCrashInterceptor(t *testing.T) { func TestStreamCrashInterceptor(t *testing.T) {
err := StreamCrashInterceptor(nil, nil, nil, func( err := StreamRecoverInterceptor(nil, nil, nil, func(
svr interface{}, stream grpc.ServerStream) error { svr interface{}, stream grpc.ServerStream) error {
panic("mock panic") panic("mock panic")
}) })
@ -22,7 +22,7 @@ func TestStreamCrashInterceptor(t *testing.T) {
} }
func TestUnaryCrashInterceptor(t *testing.T) { func TestUnaryCrashInterceptor(t *testing.T) {
_, err := UnaryCrashInterceptor(context.Background(), nil, nil, _, err := UnaryRecoverInterceptor(context.Background(), nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) { func(ctx context.Context, req interface{}) (interface{}, error) {
panic("mock panic") panic("mock panic")
}) })

@ -44,12 +44,12 @@ func NewServer(c RpcServerConf, register internal.RegisterFn) (*RpcServer, error
} }
if c.HasEtcd() { if c.HasEtcd() {
server, err = internal.NewRpcPubServer(c.Etcd, c.ListenOn, serverOptions...) server, err = internal.NewRpcPubServer(c.Etcd, c.ListenOn, c.Middlewares, serverOptions...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
server = internal.NewRpcServer(c.ListenOn, serverOptions...) server = internal.NewRpcServer(c.ListenOn, c.Middlewares, serverOptions...)
} }
server.SetName(c.Name) server.SetName(c.Name)

@ -28,6 +28,13 @@ func TestServer_setupInterceptors(t *testing.T) {
}, },
CpuThreshold: 10, CpuThreshold: 10,
Timeout: 100, Timeout: 100,
Middlewares: ServerMiddlewaresConf{
Trace: true,
Recover: true,
Stat: true,
Prometheus: true,
Breaker: true,
},
}, new(stat.Metrics)) }, new(stat.Metrics))
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 3, len(server.unaryInterceptors)) assert.Equal(t, 3, len(server.unaryInterceptors))
@ -51,11 +58,18 @@ func TestServer(t *testing.T) {
StrictControl: false, StrictControl: false,
Timeout: 0, Timeout: 0,
CpuThreshold: 0, CpuThreshold: 0,
Middlewares: ServerMiddlewaresConf{
Trace: true,
Recover: true,
Stat: true,
Prometheus: true,
Breaker: true,
},
}, func(server *grpc.Server) { }, func(server *grpc.Server) {
}) })
svr.AddOptions(grpc.ConnectionTimeout(time.Hour)) svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
svr.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor) svr.AddUnaryInterceptors(serverinterceptors.UnaryRecoverInterceptor)
svr.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor) svr.AddStreamInterceptors(serverinterceptors.StreamRecoverInterceptor)
go svr.Start() go svr.Start()
svr.Stop() svr.Stop()
} }
@ -74,6 +88,13 @@ func TestServerError(t *testing.T) {
}, },
Auth: true, Auth: true,
Redis: redis.RedisKeyConf{}, Redis: redis.RedisKeyConf{},
Middlewares: ServerMiddlewaresConf{
Trace: true,
Recover: true,
Stat: true,
Prometheus: true,
Breaker: true,
},
}, func(server *grpc.Server) { }, func(server *grpc.Server) {
}) })
assert.NotNil(t, err) assert.NotNil(t, err)
@ -93,11 +114,18 @@ func TestServer_HasEtcd(t *testing.T) {
Key: "any", Key: "any",
}, },
Redis: redis.RedisKeyConf{}, Redis: redis.RedisKeyConf{},
Middlewares: ServerMiddlewaresConf{
Trace: true,
Recover: true,
Stat: true,
Prometheus: true,
Breaker: true,
},
}, func(server *grpc.Server) { }, func(server *grpc.Server) {
}) })
svr.AddOptions(grpc.ConnectionTimeout(time.Hour)) svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
svr.AddUnaryInterceptors(serverinterceptors.UnaryCrashInterceptor) svr.AddUnaryInterceptors(serverinterceptors.UnaryRecoverInterceptor)
svr.AddStreamInterceptors(serverinterceptors.StreamCrashInterceptor) svr.AddStreamInterceptors(serverinterceptors.StreamRecoverInterceptor)
go svr.Start() go svr.Start()
svr.Stop() svr.Stop()
} }
@ -111,6 +139,13 @@ func TestServer_StartFailed(t *testing.T) {
}, },
}, },
ListenOn: "localhost:aaa", ListenOn: "localhost:aaa",
Middlewares: ServerMiddlewaresConf{
Trace: true,
Recover: true,
Stat: true,
Prometheus: true,
Breaker: true,
},
}, func(server *grpc.Server) { }, func(server *grpc.Server) {
}) })

Loading…
Cancel
Save