From e735915d89a0cb961f9a2cd80443b65ff10ad58c Mon Sep 17 00:00:00 2001 From: "YK.xiong" <767374177@qq.com> Date: Sat, 11 Mar 2023 15:28:09 +0800 Subject: [PATCH] fix QueryRowsPartial getTaggedFieldValueMap func (#2884) Co-authored-by: yongkun.xiong --- core/stores/sqlx/orm.go | 13 ++++- core/stores/sqlx/orm_test.go | 100 +++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) diff --git a/core/stores/sqlx/orm.go b/core/stores/sqlx/orm.go index b6e1b3bd..a4c71d3d 100644 --- a/core/stores/sqlx/orm.go +++ b/core/stores/sqlx/orm.go @@ -34,9 +34,20 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) { result := make(map[string]any, size) for i := 0; i < size; i++ { + if (rt.Field(i).Type.Kind() == reflect.Struct || rt.Field(i).Type.Kind() == reflect.Ptr) && rt.Field(i).Anonymous { + r, e := getTaggedFieldValueMap(reflect.Indirect(v).Field(i)) + if e != nil { + return nil, e + } + for i2, i3 := range r { + result[i2] = i3 + } + continue + } + key := parseTagName(rt.Field(i)) if len(key) == 0 { - return nil, nil + continue } valueField := reflect.Indirect(v).Field(i) diff --git a/core/stores/sqlx/orm_test.go b/core/stores/sqlx/orm_test.go index c6066a76..c33ca7d6 100644 --- a/core/stores/sqlx/orm_test.go +++ b/core/stores/sqlx/orm_test.go @@ -1041,6 +1041,106 @@ func TestUnmarshalRowError(t *testing.T) { } } +func TestAnonymousStructPr(t *testing.T) { + type Score struct { + Discipline string `db:"discipline"` + Score uint `db:"score"` + } + type ClassType struct { + Grade sql.NullString `db:"grade"` + ClassName *string `db:"class_name"` + } + type Class struct { + *ClassType + Score + } + expect := []*struct { + Name string + Age int64 + Grade sql.NullString + Discipline string + Score uint + ClassName string + }{ + { + Name: "first", + Age: 2, + Grade: sql.NullString{ + String: "", + Valid: false, + }, + ClassName: "实验班", + Discipline: "数学", + Score: 100, + }, + { + Name: "second", + Age: 3, + Grade: sql.NullString{ + String: "大一", + Valid: true, + }, + ClassName: "三班二年", + Discipline: "语文", + Score: 99, + }, + } + var value []*struct { + Age int64 `db:"age"` + Class + Name string `db:"name"` + } + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "age", "grade", "discipline", "class_name", "score"}).AddRow("first", 2, nil, "数学", "实验班", 100). + AddRow("second", 3, "大一", "语文", "三班二年", 99) + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, age,grade,discipline,class_name,score from users where user=?", "anyone")) + + for i, each := range expect { + assert.Equal(t, each.Name, value[i].Name) + assert.Equal(t, each.Age, value[i].Age) + assert.Equal(t, each.ClassName, *value[i].Class.ClassName) + assert.Equal(t, each.Discipline, value[i].Score.Discipline) + assert.Equal(t, each.Score, value[i].Score.Score) + assert.Equal(t, each.Grade, value[i].Class.Grade) + } + }) +} +func TestAnonymousStructPrError(t *testing.T) { + type Score struct { + Discipline string `db:"discipline"` + score uint `db:"score"` + } + type ClassType struct { + Grade sql.NullString `db:"grade"` + ClassName *string `db:"class_name"` + } + type Class struct { + *ClassType + Score + } + var value []*struct { + Age int64 `db:"age"` + Class + Name string `db:"name"` + } + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "age", "grade", "discipline", "class_name", "score"}).AddRow("first", 2, nil, "数学", "实验班", 100). + AddRow("second", 3, "大一", "语文", "三班二年", 99) + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, age,grade,discipline,class_name,score from users where user=?", "anyone")) + if len(value) > 0 { + assert.Equal(t, value[0].score, 0) + } + }) +} + func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) { logx.Disable()