diff --git a/zrpc/internal/codes/accept.go b/zrpc/internal/codes/accept.go index 106abbe8..0ecb1275 100644 --- a/zrpc/internal/codes/accept.go +++ b/zrpc/internal/codes/accept.go @@ -1,8 +1,6 @@ package codes import ( - "context" - "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -12,17 +10,6 @@ func Acceptable(err error) bool { switch status.Code(err) { case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss: return false - case codes.Unknown: - return acceptableUnknown(err) - default: - return true - } -} - -func acceptableUnknown(err error) bool { - switch err { - case context.DeadlineExceeded: - return false default: return true } diff --git a/zrpc/internal/codes/accept_test.go b/zrpc/internal/codes/accept_test.go index ddca4a08..00750126 100644 --- a/zrpc/internal/codes/accept_test.go +++ b/zrpc/internal/codes/accept_test.go @@ -9,6 +9,7 @@ import ( ) func TestAccept(t *testing.T) { + tests := []struct { name string err error diff --git a/zrpc/internal/serverinterceptors/timeoutinterceptor.go b/zrpc/internal/serverinterceptors/timeoutinterceptor.go index 83ed9977..67fc5d1a 100644 --- a/zrpc/internal/serverinterceptors/timeoutinterceptor.go +++ b/zrpc/internal/serverinterceptors/timeoutinterceptor.go @@ -3,6 +3,8 @@ package serverinterceptors import ( "context" "fmt" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "runtime/debug" "strings" "sync" @@ -46,7 +48,14 @@ func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor defer lock.Unlock() return resp, err case <-ctx.Done(): - return nil, ctx.Err() + err := ctx.Err() + + if err == context.Canceled { + err = status.Error(codes.Canceled, err.Error()) + } else if err == context.DeadlineExceeded { + err = status.Error(codes.DeadlineExceeded, err.Error()) + } + return nil, err } } } diff --git a/zrpc/internal/serverinterceptors/timeoutinterceptor_test.go b/zrpc/internal/serverinterceptors/timeoutinterceptor_test.go index abde9b3d..e483d0c9 100644 --- a/zrpc/internal/serverinterceptors/timeoutinterceptor_test.go +++ b/zrpc/internal/serverinterceptors/timeoutinterceptor_test.go @@ -2,6 +2,8 @@ package serverinterceptors import ( "context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "sync" "testing" "time" @@ -66,5 +68,24 @@ func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) { return nil, nil }) wg.Wait() - assert.Equal(t, context.DeadlineExceeded, err) + assert.EqualValues(t, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()), err) +} +func TestUnaryTimeoutInterceptor_cancel(t *testing.T) { + const timeout = time.Minute * 10 + interceptor := UnaryTimeoutInterceptor(timeout) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + var wg sync.WaitGroup + wg.Add(1) + _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{ + FullMethod: "/", + }, func(ctx context.Context, req interface{}) (interface{}, error) { + defer wg.Done() + time.Sleep(time.Millisecond * 50) + return nil, nil + }) + + wg.Wait() + assert.EqualValues(t, status.Error(codes.Canceled, context.Canceled.Error()), err) }