chore: refactor errors to use errors.Is (#3654)

master
Kevin Wan 1 year ago committed by GitHub
parent 81ae7d36b5
commit 42e0a6f90c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,7 +30,7 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error { assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error {
return errDummy return errDummy
}, func(err error) bool { }, func(err error) bool {
return err == nil || err == errDummy return err == nil || errors.Is(err, errDummy)
})) }))
} }
verify(t, func() bool { verify(t, func() bool {
@ -45,12 +45,12 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
}, func(err error) bool { }, func(err error) bool {
return err == nil return err == nil
}) })
assert.True(t, err == errDummy || err == ErrServiceUnavailable) assert.True(t, errors.Is(err, errDummy) || errors.Is(err, ErrServiceUnavailable))
} }
verify(t, func() bool { verify(t, func() bool {
return ErrServiceUnavailable == Do("another", func() error { return errors.Is(Do("another", func() error {
return nil return nil
}) }), ErrServiceUnavailable)
}) })
} }
@ -75,12 +75,12 @@ func TestBreakersFallback(t *testing.T) {
}, func(err error) error { }, func(err error) error {
return nil return nil
}) })
assert.True(t, err == nil || err == errDummy) assert.True(t, err == nil || errors.Is(err, errDummy))
} }
verify(t, func() bool { verify(t, func() bool {
return ErrServiceUnavailable == Do("fallback", func() error { return errors.Is(Do("fallback", func() error {
return nil return nil
}) }), ErrServiceUnavailable)
}) })
} }
@ -94,12 +94,12 @@ func TestBreakersAcceptableFallback(t *testing.T) {
}, func(err error) bool { }, func(err error) bool {
return err == nil return err == nil
}) })
assert.True(t, err == nil || err == errDummy) assert.True(t, err == nil || errors.Is(err, errDummy))
} }
verify(t, func() bool { verify(t, func() bool {
return ErrServiceUnavailable == Do("acceptablefallback", func() error { return errors.Is(Do("acceptablefallback", func() error {
return nil return nil
}) }), ErrServiceUnavailable)
}) })
} }

