fix: format error should not trigger circuit breaker in sqlx (#3437)

master
Kevin Wan 1 year ago committed by GitHub
parent 05db706c62
commit ff04356704
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -158,7 +158,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
refValue := reflect.ValueOf(mapValue) refValue := reflect.ValueOf(mapValue)
if refValue.Kind() != reflect.Slice { if refValue.Kind() != reflect.Slice {
return fmt.Errorf("%s: %v", fullName, errTypeMismatch) return newTypeMismatchErrorWithHint(fullName, reflect.Slice.String(), refValue.Type().String())
} }
if refValue.IsNil() { if refValue.IsNil() {
return nil return nil
@ -180,9 +180,9 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
continue continue
} }
valid = true
sliceFullName := fmt.Sprintf("%s[%d]", fullName, i) sliceFullName := fmt.Sprintf("%s[%d]", fullName, i)
valid = true
switch dereffedBaseKind { switch dereffedBaseKind {
case reflect.Struct: case reflect.Struct:
target := reflect.New(dereffedBaseType) target := reflect.New(dereffedBaseType)
@ -319,7 +319,6 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any,
for _, key := range refValue.MapKeys() { for _, key := range refValue.MapKeys() {
keythValue := refValue.MapIndex(key) keythValue := refValue.MapIndex(key)
keythData := keythValue.Interface() keythData := keythValue.Interface()
mapFullName := fmt.Sprintf("%s[%s]", fullName, key.String()) mapFullName := fmt.Sprintf("%s[%s]", fullName, key.String())
switch dereffedElemKind { switch dereffedElemKind {

@ -5081,6 +5081,17 @@ func TestGetValueWithChainedKeys(t *testing.T) {
}) })
} }
func TestUnmarshalFromStringSliceForTypeMismatch(t *testing.T) {
var v struct {
Values map[string][]string `key:"values"`
}
assert.Error(t, UnmarshalKey(map[string]any{
"values": map[string]any{
"foo": "bar",
},
}, &v))
}
func BenchmarkDefaultValue(b *testing.B) { func BenchmarkDefaultValue(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var a struct { var a struct {

@ -103,21 +103,21 @@ func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(str, 10, 64) intValue, err := strconv.ParseInt(str, 10, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("the value %q cannot parsed as int", str) return 0, fmt.Errorf("the value %q cannot be parsed as int", str)
} }
return intValue, nil return intValue, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uintValue, err := strconv.ParseUint(str, 10, 64) uintValue, err := strconv.ParseUint(str, 10, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("the value %q cannot parsed as uint", str) return 0, fmt.Errorf("the value %q cannot be parsed as uint", str)
} }
return uintValue, nil return uintValue, nil
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
floatValue, err := strconv.ParseFloat(str, 64) floatValue, err := strconv.ParseFloat(str, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("the value %q cannot parsed as float", str) return 0, fmt.Errorf("the value %q cannot be parsed as float", str)
} }
return floatValue, nil return floatValue, nil

@ -291,12 +291,19 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
} }
func (db *commonSqlConn) acceptable(err error) bool { func (db *commonSqlConn) acceptable(err error) bool {
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled if err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled {
return true
}
if _, ok := err.(acceptableError); ok {
return true
}
if db.accept == nil { if db.accept == nil {
return ok return false
} }
return ok || db.accept(err) return db.accept(err)
} }
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error, func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,

@ -236,6 +236,33 @@ func TestStatement(t *testing.T) {
}) })
} }
func TestBreakerWithFormatError(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db, withMysqlAcceptable())
for i := 0; i < 1000; i++ {
var val string
if !assert.NotEqual(t, breaker.ErrServiceUnavailable,
conn.QueryRow(&val, "any ?, ?", "foo")) {
break
}
}
})
}
func TestBreakerWithScanError(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db, withMysqlAcceptable())
for i := 0; i < 1000; i++ {
rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val int
if !assert.NotEqual(t, breaker.ErrServiceUnavailable, conn.QueryRow(&val, "any")) {
break
}
}
})
}
func buildConn() (mock sqlmock.Sqlmock, err error) { func buildConn() (mock sqlmock.Sqlmock, err error) {
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) { _, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB var db *sql.DB

@ -51,7 +51,13 @@ func escape(input string) string {
return b.String() return b.String()
} }
func format(query string, args ...any) (string, error) { func format(query string, args ...any) (val string, err error) {
defer func() {
if err != nil {
err = newAcceptableError(err)
}
}()
numArgs := len(args) numArgs := len(args)
if numArgs == 0 { if numArgs == 0 {
return query, nil return query, nil
@ -66,7 +72,8 @@ func format(query string, args ...any) (string, error) {
switch ch { switch ch {
case '?': case '?':
if argIndex >= numArgs { if argIndex >= numArgs {
return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex) return "", fmt.Errorf("%d ? in sql, but only %d arguments provided",
argIndex+1, numArgs)
} }
writeValue(&b, args[argIndex]) writeValue(&b, args[argIndex])
@ -165,3 +172,17 @@ func writeValue(buf *strings.Builder, arg any) {
buf.WriteString(mapping.Repr(v)) buf.WriteString(mapping.Repr(v))
} }
} }
type acceptableError struct {
err error
}
func newAcceptableError(err error) error {
return acceptableError{
err: err,
}
}
func (e acceptableError) Error() string {
return e.err.Error()
}

Loading…
Cancel
Save