From cca45be3c5906493af9943b485b6f72edc784873 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sat, 11 Mar 2023 18:03:20 +0800 Subject: [PATCH] chore: refactor orm code (#3015) --- core/stores/sqlx/orm.go | 19 +++++++++------ core/stores/sqlx/orm_test.go | 47 ++++++++++++++++++++++++++---------- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/core/stores/sqlx/orm.go b/core/stores/sqlx/orm.go index a4c71d3d..956a41a6 100644 --- a/core/stores/sqlx/orm.go +++ b/core/stores/sqlx/orm.go @@ -34,18 +34,21 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) { result := make(map[string]any, size) for i := 0; i < size; i++ { - if (rt.Field(i).Type.Kind() == reflect.Struct || rt.Field(i).Type.Kind() == reflect.Ptr) && rt.Field(i).Anonymous { - r, e := getTaggedFieldValueMap(reflect.Indirect(v).Field(i)) - if e != nil { - return nil, e + field := rt.Field(i) + if field.Anonymous && mapping.Deref(field.Type).Kind() == reflect.Struct { + inner, err := getTaggedFieldValueMap(reflect.Indirect(v).Field(i)) + if err != nil { + return nil, err } - for i2, i3 := range r { - result[i2] = i3 + + for key, val := range inner { + result[key] = val } + continue } - key := parseTagName(rt.Field(i)) + key := parseTagName(field) if len(key) == 0 { continue } @@ -125,7 +128,7 @@ func parseTagName(field reflect.StructField) string { } options := strings.Split(key, ",") - return options[0] + return strings.TrimSpace(options[0]) } func unmarshalRow(v any, scanner rowsScanner, strict bool) error { diff --git a/core/stores/sqlx/orm_test.go b/core/stores/sqlx/orm_test.go index c33ca7d6..2efd04d6 100644 --- a/core/stores/sqlx/orm_test.go +++ b/core/stores/sqlx/orm_test.go @@ -1069,19 +1069,19 @@ func TestAnonymousStructPr(t *testing.T) { String: "", Valid: false, }, - ClassName: "实验班", - Discipline: "数学", + ClassName: "experimental class", + Discipline: "math", Score: 100, }, { Name: "second", Age: 3, Grade: sql.NullString{ - String: "大一", + String: "grade one", Valid: true, }, - ClassName: "三班二年", - Discipline: "语文", + ClassName: "class three grade two", + Discipline: "chinese", Score: 99, }, } @@ -1092,12 +1092,22 @@ func TestAnonymousStructPr(t *testing.T) { } runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - rs := sqlmock.NewRows([]string{"name", "age", "grade", "discipline", "class_name", "score"}).AddRow("first", 2, nil, "数学", "实验班", 100). - AddRow("second", 3, "大一", "语文", "三班二年", 99) - mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + rs := sqlmock.NewRows([]string{ + "name", + "age", + "grade", + "discipline", + "class_name", + "score", + }). + AddRow("first", 2, nil, "math", "experimental class", 100). + AddRow("second", 3, "grade one", "chinese", "class three grade two", 99) + mock.ExpectQuery("select (.+) from users where user=?"). + WithArgs("anyone").WillReturnRows(rs) assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) - }, "select name, age,grade,discipline,class_name,score from users where user=?", "anyone")) + }, "select name, age,grade,discipline,class_name,score from users where user=?", + "anyone")) for i, each := range expect { assert.Equal(t, each.Name, value[i].Name) @@ -1109,6 +1119,7 @@ func TestAnonymousStructPr(t *testing.T) { } }) } + func TestAnonymousStructPrError(t *testing.T) { type Score struct { Discipline string `db:"discipline"` @@ -1129,12 +1140,22 @@ func TestAnonymousStructPrError(t *testing.T) { } runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - rs := sqlmock.NewRows([]string{"name", "age", "grade", "discipline", "class_name", "score"}).AddRow("first", 2, nil, "数学", "实验班", 100). - AddRow("second", 3, "大一", "语文", "三班二年", 99) - mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + rs := sqlmock.NewRows([]string{ + "name", + "age", + "grade", + "discipline", + "class_name", + "score", + }). + AddRow("first", 2, nil, "math", "experimental class", 100). + AddRow("second", 3, "grade one", "chinese", "class three grade two", 99) + mock.ExpectQuery("select (.+) from users where user=?"). + WithArgs("anyone").WillReturnRows(rs) assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) - }, "select name, age,grade,discipline,class_name,score from users where user=?", "anyone")) + }, "select name, age,grade,discipline,class_name,score from users where user=?", + "anyone")) if len(value) > 0 { assert.Equal(t, value[0].score, 0) }