@ -69,10 +69,10 @@ func (t *Tree) Add(route string, item any) error {
} }
err := add(t.root, route[1:], item) err := add(t.root, route[1:], item)
switch err { switch {
case errDupItem: case errors.Is(err, errDupItem):
return duplicatedItem(route) return duplicatedItem(route)
case errDupSlash: case errors.Is(err, errDupSlash):
return duplicatedSlash(route) return duplicatedSlash(route)
default: default:
return err return err

@ -96,7 +96,7 @@ func (c cacheNode) Get(key string, val any) error {
// GetCtx gets the cache with key and fills into v. // GetCtx gets the cache with key and fills into v.
func (c cacheNode) GetCtx(ctx context.Context, key string, val any) error { func (c cacheNode) GetCtx(ctx context.Context, key string, val any) error {
err := c.doGetCache(ctx, key, val) err := c.doGetCache(ctx, key, val)
if err == errPlaceholder { if errors.Is(err, errPlaceholder) {
return c.errNotFound return c.errNotFound
} }
@ -210,16 +210,16 @@ func (c cacheNode) doTake(ctx context.Context, v any, key string,
logger := logx.WithContext(ctx) logger := logx.WithContext(ctx)
val, fresh, err := c.barrier.DoEx(key, func() (any, error) { val, fresh, err := c.barrier.DoEx(key, func() (any, error) {
if err := c.doGetCache(ctx, key, v); err != nil { if err := c.doGetCache(ctx, key, v); err != nil {
if err == errPlaceholder { if errors.Is(err, errPlaceholder) {
return nil, c.errNotFound return nil, c.errNotFound
} else if err != c.errNotFound { } else if !errors.Is(err, c.errNotFound) {
// why we just return the error instead of query from db, // why we just return the error instead of query from db,
// because we don't allow the disaster pass to the dbs. // because we don't allow the disaster pass to the dbs.
// fail fast, in case we bring down the dbs. // fail fast, in case we bring down the dbs.
return nil, err return nil, err
} }
if err = query(v); err == c.errNotFound { if err = query(v); errors.Is(err, c.errNotFound) {
if err = c.setCacheWithNotFound(ctx, key); err != nil { if err = c.setCacheWithNotFound(ctx, key); err != nil {
logger.Error(err) logger.Error(err)
} }

@ -3,6 +3,7 @@ package mon
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"time" "time"
"github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/breaker"
@ -562,11 +563,19 @@ func (p keepablePromise) keep(err error) error {
} }
func acceptable(err error) bool { func acceptable(err error) bool {
return err == nil || err == mongo.ErrNoDocuments || err == mongo.ErrNilValue || return err == nil ||
err == mongo.ErrNilDocument || err == mongo.ErrNilCursor || err == mongo.ErrEmptySlice || errors.Is(err, mongo.ErrNoDocuments) ||
errors.Is(err, mongo.ErrNilValue) ||
errors.Is(err, mongo.ErrNilDocument) ||
errors.Is(err, mongo.ErrNilCursor) ||
errors.Is(err, mongo.ErrEmptySlice) ||
// session errors // session errors
err == session.ErrSessionEnded || err == session.ErrNoTransactStarted || errors.Is(err, session.ErrSessionEnded) ||
err == session.ErrTransactInProgress || err == session.ErrAbortAfterCommit || errors.Is(err, session.ErrNoTransactStarted) ||
err == session.ErrAbortTwice || err == session.ErrCommitAfterAbort || errors.Is(err, session.ErrTransactInProgress) ||
err == session.ErrUnackWCUnsupported || err == session.ErrSnapshotTransaction errors.Is(err, session.ErrAbortAfterCommit) ||
errors.Is(err, session.ErrAbortTwice) ||
errors.Is(err, session.ErrCommitAfterAbort) ||
errors.Is(err, session.ErrUnackWCUnsupported) ||
errors.Is(err, session.ErrSnapshotTransaction)
} }

@ -2,6 +2,7 @@ package mon
import ( import (
"context" "context"
"errors"
"github.com/zeromicro/go-zero/core/trace" "github.com/zeromicro/go-zero/core/trace"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
@ -23,8 +24,8 @@ func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span
func endSpan(span oteltrace.Span, err error) { func endSpan(span oteltrace.Span, err error) {
defer span.End() defer span.End()
if err == nil || err == mongo.ErrNoDocuments || if err == nil || errors.Is(err, mongo.ErrNoDocuments) ||
err == mongo.ErrNilValue || err == mongo.ErrNilDocument { errors.Is(err, mongo.ErrNilValue) || errors.Is(err, mongo.ErrNilDocument) {
span.SetStatus(codes.Ok, "") span.SetStatus(codes.Ok, "")
return return
} }

@ -2849,7 +2849,7 @@ func withHook(hook red.Hook) Option {
} }
func acceptable(err error) bool { func acceptable(err error) bool {
return err == nil || err == red.Nil || err == context.Canceled return err == nil || err == red.Nil || errors.Is(err, context.Canceled)
} }
func getRedis(r *Redis) (RedisNode, error) { func getRedis(r *Redis) (RedisNode, error) {

@ -1,6 +1,10 @@
package sqlx package sqlx
import "github.com/go-sql-driver/mysql" import (
"errors"
"github.com/go-sql-driver/mysql"
)
const ( const (
mysqlDriverName = "mysql" mysqlDriverName = "mysql"
@ -18,7 +22,8 @@ func mysqlAcceptable(err error) bool {
return true return true
} }
myerr, ok := err.(*mysql.MySQLError) var myerr *mysql.MySQLError
ok := errors.As(err, &myerr)
if !ok { if !ok {
return false return false
} }

@ -28,7 +28,7 @@ func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) {
var found bool var found bool
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
if tryOnDuplicateEntryError(t, nil) == breaker.ErrServiceUnavailable { if errors.Is(tryOnDuplicateEntryError(t, nil), breaker.ErrServiceUnavailable) {
found = true found = true
} }
} }

@ -3,6 +3,7 @@ package sqlx
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
@ -157,7 +158,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
result, err = exec(ctx, conn, q, args...) result, err = exec(ctx, conn, q, args...)
return err return err
}, db.acceptable) }, db.acceptable)
if err == breaker.ErrServiceUnavailable { if errors.Is(err, breaker.ErrServiceUnavailable) {
metricReqErr.Inc("Exec", "breaker") metricReqErr.Inc("Exec", "breaker")
} }
@ -193,7 +194,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
} }
return nil return nil
}, db.acceptable) }, db.acceptable)
if err == breaker.ErrServiceUnavailable { if errors.Is(err, breaker.ErrServiceUnavailable) {
metricReqErr.Inc("Prepare", "breaker") metricReqErr.Inc("Prepare", "breaker")
} }
@ -283,7 +284,7 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
err = db.brk.DoWithAcceptable(func() error { err = db.brk.DoWithAcceptable(func() error {
return transact(ctx, db, db.beginTx, fn) return transact(ctx, db, db.beginTx, fn)
}, db.acceptable) }, db.acceptable)
if err == breaker.ErrServiceUnavailable { if errors.Is(err, breaker.ErrServiceUnavailable) {
metricReqErr.Inc("Transact", "breaker") metricReqErr.Inc("Transact", "breaker")
} }
@ -291,11 +292,13 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
} }
func (db *commonSqlConn) acceptable(err error) bool { func (db *commonSqlConn) acceptable(err error) bool {
if err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled { if err == nil || errors.Is(err, sql.ErrNoRows) || errors.Is(err, sql.ErrTxDone) ||
errors.Is(err, context.Canceled) {
return true return true
} }
if _, ok := err.(acceptableError); ok { var e acceptableError
if errors.As(err, &e) {
return true return true
} }
@ -321,9 +324,9 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
return qerr return qerr
}, q, args...) }, q, args...)
}, func(err error) bool { }, func(err error) bool {
return qerr == err || db.acceptable(err) return errors.Is(err, qerr) || db.acceptable(err)
}) })
if err == breaker.ErrServiceUnavailable { if errors.Is(err, breaker.ErrServiceUnavailable) {
metricReqErr.Inc("queryRows", "breaker") metricReqErr.Inc("queryRows", "breaker")
} }

