fix: format error should not trigger circuit breaker in sqlx (#3437)

master
Kevin Wan 1 year ago committed by GitHub
parent 05db706c62
commit ff04356704
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -158,7 +158,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
refValue := reflect.ValueOf(mapValue)
if refValue.Kind() != reflect.Slice {
return fmt.Errorf("%s: %v", fullName, errTypeMismatch)
return newTypeMismatchErrorWithHint(fullName, reflect.Slice.String(), refValue.Type().String())
}
if refValue.IsNil() {
return nil
@ -180,9 +180,9 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
continue
}
valid = true
sliceFullName := fmt.Sprintf("%s[%d]", fullName, i)
valid = true
switch dereffedBaseKind {
case reflect.Struct:
target := reflect.New(dereffedBaseType)
@ -319,7 +319,6 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any,
for _, key := range refValue.MapKeys() {
keythValue := refValue.MapIndex(key)
keythData := keythValue.Interface()
mapFullName := fmt.Sprintf("%s[%s]", fullName, key.String())
switch dereffedElemKind {

@ -5081,6 +5081,17 @@ func TestGetValueWithChainedKeys(t *testing.T) {
})
}
func TestUnmarshalFromStringSliceForTypeMismatch(t *testing.T) {
var v struct {
Values map[string][]string `key:"values"`
}
assert.Error(t, UnmarshalKey(map[string]any{
"values": map[string]any{
"foo": "bar",
},
}, &v))
}
func BenchmarkDefaultValue(b *testing.B) {
for i := 0; i < b.N; i++ {
var a struct {

@ -103,21 +103,21 @@ func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return 0, fmt.Errorf("the value %q cannot parsed as int", str)
return 0, fmt.Errorf("the value %q cannot be parsed as int", str)
}
return intValue, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uintValue, err := strconv.ParseUint(str, 10, 64)
if err != nil {
return 0, fmt.Errorf("the value %q cannot parsed as uint", str)
return 0, fmt.Errorf("the value %q cannot be parsed as uint", str)
}
return uintValue, nil
case reflect.Float32, reflect.Float64:
floatValue, err := strconv.ParseFloat(str, 64)
if err != nil {
return 0, fmt.Errorf("the value %q cannot parsed as float", str)
return 0, fmt.Errorf("the value %q cannot be parsed as float", str)
}
return floatValue, nil

@ -291,12 +291,19 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
}
func (db *commonSqlConn) acceptable(err error) bool {
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
if err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled {
return true
}
if _, ok := err.(acceptableError); ok {
return true
}
if db.accept == nil {
return ok
return false
}
return ok || db.accept(err)
return db.accept(err)
}
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,

@ -236,6 +236,33 @@ func TestStatement(t *testing.T) {
})
}
func TestBreakerWithFormatError(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db, withMysqlAcceptable())
for i := 0; i < 1000; i++ {
var val string
if !assert.NotEqual(t, breaker.ErrServiceUnavailable,
conn.QueryRow(&val, "any ?, ?", "foo")) {
break
}
}
})
}
func TestBreakerWithScanError(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db, withMysqlAcceptable())
for i := 0; i < 1000; i++ {
rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val int
if !assert.NotEqual(t, breaker.ErrServiceUnavailable, conn.QueryRow(&val, "any")) {
break
}
}
})
}
func buildConn() (mock sqlmock.Sqlmock, err error) {
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB

@ -51,7 +51,13 @@ func escape(input string) string {
return b.String()
}
func format(query string, args ...any) (string, error) {
func format(query string, args ...any) (val string, err error) {
defer func() {
if err != nil {
err = newAcceptableError(err)
}
}()
numArgs := len(args)
if numArgs == 0 {
return query, nil
@ -66,7 +72,8 @@ func format(query string, args ...any) (string, error) {
switch ch {
case '?':
if argIndex >= numArgs {
return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex)
return "", fmt.Errorf("%d ? in sql, but only %d arguments provided",
argIndex+1, numArgs)
}
writeValue(&b, args[argIndex])
@ -165,3 +172,17 @@ func writeValue(buf *strings.Builder, arg any) {
buf.WriteString(mapping.Repr(v))
}
}
type acceptableError struct {
err error
}
func newAcceptableError(err error) error {
return acceptableError{
err: err,
}
}
func (e acceptableError) Error() string {
return e.err.Error()
}

Loading…
Cancel
Save