diff --git a/zrpc/internal/clientinterceptors/tracinginterceptor.go b/zrpc/internal/clientinterceptors/tracinginterceptor.go index 6567deed..9faf88e7 100644 --- a/zrpc/internal/clientinterceptors/tracinginterceptor.go +++ b/zrpc/internal/clientinterceptors/tracinginterceptor.go @@ -67,10 +67,10 @@ func StreamTracingInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *gr s, ok := status.FromError(err) if ok { span.SetStatus(codes.Error, s.Message()) + span.SetAttributes(ztrace.StatusCodeAttr(s.Code())) } else { span.SetStatus(codes.Error, err.Error()) } - span.SetAttributes(ztrace.StatusCodeAttr(s.Code())) } else { span.SetAttributes(ztrace.StatusCodeAttr(gcodes.OK)) } diff --git a/zrpc/internal/clientinterceptors/tracinginterceptor_test.go b/zrpc/internal/clientinterceptors/tracinginterceptor_test.go index 72bf8b5c..7204b931 100644 --- a/zrpc/internal/clientinterceptors/tracinginterceptor_test.go +++ b/zrpc/internal/clientinterceptors/tracinginterceptor_test.go @@ -10,6 +10,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/tal-tech/go-zero/core/trace" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) func TestOpenTracingInterceptor(t *testing.T) { @@ -80,21 +83,107 @@ func TestStreamTracingInterceptor(t *testing.T) { assert.Equal(t, int32(1), atomic.LoadInt32(&run)) } -func TestStreamTracingInterceptor_WithError(t *testing.T) { - var run int32 +func TestStreamTracingInterceptor_FinishWithNormalError(t *testing.T) { var wg sync.WaitGroup wg.Add(1) cc := new(grpc.ClientConn) - _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo", + ctx, cancel := context.WithCancel(context.Background()) + stream, err := StreamTracingInterceptor(ctx, nil, cc, "/foo", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { defer wg.Done() - atomic.AddInt32(&run, 1) - return nil, errors.New("dummy") + return nil, nil }) wg.Wait() - assert.NotNil(t, err) - assert.Equal(t, int32(1), atomic.LoadInt32(&run)) + assert.Nil(t, err) + + cancel() + cs := stream.(*clientStream) + <-cs.eventsDone +} + +func TestStreamTracingInterceptor_FinishWithGrpcError(t *testing.T) { + tests := []struct { + name string + event streamEventType + err error + }{ + { + name: "receive event", + event: receiveEndEvent, + err: status.Error(codes.DataLoss, "dummy"), + }, + { + name: "error event", + event: errorEvent, + err: status.Error(codes.DataLoss, "dummy"), + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var wg sync.WaitGroup + wg.Add(1) + cc := new(grpc.ClientConn) + stream, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo", + func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, + opts ...grpc.CallOption) (grpc.ClientStream, error) { + defer wg.Done() + return &mockedClientStream{ + err: errors.New("dummy"), + }, nil + }) + wg.Wait() + assert.Nil(t, err) + + cs := stream.(*clientStream) + cs.sendStreamEvent(test.event, status.Error(codes.DataLoss, "dummy")) + <-cs.eventsDone + cs.sendStreamEvent(test.event, test.err) + assert.NotNil(t, cs.CloseSend()) + }) + } +} + +func TestStreamTracingInterceptor_WithError(t *testing.T) { + tests := []struct { + name string + err error + }{ + { + name: "normal error", + err: errors.New("dummy"), + }, + { + name: "grpc error", + err: status.Error(codes.DataLoss, "dummy"), + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var run int32 + var wg sync.WaitGroup + wg.Add(1) + cc := new(grpc.ClientConn) + _, err := StreamTracingInterceptor(context.Background(), nil, cc, "/foo", + func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, + opts ...grpc.CallOption) (grpc.ClientStream, error) { + defer wg.Done() + atomic.AddInt32(&run, 1) + return new(mockedClientStream), test.err + }) + wg.Wait() + assert.NotNil(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&run)) + }) + } } func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) { @@ -130,3 +219,32 @@ func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) { assert.Nil(t, err) assert.Equal(t, int32(1), atomic.LoadInt32(&run)) } + +type mockedClientStream struct { + md metadata.MD + err error +} + +func (m *mockedClientStream) Header() (metadata.MD, error) { + return m.md, m.err +} + +func (m *mockedClientStream) Trailer() metadata.MD { + panic("implement me") +} + +func (m *mockedClientStream) CloseSend() error { + return m.err +} + +func (m *mockedClientStream) Context() context.Context { + panic("implement me") +} + +func (m *mockedClientStream) SendMsg(v interface{}) error { + panic("implement me") +} + +func (m *mockedClientStream) RecvMsg(v interface{}) error { + panic("implement me") +}