diff --git a/core/service/serviceconf.go b/core/service/serviceconf.go index 2421ccfa..46bea43d 100644 --- a/core/service/serviceconf.go +++ b/core/service/serviceconf.go @@ -30,7 +30,7 @@ type ServiceConf struct { MetricsUrl string `json:",optional"` Prometheus prometheus.Config `json:",optional"` // TODO: enable it in v1.2.1 - // Telemetry opentelemetry.Config `json:",optional"` + // Telemetry opentelemetry.Config `json:",optional"` } // MustSetUp sets up the service, exits on error. diff --git a/zrpc/internal/client.go b/zrpc/internal/client.go index 945e0234..e8a9c28f 100644 --- a/zrpc/internal/client.go +++ b/zrpc/internal/client.go @@ -67,14 +67,15 @@ func (c *client) buildDialOptions(opts ...ClientOption) []grpc.DialOption { grpc.WithInsecure(), grpc.WithBlock(), WithUnaryClientInterceptors( - clientinterceptors.TracingInterceptor, + clientinterceptors.UnaryTracingInterceptor, + clientinterceptors.UnaryOpenTracingInterceptor(), clientinterceptors.DurationInterceptor, clientinterceptors.PrometheusInterceptor, clientinterceptors.BreakerInterceptor, clientinterceptors.TimeoutInterceptor(cliOpts.Timeout), - clientinterceptors.OpenTracingInterceptor(), ), WithStreamClientInterceptors( + clientinterceptors.StreamTracingInterceptor, clientinterceptors.StreamOpenTracingInterceptor(), ), } diff --git a/zrpc/internal/clientinterceptors/opentracinginterceptor.go b/zrpc/internal/clientinterceptors/opentracinginterceptor.go index 0905e1e2..04bc55b8 100644 --- a/zrpc/internal/clientinterceptors/opentracinginterceptor.go +++ b/zrpc/internal/clientinterceptors/opentracinginterceptor.go @@ -13,8 +13,8 @@ import ( "google.golang.org/grpc/status" ) -// OpenTracingInterceptor returns a grpc.UnaryClientInterceptor for opentelemetry. -func OpenTracingInterceptor() grpc.UnaryClientInterceptor { +// UnaryOpenTracingInterceptor returns a grpc.UnaryClientInterceptor for opentelemetry. +func UnaryOpenTracingInterceptor() grpc.UnaryClientInterceptor { propagator := otel.GetTextMapPropagator() return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { diff --git a/zrpc/internal/clientinterceptors/opentracinginterceptor_test.go b/zrpc/internal/clientinterceptors/opentracinginterceptor_test.go index c883e6a7..1ab5466c 100644 --- a/zrpc/internal/clientinterceptors/opentracinginterceptor_test.go +++ b/zrpc/internal/clientinterceptors/opentracinginterceptor_test.go @@ -18,7 +18,7 @@ func TestOpenTracingInterceptor(t *testing.T) { }) cc := new(grpc.ClientConn) - err := OpenTracingInterceptor()(context.Background(), "/ListUser", nil, nil, cc, + err := UnaryOpenTracingInterceptor()(context.Background(), "/ListUser", nil, nil, cc, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return nil diff --git a/zrpc/internal/clientinterceptors/tracinginterceptor.go b/zrpc/internal/clientinterceptors/tracinginterceptor.go index afd5312f..abf134c1 100644 --- a/zrpc/internal/clientinterceptors/tracinginterceptor.go +++ b/zrpc/internal/clientinterceptors/tracinginterceptor.go @@ -8,8 +8,8 @@ import ( "google.golang.org/grpc/metadata" ) -// TracingInterceptor is an interceptor that handles tracing. -func TracingInterceptor(ctx context.Context, method string, req, reply interface{}, +// UnaryTracingInterceptor is an interceptor that handles tracing. +func UnaryTracingInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { ctx, span := trace.StartClientSpan(ctx, cc.Target(), method) defer span.Finish() @@ -23,3 +23,19 @@ func TracingInterceptor(ctx context.Context, method string, req, reply interface return invoker(ctx, method, req, reply, cc, opts...) } + +// StreamTracingInterceptor is an interceptor that handles tracing for stream requests. +func StreamTracingInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, + method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + ctx, span := trace.StartClientSpan(ctx, cc.Target(), method) + defer span.Finish() + + var pairs []string + span.Visit(func(key, val string) bool { + pairs = append(pairs, key, val) + return true + }) + ctx = metadata.AppendToOutgoingContext(ctx, pairs...) + + return streamer(ctx, desc, cc, method, opts...) +} diff --git a/zrpc/internal/clientinterceptors/tracinginterceptor_test.go b/zrpc/internal/clientinterceptors/tracinginterceptor_test.go index 2d92fc3a..1c448d96 100644 --- a/zrpc/internal/clientinterceptors/tracinginterceptor_test.go +++ b/zrpc/internal/clientinterceptors/tracinginterceptor_test.go @@ -12,12 +12,12 @@ import ( "google.golang.org/grpc/metadata" ) -func TestTracingInterceptor(t *testing.T) { +func TestUnaryTracingInterceptor(t *testing.T) { var run int32 var wg sync.WaitGroup wg.Add(1) cc := new(grpc.ClientConn) - err := TracingInterceptor(context.Background(), "/foo", nil, nil, cc, + err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { defer wg.Done() @@ -29,7 +29,24 @@ func TestTracingInterceptor(t *testing.T) { assert.Equal(t, int32(1), atomic.LoadInt32(&run)) } -func TestTracingInterceptor_GrpcFormat(t *testing.T) { +func TestStreamTracingInterceptor(t *testing.T) { + var run int32 + var wg sync.WaitGroup + wg.Add(1) + cc := new(grpc.ClientConn) + _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo", + func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, + opts ...grpc.CallOption) (grpc.ClientStream, error) { + defer wg.Done() + atomic.AddInt32(&run, 1) + return nil, nil + }) + wg.Wait() + assert.Nil(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&run)) +} + +func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) { var run int32 var wg sync.WaitGroup wg.Add(1) @@ -40,7 +57,7 @@ func TestTracingInterceptor_GrpcFormat(t *testing.T) { assert.Nil(t, err) ctx, _ := trace.StartServerSpan(context.Background(), carrier, "user", "/foo") cc := new(grpc.ClientConn) - err = TracingInterceptor(ctx, "/foo", nil, nil, cc, + err = UnaryTracingInterceptor(ctx, "/foo", nil, nil, cc, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { defer wg.Done() @@ -51,3 +68,26 @@ func TestTracingInterceptor_GrpcFormat(t *testing.T) { assert.Nil(t, err) assert.Equal(t, int32(1), atomic.LoadInt32(&run)) } + +func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) { + var run int32 + var wg sync.WaitGroup + wg.Add(1) + md := metadata.New(map[string]string{ + "foo": "bar", + }) + carrier, err := trace.Inject(trace.GrpcFormat, md) + assert.Nil(t, err) + ctx, _ := trace.StartServerSpan(context.Background(), carrier, "user", "/foo") + cc := new(grpc.ClientConn) + _, err = StreamTracingInterceptor(ctx, nil, cc, "/foo", + func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, + opts ...grpc.CallOption) (grpc.ClientStream, error) { + defer wg.Done() + atomic.AddInt32(&run, 1) + return nil, nil + }) + wg.Wait() + assert.Nil(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&run)) +} diff --git a/zrpc/internal/rpcserver.go b/zrpc/internal/rpcserver.go index 3b5c9a3e..b71f6c2f 100644 --- a/zrpc/internal/rpcserver.go +++ b/zrpc/internal/rpcserver.go @@ -55,6 +55,7 @@ func (s *rpcServer) Start(register RegisterFn) error { unaryInterceptors := []grpc.UnaryServerInterceptor{ serverinterceptors.UnaryTracingInterceptor(s.name), + serverinterceptors.UnaryOpenTracingInterceptor(), serverinterceptors.UnaryCrashInterceptor(), serverinterceptors.UnaryStatInterceptor(s.metrics), serverinterceptors.UnaryPrometheusInterceptor(), @@ -62,6 +63,8 @@ func (s *rpcServer) Start(register RegisterFn) error { } unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...) streamInterceptors := []grpc.StreamServerInterceptor{ + serverinterceptors.StreamTracingInterceptor(s.name), + serverinterceptors.StreamOpenTracingInterceptor(), serverinterceptors.StreamCrashInterceptor, serverinterceptors.StreamBreakerInterceptor, } diff --git a/zrpc/internal/serverinterceptors/tracinginterceptor.go b/zrpc/internal/serverinterceptors/tracinginterceptor.go index 3ec81a27..b2fa6291 100644 --- a/zrpc/internal/serverinterceptors/tracinginterceptor.go +++ b/zrpc/internal/serverinterceptors/tracinginterceptor.go @@ -27,3 +27,23 @@ func UnaryTracingInterceptor(serviceName string) grpc.UnaryServerInterceptor { return handler(ctx, req) } } + +func StreamTracingInterceptor(serviceName string) grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, + handler grpc.StreamHandler) error { + ctx := ss.Context() + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return handler(srv, ss) + } + + carrier, err := trace.Extract(trace.GrpcFormat, md) + if err != nil { + return handler(srv, ss) + } + + ctx, span := trace.StartServerSpan(ctx, carrier, serviceName, info.FullMethod) + defer span.Finish() + return handler(srv, ss) + } +} diff --git a/zrpc/internal/serverinterceptors/tracinginterceptor_test.go b/zrpc/internal/serverinterceptors/tracinginterceptor_test.go index 86fbac36..e1bae3d4 100644 --- a/zrpc/internal/serverinterceptors/tracinginterceptor_test.go +++ b/zrpc/internal/serverinterceptors/tracinginterceptor_test.go @@ -46,3 +46,71 @@ func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) { wg.Wait() assert.Nil(t, err) } + +func TestStreamTracingInterceptor(t *testing.T) { + interceptor := StreamTracingInterceptor("foo") + var run int32 + var wg sync.WaitGroup + wg.Add(1) + err := interceptor(nil, new(mockedServerStream), nil, + func(srv interface{}, stream grpc.ServerStream) error { + defer wg.Done() + atomic.AddInt32(&run, 1) + return nil + }) + wg.Wait() + assert.Nil(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&run)) +} + +func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) { + interceptor := StreamTracingInterceptor("foo") + var run int32 + var wg sync.WaitGroup + wg.Add(1) + var md metadata.MD + ctx := metadata.NewIncomingContext(context.Background(), md) + stream := mockedServerStream{ctx: ctx} + err := interceptor(nil, &stream, &grpc.StreamServerInfo{ + FullMethod: "/foo", + }, func(srv interface{}, stream grpc.ServerStream) error { + defer wg.Done() + atomic.AddInt32(&run, 1) + return nil + }) + wg.Wait() + assert.Nil(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&run)) +} + +type mockedServerStream struct { + ctx context.Context +} + +func (m *mockedServerStream) SetHeader(md metadata.MD) error { + panic("implement me") +} + +func (m *mockedServerStream) SendHeader(md metadata.MD) error { + panic("implement me") +} + +func (m *mockedServerStream) SetTrailer(md metadata.MD) { + panic("implement me") +} + +func (m *mockedServerStream) Context() context.Context { + if m.ctx == nil { + return context.Background() + } + + return m.ctx +} + +func (m *mockedServerStream) SendMsg(v interface{}) error { + panic("implement me") +} + +func (m *mockedServerStream) RecvMsg(v interface{}) error { + panic("implement me") +}