diff --git a/zrpc/internal/balancer/p2c/p2c.go b/zrpc/internal/balancer/p2c/p2c.go index 74335d38..08ea1aad 100644 --- a/zrpc/internal/balancer/p2c/p2c.go +++ b/zrpc/internal/balancer/p2c/p2c.go @@ -72,7 +72,7 @@ type p2cPicker struct { lock sync.Mutex } -func (p *p2cPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { +func (p *p2cPicker) Pick(_ balancer.PickInfo) (balancer.PickResult, error) { p.lock.Lock() defer p.lock.Unlock() diff --git a/zrpc/internal/balancer/p2c/p2c_test.go b/zrpc/internal/balancer/p2c/p2c_test.go index 1dab287c..b92420b0 100644 --- a/zrpc/internal/balancer/p2c/p2c_test.go +++ b/zrpc/internal/balancer/p2c/p2c_test.go @@ -123,6 +123,15 @@ func TestP2cPicker_Pick(t *testing.T) { } } +func TestPickerWithEmptyConns(t *testing.T) { + var picker p2cPicker + _, err := picker.Pick(balancer.PickInfo{ + FullMethodName: "/", + Ctx: context.Background(), + }) + assert.ErrorIs(t, err, balancer.ErrNoSubConnAvailable) +} + type mockClientConn struct { // add random string member to avoid map key equality. id string diff --git a/zrpc/internal/clientinterceptors/durationinterceptor_test.go b/zrpc/internal/clientinterceptors/durationinterceptor_test.go index a851fde3..e6122035 100644 --- a/zrpc/internal/clientinterceptors/durationinterceptor_test.go +++ b/zrpc/internal/clientinterceptors/durationinterceptor_test.go @@ -24,13 +24,82 @@ func TestDurationInterceptor(t *testing.T) { err: errors.New("mock"), }, } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cc := new(grpc.ClientConn) + err := DurationInterceptor(context.Background(), "/foo", nil, nil, cc, + func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, + opts ...grpc.CallOption) error { + return test.err + }) + assert.Equal(t, test.err, err) + }) + } + + DontLogContentForMethod("/foo") + t.Cleanup(func() { + notLoggingContentMethods.Delete("/foo") + }) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cc := new(grpc.ClientConn) + err := DurationInterceptor(context.Background(), "/foo", nil, nil, cc, + func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, + opts ...grpc.CallOption) error { + return test.err + }) + assert.Equal(t, test.err, err) + }) + } +} + +func TestDurationInterceptorWithSlowThreshold(t *testing.T) { + SetSlowThreshold(time.Microsecond) + t.Cleanup(func() { + SetSlowThreshold(defaultSlowThreshold) + }) + + tests := []struct { + name string + err error + }{ + { + name: "nil", + err: nil, + }, + { + name: "with error", + err: errors.New("mock"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cc := new(grpc.ClientConn) + err := DurationInterceptor(context.Background(), "/foo", nil, nil, cc, + func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, + opts ...grpc.CallOption) error { + time.Sleep(time.Millisecond * 10) + return test.err + }) + assert.Equal(t, test.err, err) + }) + } + DontLogContentForMethod("/foo") + t.Cleanup(func() { + notLoggingContentMethods.Delete("/foo") + }) + for _, test := range tests { t.Run(test.name, func(t *testing.T) { cc := new(grpc.ClientConn) err := DurationInterceptor(context.Background(), "/foo", nil, nil, cc, func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + time.Sleep(time.Millisecond * 10) return test.err }) assert.Equal(t, test.err, err) diff --git a/zrpc/internal/clientinterceptors/tracinginterceptor_test.go b/zrpc/internal/clientinterceptors/tracinginterceptor_test.go index c5a7b0a3..603de541 100644 --- a/zrpc/internal/clientinterceptors/tracinginterceptor_test.go +++ b/zrpc/internal/clientinterceptors/tracinginterceptor_test.go @@ -69,6 +69,23 @@ func TestUnaryTracingInterceptor_WithError(t *testing.T) { assert.Equal(t, int32(1), atomic.LoadInt32(&run)) } +func TestUnaryTracingInterceptor_WithStatusError(t *testing.T) { + var run int32 + var wg sync.WaitGroup + wg.Add(1) + cc := new(grpc.ClientConn) + err := UnaryTracingInterceptor(context.Background(), "/foo", nil, nil, cc, + func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, + opts ...grpc.CallOption) error { + defer wg.Done() + atomic.AddInt32(&run, 1) + return status.Error(codes.DataLoss, "dummy") + }) + wg.Wait() + assert.NotNil(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&run)) +} + func TestStreamTracingInterceptor(t *testing.T) { var run int32 var wg sync.WaitGroup diff --git a/zrpc/resolver/internal/kube/targetparser.go b/zrpc/resolver/internal/kube/targetparser.go index 811cae2e..f34e6e49 100644 --- a/zrpc/resolver/internal/kube/targetparser.go +++ b/zrpc/resolver/internal/kube/targetparser.go @@ -1,7 +1,6 @@ package kube import ( - "fmt" "strconv" "strings" @@ -34,10 +33,6 @@ func ParseTarget(target resolver.Target) (Service, error) { endpoints := targets.GetEndpoints(target) if strings.Contains(endpoints, colon) { segs := strings.SplitN(endpoints, colon, 2) - if len(segs) < 2 { - return emptyService, fmt.Errorf("bad endpoint: %s", endpoints) - } - service.Name = segs[0] port, err := strconv.Atoi(segs[1]) if err != nil { diff --git a/zrpc/resolver/internal/kube/targetparser_test.go b/zrpc/resolver/internal/kube/targetparser_test.go index ef69053c..46852d63 100644 --- a/zrpc/resolver/internal/kube/targetparser_test.go +++ b/zrpc/resolver/internal/kube/targetparser_test.go @@ -51,18 +51,24 @@ func TestParseTarget(t *testing.T) { input: "k8s://ns1/my-svc:800a", hasErr: true, }, + { + name: "bad endpoint", + input: "k8s://ns1:800/:", + hasErr: true, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { uri, err := url.Parse(test.input) - assert.Nil(t, err) - svc, err := ParseTarget(resolver.Target{URL: *uri}) - if test.hasErr { - assert.NotNil(t, err) - } else { - assert.Nil(t, err) - assert.Equal(t, test.expect, svc) + if assert.NoError(t, err) { + svc, err := ParseTarget(resolver.Target{URL: *uri}) + if test.hasErr { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, test.expect, svc) + } } }) }