diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index 216e3634..d11a6060 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -158,7 +158,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map refValue := reflect.ValueOf(mapValue) if refValue.Kind() != reflect.Slice { - return fmt.Errorf("%s: %v", fullName, errTypeMismatch) + return newTypeMismatchErrorWithHint(fullName, reflect.Slice.String(), refValue.Type().String()) } if refValue.IsNil() { return nil @@ -180,9 +180,9 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map continue } + valid = true sliceFullName := fmt.Sprintf("%s[%d]", fullName, i) - valid = true switch dereffedBaseKind { case reflect.Struct: target := reflect.New(dereffedBaseType) @@ -319,7 +319,6 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any, for _, key := range refValue.MapKeys() { keythValue := refValue.MapIndex(key) keythData := keythValue.Interface() - mapFullName := fmt.Sprintf("%s[%s]", fullName, key.String()) switch dereffedElemKind { diff --git a/core/mapping/unmarshaler_test.go b/core/mapping/unmarshaler_test.go index 14a132fb..c65e69ff 100644 --- a/core/mapping/unmarshaler_test.go +++ b/core/mapping/unmarshaler_test.go @@ -5081,6 +5081,17 @@ func TestGetValueWithChainedKeys(t *testing.T) { }) } +func TestUnmarshalFromStringSliceForTypeMismatch(t *testing.T) { + var v struct { + Values map[string][]string `key:"values"` + } + assert.Error(t, UnmarshalKey(map[string]any{ + "values": map[string]any{ + "foo": "bar", + }, + }, &v)) +} + func BenchmarkDefaultValue(b *testing.B) { for i := 0; i < b.N; i++ { var a struct { diff --git a/core/mapping/utils.go b/core/mapping/utils.go index 931d3169..ae558959 100644 --- a/core/mapping/utils.go +++ b/core/mapping/utils.go @@ -103,21 +103,21 @@ func convertTypeFromString(kind reflect.Kind, str string) (any, error) { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: intValue, err := strconv.ParseInt(str, 10, 64) if err != nil { - return 0, fmt.Errorf("the value %q cannot parsed as int", str) + return 0, fmt.Errorf("the value %q cannot be parsed as int", str) } return intValue, nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: uintValue, err := strconv.ParseUint(str, 10, 64) if err != nil { - return 0, fmt.Errorf("the value %q cannot parsed as uint", str) + return 0, fmt.Errorf("the value %q cannot be parsed as uint", str) } return uintValue, nil case reflect.Float32, reflect.Float64: floatValue, err := strconv.ParseFloat(str, 64) if err != nil { - return 0, fmt.Errorf("the value %q cannot parsed as float", str) + return 0, fmt.Errorf("the value %q cannot be parsed as float", str) } return floatValue, nil diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go index fdb9d2be..62b5936c 100644 --- a/core/stores/sqlx/sqlconn.go +++ b/core/stores/sqlx/sqlconn.go @@ -291,12 +291,19 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex } func (db *commonSqlConn) acceptable(err error) bool { - ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled + if err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled { + return true + } + + if _, ok := err.(acceptableError); ok { + return true + } + if db.accept == nil { - return ok + return false } - return ok || db.accept(err) + return db.accept(err) } func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error, diff --git a/core/stores/sqlx/sqlconn_test.go b/core/stores/sqlx/sqlconn_test.go index bbe15a92..c05af3cf 100644 --- a/core/stores/sqlx/sqlconn_test.go +++ b/core/stores/sqlx/sqlconn_test.go @@ -236,6 +236,33 @@ func TestStatement(t *testing.T) { }) } +func TestBreakerWithFormatError(t *testing.T) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + conn := NewSqlConnFromDB(db, withMysqlAcceptable()) + for i := 0; i < 1000; i++ { + var val string + if !assert.NotEqual(t, breaker.ErrServiceUnavailable, + conn.QueryRow(&val, "any ?, ?", "foo")) { + break + } + } + }) +} + +func TestBreakerWithScanError(t *testing.T) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + conn := NewSqlConnFromDB(db, withMysqlAcceptable()) + for i := 0; i < 1000; i++ { + rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(rows) + var val int + if !assert.NotEqual(t, breaker.ErrServiceUnavailable, conn.QueryRow(&val, "any")) { + break + } + } + }) +} + func buildConn() (mock sqlmock.Sqlmock, err error) { _, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) { var db *sql.DB diff --git a/core/stores/sqlx/utils.go b/core/stores/sqlx/utils.go index dae12faf..c5944517 100644 --- a/core/stores/sqlx/utils.go +++ b/core/stores/sqlx/utils.go @@ -51,7 +51,13 @@ func escape(input string) string { return b.String() } -func format(query string, args ...any) (string, error) { +func format(query string, args ...any) (val string, err error) { + defer func() { + if err != nil { + err = newAcceptableError(err) + } + }() + numArgs := len(args) if numArgs == 0 { return query, nil @@ -66,7 +72,8 @@ func format(query string, args ...any) (string, error) { switch ch { case '?': if argIndex >= numArgs { - return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex) + return "", fmt.Errorf("%d ? in sql, but only %d arguments provided", + argIndex+1, numArgs) } writeValue(&b, args[argIndex]) @@ -165,3 +172,17 @@ func writeValue(buf *strings.Builder, arg any) { buf.WriteString(mapping.Repr(v)) } } + +type acceptableError struct { + err error +} + +func newAcceptableError(err error) error { + return acceptableError{ + err: err, + } +} + +func (e acceptableError) Error() string { + return e.err.Error() +}