From 3cdfcb05f17c4c1c979576a2fb5aed584f56d1b7 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Mon, 4 Oct 2021 20:02:25 +0800 Subject: [PATCH] add more tests (#1114) --- .../clientinterceptors/tracinginterceptor.go | 36 +++--- .../tracinginterceptor_test.go | 105 +++++++++++++++++- 2 files changed, 119 insertions(+), 22 deletions(-) diff --git a/zrpc/internal/clientinterceptors/tracinginterceptor.go b/zrpc/internal/clientinterceptors/tracinginterceptor.go index 9faf88e7..668a7751 100644 --- a/zrpc/internal/clientinterceptors/tracinginterceptor.go +++ b/zrpc/internal/clientinterceptors/tracinginterceptor.go @@ -100,6 +100,24 @@ type ( } ) +func (w *clientStream) CloseSend() error { + err := w.ClientStream.CloseSend() + if err != nil { + w.sendStreamEvent(errorEvent, err) + } + + return err +} + +func (w *clientStream) Header() (metadata.MD, error) { + md, err := w.ClientStream.Header() + if err != nil { + w.sendStreamEvent(errorEvent, err) + } + + return md, err +} + func (w *clientStream) RecvMsg(m interface{}) error { err := w.ClientStream.RecvMsg(m) if err == nil && !w.desc.ServerStreams { @@ -127,24 +145,6 @@ func (w *clientStream) SendMsg(m interface{}) error { return err } -func (w *clientStream) Header() (metadata.MD, error) { - md, err := w.ClientStream.Header() - if err != nil { - w.sendStreamEvent(errorEvent, err) - } - - return md, err -} - -func (w *clientStream) CloseSend() error { - err := w.ClientStream.CloseSend() - if err != nil { - w.sendStreamEvent(errorEvent, err) - } - - return err -} - func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) { select { case <-w.eventsDone: diff --git a/zrpc/internal/clientinterceptors/tracinginterceptor_test.go b/zrpc/internal/clientinterceptors/tracinginterceptor_test.go index 7204b931..35cb8757 100644 --- a/zrpc/internal/clientinterceptors/tracinginterceptor_test.go +++ b/zrpc/internal/clientinterceptors/tracinginterceptor_test.go @@ -3,6 +3,7 @@ package clientinterceptors import ( "context" "errors" + "io" "sync" "sync/atomic" "testing" @@ -24,7 +25,8 @@ func TestOpenTracingInterceptor(t *testing.T) { }) cc := new(grpc.ClientConn) - err := UnaryTracingInterceptor(context.Background(), "/ListUser", nil, nil, cc, + ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{}) + err := UnaryTracingInterceptor(ctx, "/ListUser", nil, nil, cc, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return nil @@ -220,6 +222,101 @@ func TestStreamTracingInterceptor_GrpcFormat(t *testing.T) { assert.Equal(t, int32(1), atomic.LoadInt32(&run)) } +func TestClientStream_RecvMsg(t *testing.T) { + tests := []struct { + name string + serverStreams bool + err error + }{ + { + name: "nil error", + }, + { + name: "EOF", + err: io.EOF, + }, + { + name: "dummy error", + err: errors.New("dummy"), + }, + { + name: "server streams", + serverStreams: true, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + desc := new(grpc.StreamDesc) + desc.ServerStreams = test.serverStreams + stream := wrapClientStream(context.Background(), &mockedClientStream{ + md: nil, + err: test.err, + }, desc) + assert.Equal(t, test.err, stream.RecvMsg(nil)) + }) + } +} + +func TestClientStream_Header(t *testing.T) { + tests := []struct { + name string + err error + }{ + { + name: "nil error", + }, + { + name: "with error", + err: errors.New("dummy"), + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + desc := new(grpc.StreamDesc) + stream := wrapClientStream(context.Background(), &mockedClientStream{ + md: metadata.MD{}, + err: test.err, + }, desc) + _, err := stream.Header() + assert.Equal(t, test.err, err) + }) + } +} + +func TestClientStream_SendMsg(t *testing.T) { + tests := []struct { + name string + err error + }{ + { + name: "nil error", + }, + { + name: "with error", + err: errors.New("dummy"), + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + desc := new(grpc.StreamDesc) + stream := wrapClientStream(context.Background(), &mockedClientStream{ + md: metadata.MD{}, + err: test.err, + }, desc) + assert.Equal(t, test.err, stream.SendMsg(nil)) + }) + } +} + type mockedClientStream struct { md metadata.MD err error @@ -238,13 +335,13 @@ func (m *mockedClientStream) CloseSend() error { } func (m *mockedClientStream) Context() context.Context { - panic("implement me") + return context.Background() } func (m *mockedClientStream) SendMsg(v interface{}) error { - panic("implement me") + return m.err } func (m *mockedClientStream) RecvMsg(v interface{}) error { - panic("implement me") + return m.err }