@ -143,7 +143,7 @@ func logInstanceError(ctx context.Context, datasource string, err error) {
} }
func logSqlError(ctx context.Context, stmt string, err error) { func logSqlError(ctx context.Context, stmt string, err error) {
if err != nil && err != ErrNotFound { if err != nil && !errors.Is(err, ErrNotFound) {
logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error()) logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
} }
} }

@ -27,7 +27,7 @@ func TestLockedCallDoErr(t *testing.T) {
v, err := g.Do("key", func() (any, error) { v, err := g.Do("key", func() (any, error) {
return nil, someErr return nil, someErr
}) })
if err != someErr { if !errors.Is(err, someErr) {
t.Errorf("Do error = %v; want someErr", err) t.Errorf("Do error = %v; want someErr", err)
} }
if v != nil { if v != nil {

@ -28,7 +28,7 @@ func TestExclusiveCallDoErr(t *testing.T) {
v, err := g.Do("key", func() (any, error) { v, err := g.Do("key", func() (any, error) {
return nil, someErr return nil, someErr
}) })
if err != someErr { if !errors.Is(err, someErr) {
t.Errorf("Do error = %v; want someErr", err) t.Errorf("Do error = %v; want someErr", err)
} }
if v != nil { if v != nil {

@ -3,6 +3,7 @@ package httpx
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"sync" "sync"
@ -141,10 +142,10 @@ func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, a
return return
} }
e, ok := body.(error) switch v := body.(type) {
if ok { case error:
http.Error(w, e.Error(), code) http.Error(w, v.Error(), code)
} else { default:
writeJson(w, code, body) writeJson(w, code, body)
} }
} }
@ -162,7 +163,7 @@ func doWriteJson(w http.ResponseWriter, code int, v any) error {
if n, err := w.Write(bs); err != nil { if n, err := w.Write(bs); err != nil {
// http.ErrHandlerTimeout has been handled by http.TimeoutHandler, // http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
// so it's ignored here. // so it's ignored here.
if err != http.ErrHandlerTimeout { if !errors.Is(err, http.ErrHandlerTimeout) {
return fmt.Errorf("write response failed, error: %w", err) return fmt.Errorf("write response failed, error: %w", err)
} }
} else if n < len(bs) { } else if n < len(bs) {

@ -2,6 +2,7 @@ package internal
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/http" "net/http"
@ -49,7 +50,7 @@ func start(host string, port int, handler http.Handler, run func(svr *http.Serve
} }
}) })
defer func() { defer func() {
if err == http.ErrServerClosed { if errors.Is(err, http.ErrServerClosed) {
waitForCalled() waitForCalled()
} }
}() }()

@ -2,6 +2,7 @@ package rest
import ( import (
"crypto/tls" "crypto/tls"
"errors"
"net/http" "net/http"
"path" "path"
"time" "time"
@ -307,7 +308,7 @@ func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
func handleError(err error) { func handleError(err error) {
// ErrServerClosed means the server is closed manually // ErrServerClosed means the server is closed manually
if err == nil || err == http.ErrServerClosed { if err == nil || errors.Is(err, http.ErrServerClosed) {
return return
} }

@ -56,7 +56,8 @@ func init() {
pgDatasourceCmdFlags.StringVar(&command.VarStringHome, "home") pgDatasourceCmdFlags.StringVar(&command.VarStringHome, "home")
pgDatasourceCmdFlags.StringVar(&command.VarStringRemote, "remote") pgDatasourceCmdFlags.StringVar(&command.VarStringRemote, "remote")
pgDatasourceCmdFlags.StringVar(&command.VarStringBranch, "branch") pgDatasourceCmdFlags.StringVar(&command.VarStringBranch, "branch")
pgCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns, "ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"}) pgCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns,
"ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t") mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t")
mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c") mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c")
@ -68,7 +69,8 @@ func init() {
mongoCmdFlags.StringVar(&mongo.VarStringBranch, "branch") mongoCmdFlags.StringVar(&mongo.VarStringBranch, "branch")
mysqlCmd.PersistentFlags().BoolVar(&command.VarBoolStrict, "strict") mysqlCmd.PersistentFlags().BoolVar(&command.VarBoolStrict, "strict")
mysqlCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns, "ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"}) mysqlCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns,
"ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
mysqlCmd.AddCommand(datasourceCmd, ddlCmd) mysqlCmd.AddCommand(datasourceCmd, ddlCmd)
pgCmd.AddCommand(pgDatasourceCmd) pgCmd.AddCommand(pgDatasourceCmd)

@ -8,7 +8,8 @@ import (
// Acceptable checks if given error is acceptable. // Acceptable checks if given error is acceptable.
func Acceptable(err error) bool { func Acceptable(err error) bool {
switch status.Code(err) { switch status.Code(err) {
case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss, codes.Unimplemented: case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss,
codes.Unimplemented, codes.ResourceExhausted:
return false return false
default: default:
return true return true

@ -2,10 +2,13 @@ package serverinterceptors
import ( import (
"context" "context"
"errors"
"github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/zrpc/internal/codes" "github.com/zeromicro/go-zero/zrpc/internal/codes"
"google.golang.org/grpc" "google.golang.org/grpc"
gcodes "google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
) )
// StreamBreakerInterceptor is an interceptor that acts as a circuit breaker. // StreamBreakerInterceptor is an interceptor that acts as a circuit breaker.
@ -26,6 +29,9 @@ func UnaryBreakerInterceptor(ctx context.Context, req any, info *grpc.UnaryServe
resp, err = handler(ctx, req) resp, err = handler(ctx, req)
return err return err
}, codes.Acceptable) }, codes.Acceptable)
if errors.Is(err, breaker.ErrServiceUnavailable) {
err = status.Error(gcodes.Unavailable, err.Error())
}
return resp, err return resp, err
} }

@ -2,11 +2,14 @@ package serverinterceptors
import ( import (
"context" "context"
"errors"
"sync" "sync"
"github.com/zeromicro/go-zero/core/load" "github.com/zeromicro/go-zero/core/load"
"github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/core/stat"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
) )
const serviceType = "rpc" const serviceType = "rpc"
@ -28,11 +31,12 @@ func UnarySheddingInterceptor(shedder load.Shedder, metrics *stat.Metrics) grpc.
if err != nil { if err != nil {
metrics.AddDrop() metrics.AddDrop()
sheddingStat.IncrementDrop() sheddingStat.IncrementDrop()
err = status.Error(codes.ResourceExhausted, err.Error())
return return
} }
defer func() { defer func() {
if err == context.DeadlineExceeded { if errors.Is(err, context.DeadlineExceeded) {
promise.Fail() promise.Fail()
} else { } else {
sheddingStat.IncrementPass() sheddingStat.IncrementPass()

@ -8,6 +8,8 @@ import (
"github.com/zeromicro/go-zero/core/load" "github.com/zeromicro/go-zero/core/load"
"github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/core/stat"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
) )
func TestUnarySheddingInterceptor(t *testing.T) { func TestUnarySheddingInterceptor(t *testing.T) {
@ -33,7 +35,7 @@ func TestUnarySheddingInterceptor(t *testing.T) {
name: "reject", name: "reject",
allow: false, allow: false,
handleErr: nil, handleErr: nil,
expect: load.ErrServiceOverloaded, expect: status.Error(codes.ResourceExhausted, load.ErrServiceOverloaded.Error()),
}, },
} }

@ -2,6 +2,7 @@ package serverinterceptors
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"runtime/debug" "runtime/debug"
"strings" "strings"
@ -49,9 +50,9 @@ func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor
return resp, err return resp, err
case <-ctx.Done(): case <-ctx.Done():
err := ctx.Err() err := ctx.Err()
if err == context.Canceled { if errors.Is(err, context.Canceled) {
err = status.Error(codes.Canceled, err.Error()) err = status.Error(codes.Canceled, err.Error())
} else if err == context.DeadlineExceeded { } else if errors.Is(err, context.DeadlineExceeded) {
err = status.Error(codes.DeadlineExceeded, err.Error()) err = status.Error(codes.DeadlineExceeded, err.Error())
} }
return nil, err return nil, err

Loading…
Cancel
Save