diff --git a/tools/goctl/model/sql/builderx/builder.go b/tools/goctl/model/sql/builderx/builder.go index 3ce38604..01ce067c 100644 --- a/tools/goctl/model/sql/builderx/builder.go +++ b/tools/goctl/model/sql/builderx/builder.go @@ -61,9 +61,9 @@ func FieldNames(in interface{}) []string { // gets us a StructField fi := typ.Field(i) if tagv := fi.Tag.Get(dbTag); tagv != "" { - out = append(out, tagv) + out = append(out, fmt.Sprintf("`%v`", tagv)) } else { - out = append(out, fi.Name) + out = append(out, fmt.Sprintf("`%v`", fi.Name)) } } return out diff --git a/tools/goctl/model/sql/builderx/builder_test.go b/tools/goctl/model/sql/builderx/builder_test.go index 28d1db01..1bbba6a5 100644 --- a/tools/goctl/model/sql/builderx/builder_test.go +++ b/tools/goctl/model/sql/builderx/builder_test.go @@ -28,8 +28,7 @@ var userFields = FieldNames(User{}) func TestFieldNames(t *testing.T) { var u User out := FieldNames(&u) - fmt.Println(out) - actual := []string{"id", "user_name", "sex", "uuid", "age"} + actual := []string{"`id`", "`user_name`", "`sex`", "`uuid`", "`age`"} assert.Equal(t, out, actual) } @@ -54,7 +53,7 @@ func TestBuilderSql(t *testing.T) { sql, args, err := builder.Select(fields...).From("user").Where(eq).ToSQL() fmt.Println(sql, args, err) - actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE id=?" + actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE id=?" actualArgs := []interface{}{"123123"} assert.Equal(t, sql, actualSql) assert.Equal(t, args, actualArgs) @@ -68,7 +67,7 @@ func TestBuildSqlDefaultValue(t *testing.T) { sql, args, err := builder.Select(userFields...).From("user").Where(eq).ToSQL() fmt.Println(sql, args, err) - actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE age=? AND user_name=?" + actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE age=? AND user_name=?" actualArgs := []interface{}{0, ""} assert.Equal(t, sql, actualSql) assert.Equal(t, args, actualArgs) @@ -83,7 +82,7 @@ func TestBuilderSqlIn(t *testing.T) { sql, args, err := builder.Select(userFields...).From("user").Where(in).And(gtU).ToSQL() fmt.Println(sql, args, err) - actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE id IN (?,?,?) AND age>?" + actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE id IN (?,?,?) AND age>?" actualArgs := []interface{}{"1", "2", "3", 18} assert.Equal(t, sql, actualSql) assert.Equal(t, args, actualArgs) @@ -94,7 +93,7 @@ func TestBuildSqlLike(t *testing.T) { sql, args, err := builder.Select(userFields...).From("user").Where(like).ToSQL() fmt.Println(sql, args, err) - actualSql := "SELECT id,user_name,sex,uuid,age FROM user WHERE name LIKE ?" + actualSql := "SELECT `id`,`user_name`,`sex`,`uuid`,`age` FROM user WHERE name LIKE ?" actualArgs := []interface{}{"%wang%"} assert.Equal(t, sql, actualSql) assert.Equal(t, args, actualArgs) diff --git a/tools/goctl/model/sql/example/makefile b/tools/goctl/model/sql/example/makefile index 135d3fdd..7da64fe4 100644 --- a/tools/goctl/model/sql/example/makefile +++ b/tools/goctl/model/sql/example/makefile @@ -1,8 +1,13 @@ #!/bin/bash # generate model with cache from ddl -fromDDL: - goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/user" -cache +fromDDLWithCache: + goctl template clean; + goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/cache/user" -cache; + +fromDDLWithoutCache: + goctl template clean; + goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/nocache/user"; # generate model with cache from data source @@ -12,4 +17,5 @@ datasource=127.0.0.1:3306 database=gozero fromDataSource: - goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style gozero \ No newline at end of file + goctl template clean; + goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style gozero; \ No newline at end of file diff --git a/tools/goctl/model/sql/gen/delete.go b/tools/goctl/model/sql/gen/delete.go index a14c3171..f79773d8 100644 --- a/tools/goctl/model/sql/gen/delete.go +++ b/tools/goctl/model/sql/gen/delete.go @@ -36,7 +36,7 @@ func genDelete(table Table, withCache bool) (string, string, error) { "lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(), "dataType": table.PrimaryKey.DataType, "keys": strings.Join(keySet.KeysStr(), "\n"), - "originalPrimaryKey": table.PrimaryKey.Name.Source(), + "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), "keyValues": strings.Join(keyVariableSet.KeysStr(), ", "), }) if err != nil { diff --git a/tools/goctl/model/sql/gen/findone.go b/tools/goctl/model/sql/gen/findone.go index f7c157a2..d6ff0367 100644 --- a/tools/goctl/model/sql/gen/findone.go +++ b/tools/goctl/model/sql/gen/findone.go @@ -19,7 +19,7 @@ func genFindOne(table Table, withCache bool) (string, string, error) { "withCache": withCache, "upperStartCamelObject": camel, "lowerStartCamelObject": stringx.From(camel).Untitle(), - "originalPrimaryKey": table.PrimaryKey.Name.Source(), + "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), "lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(), "dataType": table.PrimaryKey.DataType, "cacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].KeyExpression, diff --git a/tools/goctl/model/sql/gen/findonebyfield.go b/tools/goctl/model/sql/gen/findonebyfield.go index 6adea253..c68febcf 100644 --- a/tools/goctl/model/sql/gen/findonebyfield.go +++ b/tools/goctl/model/sql/gen/findonebyfield.go @@ -39,7 +39,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { "lowerStartCamelObject": stringx.From(camelTableName).Untitle(), "lowerStartCamelField": stringx.From(camelFieldName).Untitle(), "upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(), - "originalField": field.Name.Source(), + "originalField": wrapWithRawString(field.Name.Source()), }) if err != nil { return nil, err @@ -82,7 +82,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { "upperStartCamelObject": camelTableName, "primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left, "lowerStartCamelObject": stringx.From(camelTableName).Untitle(), - "originalPrimaryField": table.PrimaryKey.Name.Source(), + "originalPrimaryField": wrapWithRawString(table.PrimaryKey.Name.Source()), }) if err != nil { return nil, err diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index f78753fc..d21df855 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -21,9 +21,6 @@ import ( const ( pwd = "." createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case - NamingLower = "lower" - NamingCamel = "camel" - NamingSnake = "snake" ) type ( @@ -280,3 +277,20 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er return output.String(), nil } + +func wrapWithRawString(v string) string { + if v == "`" { + return v + } + + if !strings.HasPrefix(v, "`") { + v = "`" + v + } + + if !strings.HasSuffix(v, "`") { + v = v + "`" + } else if len(v) == 1 { + v = v + "`" + } + return v +} diff --git a/tools/goctl/model/sql/gen/gen_test.go b/tools/goctl/model/sql/gen/gen_test.go index 7de3f1e7..348ad9cd 100644 --- a/tools/goctl/model/sql/gen/gen_test.go +++ b/tools/goctl/model/sql/gen/gen_test.go @@ -1,13 +1,18 @@ package gen import ( + "database/sql" "os" "path/filepath" + "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/tal-tech/go-zero/core/logx" + "github.com/tal-tech/go-zero/core/stringx" "github.com/tal-tech/go-zero/tools/goctl/config" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx" ) var ( @@ -79,3 +84,32 @@ func TestNamingModel(t *testing.T) { return err == nil }()) } + +func TestWrapWithRawString(t *testing.T) { + assert.Equal(t, "``", wrapWithRawString("")) + assert.Equal(t, "``", wrapWithRawString("``")) + assert.Equal(t, "`a`", wrapWithRawString("a")) + assert.Equal(t, "` `", wrapWithRawString(" ")) +} + +func TestFields(t *testing.T) { + type Student struct { + Id int64 `db:"id"` + Name string `db:"name"` + Age sql.NullInt64 `db:"age"` + Score sql.NullFloat64 `db:"score"` + CreateTime time.Time `db:"create_time"` + UpdateTime sql.NullTime `db:"update_time"` + } + var ( + studentFieldNames = builderx.FieldNames(&Student{}) + studentRows = strings.Join(studentFieldNames, ",") + studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",") + studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?" + ) + + assert.Equal(t, []string{"`id`", "`name`", "`age`", "`score`", "`create_time`", "`update_time`"}, studentFieldNames) + assert.Equal(t, "`id`,`name`,`age`,`score`,`create_time`,`update_time`", studentRows) + assert.Equal(t, "`name`,`age`,`score`", studentRowsExpectAutoSet) + assert.Equal(t, "`name`=?,`age`=?,`score`=?", studentRowsWithPlaceHolder) +} diff --git a/tools/goctl/model/sql/gen/new.go b/tools/goctl/model/sql/gen/new.go index 6976ffc9..79d72d3e 100644 --- a/tools/goctl/model/sql/gen/new.go +++ b/tools/goctl/model/sql/gen/new.go @@ -14,7 +14,7 @@ func genNew(table Table, withCache bool) (string, error) { output, err := util.With("new"). Parse(text). Execute(map[string]interface{}{ - "table": table.Name.Source(), + "table": wrapWithRawString(table.Name.Source()), "withCache": withCache, "upperStartCamelObject": table.Name.ToCamel(), }) diff --git a/tools/goctl/model/sql/gen/update.go b/tools/goctl/model/sql/gen/update.go index 6194b388..3eca4135 100644 --- a/tools/goctl/model/sql/gen/update.go +++ b/tools/goctl/model/sql/gen/update.go @@ -35,7 +35,7 @@ func genUpdate(table Table, withCache bool) (string, string, error) { "primaryCacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].DataKeyExpression, "primaryKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable, "lowerStartCamelObject": stringx.From(camelTableName).Untitle(), - "originalPrimaryKey": table.PrimaryKey.Name.Source(), + "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), "expressionValues": strings.Join(expressionValues, ", "), }) if err != nil { diff --git a/tools/goctl/model/sql/gen/vars.go b/tools/goctl/model/sql/gen/vars.go index 78d49641..6ec6ee49 100644 --- a/tools/goctl/model/sql/gen/vars.go +++ b/tools/goctl/model/sql/gen/vars.go @@ -27,7 +27,7 @@ func genVars(table Table, withCache bool) (string, error) { "upperStartCamelObject": camel, "cacheKeys": strings.Join(keys, "\n"), "autoIncrement": table.PrimaryKey.AutoIncrement, - "originalPrimaryKey": table.PrimaryKey.Name.Source(), + "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), "withCache": withCache, }) if err != nil { diff --git a/tools/goctl/model/sql/model/ddlmodel.go b/tools/goctl/model/sql/model/ddlmodel.go deleted file mode 100644 index a7c3a38f..00000000 --- a/tools/goctl/model/sql/model/ddlmodel.go +++ /dev/null @@ -1,34 +0,0 @@ -package model - -import ( - "github.com/tal-tech/go-zero/core/stores/sqlx" -) - -type ( - DDLModel struct { - conn sqlx.SqlConn - } - DDL struct { - Table string `db:"Table"` - DDL string `db:"Create Table"` - } -) - -func NewDDLModel(conn sqlx.SqlConn) *DDLModel { - return &DDLModel{conn: conn} -} - -func (m *DDLModel) ShowDDL(table ...string) ([]string, error) { - var ddl []string - for _, t := range table { - query := `show create table ` + t - var resp DDL - err := m.conn.QueryRow(&resp, query) - if err != nil { - return nil, err - } - - ddl = append(ddl, resp.DDL) - } - return ddl, nil -} diff --git a/tools/goctl/model/sql/template/vars.go b/tools/goctl/model/sql/template/vars.go index 9d0da1b8..254a6e07 100644 --- a/tools/goctl/model/sql/template/vars.go +++ b/tools/goctl/model/sql/template/vars.go @@ -1,12 +1,14 @@ package template -var Vars = ` +import "fmt" + +var Vars = fmt.Sprintf(` var ( {{.lowerStartCamelObject}}FieldNames = builderx.FieldNames(&{{.upperStartCamelObject}}{}) {{.lowerStartCamelObject}}Rows = strings.Join({{.lowerStartCamelObject}}FieldNames, ",") - {{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "create_time", "update_time"), ",") - {{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "create_time", "update_time"), "=?,") + "=?" + {{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ",") + {{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "%screate_time%s", "%supdate_time%s"), "=?,") + "=?" {{if .withCache}}{{.cacheKeys}}{{end}} ) -` +`, "`", "`", "`", "`", "`", "`", "`", "`") diff --git a/tools/goctl/model/sql/test/model/model_test.go b/tools/goctl/model/sql/test/model/model_test.go new file mode 100644 index 00000000..9542c429 --- /dev/null +++ b/tools/goctl/model/sql/test/model/model_test.go @@ -0,0 +1,235 @@ +package model + +import ( + "database/sql" + "fmt" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/stores/cache" + "github.com/tal-tech/go-zero/core/stores/redis" + "github.com/tal-tech/go-zero/core/stores/redis/redistest" + mocksql "github.com/tal-tech/go-zero/tools/goctl/model/sql/test" +) + +func TestStudentModel(t *testing.T) { + var ( + testTimeValue = time.Now() + testTable = "`student`" + testUpdateName = "gozero1" + testRowsAffected int64 = 1 + testInsertId int64 = 1 + ) + + var data Student + data.Id = testInsertId + data.Name = "gozero" + data.Age = sql.NullInt64{ + Int64: 1, + Valid: true, + } + data.Score = sql.NullFloat64{ + Float64: 100, + Valid: true, + } + data.CreateTime = testTimeValue + data.UpdateTime = sql.NullTime{ + Time: testTimeValue, + Valid: true, + } + + err := mockStudent(func(mock sqlmock.Sqlmock) { + mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)). + WithArgs(data.Name, data.Age, data.Score). + WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected)) + }, func(m StudentModel) { + r, err := m.Insert(data) + assert.Nil(t, err) + + lastInsertId, err := r.LastInsertId() + assert.Nil(t, err) + assert.Equal(t, testInsertId, lastInsertId) + + rowsAffected, err := r.RowsAffected() + assert.Nil(t, err) + assert.Equal(t, testRowsAffected, rowsAffected) + }) + assert.Nil(t, err) + + err = mockStudent(func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)). + WithArgs(testInsertId). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Name, data.Age, data.Score, testTimeValue, testTimeValue)) + }, func(m StudentModel) { + result, err := m.FindOne(testInsertId) + assert.Nil(t, err) + assert.Equal(t, *result, data) + }) + assert.Nil(t, err) + + err = mockStudent(func(mock sqlmock.Sqlmock) { + mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(testUpdateName, data.Age, data.Score, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected)) + }, func(m StudentModel) { + data.Name = testUpdateName + err := m.Update(data) + assert.Nil(t, err) + }) + assert.Nil(t, err) + + err = mockStudent(func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)). + WithArgs(testInsertId). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Name, data.Age, data.Score, testTimeValue, testTimeValue)) + }, func(m StudentModel) { + result, err := m.FindOne(testInsertId) + assert.Nil(t, err) + assert.Equal(t, *result, data) + }) + assert.Nil(t, err) + + err = mockStudent(func(mock sqlmock.Sqlmock) { + mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected)) + }, func(m StudentModel) { + err := m.Delete(testInsertId) + assert.Nil(t, err) + }) + assert.Nil(t, err) +} + +func TestUserModel(t *testing.T) { + var ( + testTimeValue = time.Now() + testTable = "`user`" + testUpdateName = "gozero1" + testUser = "gozero" + testPassword = "test" + testMobile = "test_mobile" + testGender = "男" + testNickname = "test_nickname" + testRowsAffected int64 = 1 + testInsertId int64 = 1 + ) + + var data User + data.Id = testInsertId + data.User = testUser + data.Name = "gozero" + data.Password = testPassword + data.Mobile = testMobile + data.Gender = testGender + data.Nickname = testNickname + data.CreateTime = testTimeValue + data.UpdateTime = testTimeValue + + err := mockUser(func(mock sqlmock.Sqlmock) { + mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)). + WithArgs(data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname). + WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected)) + }, func(m UserModel) { + r, err := m.Insert(data) + assert.Nil(t, err) + + lastInsertId, err := r.LastInsertId() + assert.Nil(t, err) + assert.Equal(t, testInsertId, lastInsertId) + + rowsAffected, err := r.RowsAffected() + assert.Nil(t, err) + assert.Equal(t, testRowsAffected, rowsAffected) + }) + assert.Nil(t, err) + + err = mockUser(func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)). + WithArgs(testInsertId). + WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue)) + }, func(m UserModel) { + result, err := m.FindOne(testInsertId) + assert.Nil(t, err) + assert.Equal(t, *result, data) + }) + assert.Nil(t, err) + + err = mockUser(func(mock sqlmock.Sqlmock) { + mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(data.User, testUpdateName, data.Password, data.Mobile, data.Gender, data.Nickname, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected)) + }, func(m UserModel) { + data.Name = testUpdateName + err := m.Update(data) + assert.Nil(t, err) + }) + assert.Nil(t, err) + + err = mockUser(func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)). + WithArgs(testInsertId). + WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue)) + }, func(m UserModel) { + result, err := m.FindOne(testInsertId) + assert.Nil(t, err) + assert.Equal(t, *result, data) + }) + assert.Nil(t, err) + + err = mockUser(func(mock sqlmock.Sqlmock) { + mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected)) + }, func(m UserModel) { + err := m.Delete(testInsertId) + assert.Nil(t, err) + }) + assert.Nil(t, err) +} + +// with cache +func mockStudent(mockFn func(mock sqlmock.Sqlmock), fn func(m StudentModel)) error { + db, mock, err := sqlmock.New() + if err != nil { + return err + } + + defer db.Close() + + mock.ExpectBegin() + mockFn(mock) + mock.ExpectCommit() + + conn := mocksql.NewMockConn(db) + r, clean, err := redistest.CreateRedis() + if err != nil { + return err + } + + defer clean() + + m := NewStudentModel(conn, cache.CacheConf{ + { + RedisConf: redis.RedisConf{ + Host: r.Addr, + Type: "node", + }, + Weight: 100, + }, + }) + fn(m) + return nil +} + +// without cache +func mockUser(mockFn func(mock sqlmock.Sqlmock), fn func(m UserModel)) error { + db, mock, err := sqlmock.New() + if err != nil { + return err + } + + defer db.Close() + + mock.ExpectBegin() + mockFn(mock) + mock.ExpectCommit() + + conn := mocksql.NewMockConn(db) + m := NewUserModel(conn) + fn(m) + return nil +} diff --git a/tools/goctl/model/sql/test/model/studentmodel.go b/tools/goctl/model/sql/test/model/studentmodel.go new file mode 100755 index 00000000..b11ec0ed --- /dev/null +++ b/tools/goctl/model/sql/test/model/studentmodel.go @@ -0,0 +1,105 @@ +package model + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/tal-tech/go-zero/core/stores/cache" + "github.com/tal-tech/go-zero/core/stores/sqlc" + "github.com/tal-tech/go-zero/core/stores/sqlx" + "github.com/tal-tech/go-zero/core/stringx" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx" +) + +var ( + studentFieldNames = builderx.FieldNames(&Student{}) + studentRows = strings.Join(studentFieldNames, ",") + studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",") + studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?" + + cacheStudentIdPrefix = "cache#Student#id#" +) + +type ( + StudentModel interface { + Insert(data Student) (sql.Result, error) + FindOne(id int64) (*Student, error) + Update(data Student) error + Delete(id int64) error + } + + defaultStudentModel struct { + sqlc.CachedConn + table string + } + + Student struct { + Id int64 `db:"id"` + Name string `db:"name"` + Age sql.NullInt64 `db:"age"` + Score sql.NullFloat64 `db:"score"` + CreateTime time.Time `db:"create_time"` + UpdateTime sql.NullTime `db:"update_time"` + } +) + +func NewStudentModel(conn sqlx.SqlConn, c cache.CacheConf) StudentModel { + return &defaultStudentModel{ + CachedConn: sqlc.NewConn(conn, c), + table: "`student`", + } +} + +func (m *defaultStudentModel) Insert(data Student) (sql.Result, error) { + query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?)", m.table, studentRowsExpectAutoSet) + ret, err := m.ExecNoCache(query, data.Name, data.Age, data.Score) + + return ret, err +} + +func (m *defaultStudentModel) FindOne(id int64) (*Student, error) { + studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id) + var resp Student + err := m.QueryRow(&resp, studentIdKey, func(conn sqlx.SqlConn, v interface{}) error { + query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table) + return conn.QueryRow(v, query, id) + }) + switch err { + case nil: + return &resp, nil + case sqlc.ErrNotFound: + return nil, ErrNotFound + default: + return nil, err + } +} + +func (m *defaultStudentModel) Update(data Student) error { + studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, data.Id) + _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { + query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, studentRowsWithPlaceHolder) + return conn.Exec(query, data.Name, data.Age, data.Score, data.Id) + }, studentIdKey) + return err +} + +func (m *defaultStudentModel) Delete(id int64) error { + + studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id) + _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { + query := fmt.Sprintf("delete from %s where `id` = ?", m.table) + return conn.Exec(query, id) + }, studentIdKey) + return err +} + +func (m *defaultStudentModel) formatPrimary(primary interface{}) string { + return fmt.Sprintf("%s%v", cacheStudentIdPrefix, primary) +} + +func (m *defaultStudentModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error { + query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table) + return conn.QueryRow(v, query, primary) +} diff --git a/tools/goctl/model/sql/test/model/usermodel.go b/tools/goctl/model/sql/test/model/usermodel.go new file mode 100755 index 00000000..b6cf9a1f --- /dev/null +++ b/tools/goctl/model/sql/test/model/usermodel.go @@ -0,0 +1,130 @@ +package model + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/tal-tech/go-zero/core/stores/sqlc" + "github.com/tal-tech/go-zero/core/stores/sqlx" + "github.com/tal-tech/go-zero/core/stringx" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx" +) + +var ( + userFieldNames = builderx.FieldNames(&User{}) + userRows = strings.Join(userFieldNames, ",") + userRowsExpectAutoSet = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), ",") + userRowsWithPlaceHolder = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?" +) + +type ( + UserModel interface { + Insert(data User) (sql.Result, error) + FindOne(id int64) (*User, error) + FindOneByUser(user string) (*User, error) + FindOneByName(name string) (*User, error) + FindOneByMobile(mobile string) (*User, error) + Update(data User) error + Delete(id int64) error + } + + defaultUserModel struct { + conn sqlx.SqlConn + table string + } + + User struct { + Id int64 `db:"id"` + User string `db:"user"` // 用户 + Name string `db:"name"` // 用户名称 + Password string `db:"password"` // 用户密码 + Mobile string `db:"mobile"` // 手机号 + Gender string `db:"gender"` // 男|女|未公开 + Nickname string `db:"nickname"` // 用户昵称 + CreateTime time.Time `db:"create_time"` + UpdateTime time.Time `db:"update_time"` + } +) + +func NewUserModel(conn sqlx.SqlConn) UserModel { + return &defaultUserModel{ + conn: conn, + table: "`user`", + } +} + +func (m *defaultUserModel) Insert(data User) (sql.Result, error) { + query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?, ?, ?)", m.table, userRowsExpectAutoSet) + ret, err := m.conn.Exec(query, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname) + return ret, err +} + +func (m *defaultUserModel) FindOne(id int64) (*User, error) { + query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", userRows, m.table) + var resp User + err := m.conn.QueryRow(&resp, query, id) + switch err { + case nil: + return &resp, nil + case sqlc.ErrNotFound: + return nil, ErrNotFound + default: + return nil, err + } +} + +func (m *defaultUserModel) FindOneByUser(user string) (*User, error) { + var resp User + query := fmt.Sprintf("select %s from %s where `user` = ? limit 1", userRows, m.table) + err := m.conn.QueryRow(&resp, query, user) + switch err { + case nil: + return &resp, nil + case sqlc.ErrNotFound: + return nil, ErrNotFound + default: + return nil, err + } +} + +func (m *defaultUserModel) FindOneByName(name string) (*User, error) { + var resp User + query := fmt.Sprintf("select %s from %s where `name` = ? limit 1", userRows, m.table) + err := m.conn.QueryRow(&resp, query, name) + switch err { + case nil: + return &resp, nil + case sqlc.ErrNotFound: + return nil, ErrNotFound + default: + return nil, err + } +} + +func (m *defaultUserModel) FindOneByMobile(mobile string) (*User, error) { + var resp User + query := fmt.Sprintf("select %s from %s where `mobile` = ? limit 1", userRows, m.table) + err := m.conn.QueryRow(&resp, query, mobile) + switch err { + case nil: + return &resp, nil + case sqlc.ErrNotFound: + return nil, ErrNotFound + default: + return nil, err + } +} + +func (m *defaultUserModel) Update(data User) error { + query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, userRowsWithPlaceHolder) + _, err := m.conn.Exec(query, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, data.Id) + return err +} + +func (m *defaultUserModel) Delete(id int64) error { + query := fmt.Sprintf("delete from %s where `id` = ?", m.table) + _, err := m.conn.Exec(query, id) + return err +} diff --git a/tools/goctl/model/sql/test/model/vars.go b/tools/goctl/model/sql/test/model/vars.go new file mode 100644 index 00000000..b7a5e88f --- /dev/null +++ b/tools/goctl/model/sql/test/model/vars.go @@ -0,0 +1,5 @@ +package model + +import "github.com/tal-tech/go-zero/core/stores/sqlx" + +var ErrNotFound = sqlx.ErrNotFound diff --git a/tools/goctl/model/sql/test/orm.go b/tools/goctl/model/sql/test/orm.go new file mode 100644 index 00000000..dadfaa75 --- /dev/null +++ b/tools/goctl/model/sql/test/orm.go @@ -0,0 +1,255 @@ +// copy from core/stores/sqlx/orm.go +package mocksql + +import ( + "errors" + "reflect" + "strings" + + "github.com/tal-tech/go-zero/core/mapping" +) + +const tagName = "db" + +var ( + ErrNotMatchDestination = errors.New("not matching destination to scan") + ErrNotReadableValue = errors.New("value not addressable or interfaceable") + ErrNotSettable = errors.New("passed in variable is not settable") + ErrUnsupportedValueType = errors.New("unsupported unmarshal type") +) + +type rowsScanner interface { + Columns() ([]string, error) + Err() error + Next() bool + Scan(v ...interface{}) error +} + +func getTaggedFieldValueMap(v reflect.Value) (map[string]interface{}, error) { + rt := mapping.Deref(v.Type()) + size := rt.NumField() + result := make(map[string]interface{}, size) + + for i := 0; i < size; i++ { + key := parseTagName(rt.Field(i)) + if len(key) == 0 { + return nil, nil + } + + valueField := reflect.Indirect(v).Field(i) + switch valueField.Kind() { + case reflect.Ptr: + if !valueField.CanInterface() { + return nil, ErrNotReadableValue + } + if valueField.IsNil() { + baseValueType := mapping.Deref(valueField.Type()) + valueField.Set(reflect.New(baseValueType)) + } + result[key] = valueField.Interface() + default: + if !valueField.CanAddr() || !valueField.Addr().CanInterface() { + return nil, ErrNotReadableValue + } + result[key] = valueField.Addr().Interface() + } + } + + return result, nil +} + +func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]interface{}, error) { + fields := unwrapFields(v) + if strict && len(columns) < len(fields) { + return nil, ErrNotMatchDestination + } + + taggedMap, err := getTaggedFieldValueMap(v) + if err != nil { + return nil, err + } + + values := make([]interface{}, len(columns)) + if len(taggedMap) == 0 { + for i := 0; i < len(values); i++ { + valueField := fields[i] + switch valueField.Kind() { + case reflect.Ptr: + if !valueField.CanInterface() { + return nil, ErrNotReadableValue + } + if valueField.IsNil() { + baseValueType := mapping.Deref(valueField.Type()) + valueField.Set(reflect.New(baseValueType)) + } + values[i] = valueField.Interface() + default: + if !valueField.CanAddr() || !valueField.Addr().CanInterface() { + return nil, ErrNotReadableValue + } + values[i] = valueField.Addr().Interface() + } + } + } else { + for i, column := range columns { + if tagged, ok := taggedMap[column]; ok { + values[i] = tagged + } else { + var anonymous interface{} + values[i] = &anonymous + } + } + } + + return values, nil +} + +func parseTagName(field reflect.StructField) string { + key := field.Tag.Get(tagName) + if len(key) == 0 { + return "" + } else { + options := strings.Split(key, ",") + return options[0] + } +} + +func unmarshalRow(v interface{}, scanner rowsScanner, strict bool) error { + if !scanner.Next() { + if err := scanner.Err(); err != nil { + return err + } + return ErrNotFound + } + + rv := reflect.ValueOf(v) + if err := mapping.ValidatePtr(&rv); err != nil { + return err + } + + rte := reflect.TypeOf(v).Elem() + rve := rv.Elem() + switch rte.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, + reflect.String: + if rve.CanSet() { + return scanner.Scan(v) + } else { + return ErrNotSettable + } + case reflect.Struct: + columns, err := scanner.Columns() + if err != nil { + return err + } + if values, err := mapStructFieldsIntoSlice(rve, columns, strict); err != nil { + return err + } else { + return scanner.Scan(values...) + } + default: + return ErrUnsupportedValueType + } +} + +func unmarshalRows(v interface{}, scanner rowsScanner, strict bool) error { + rv := reflect.ValueOf(v) + if err := mapping.ValidatePtr(&rv); err != nil { + return err + } + + rt := reflect.TypeOf(v) + rte := rt.Elem() + rve := rv.Elem() + switch rte.Kind() { + case reflect.Slice: + if rve.CanSet() { + ptr := rte.Elem().Kind() == reflect.Ptr + appendFn := func(item reflect.Value) { + if ptr { + rve.Set(reflect.Append(rve, item)) + } else { + rve.Set(reflect.Append(rve, reflect.Indirect(item))) + } + } + fillFn := func(value interface{}) error { + if rve.CanSet() { + if err := scanner.Scan(value); err != nil { + return err + } else { + appendFn(reflect.ValueOf(value)) + return nil + } + } + return ErrNotSettable + } + + base := mapping.Deref(rte.Elem()) + switch base.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, + reflect.String: + for scanner.Next() { + value := reflect.New(base) + if err := fillFn(value.Interface()); err != nil { + return err + } + } + case reflect.Struct: + columns, err := scanner.Columns() + if err != nil { + return err + } + + for scanner.Next() { + value := reflect.New(base) + if values, err := mapStructFieldsIntoSlice(value, columns, strict); err != nil { + return err + } else { + if err := scanner.Scan(values...); err != nil { + return err + } else { + appendFn(value) + } + } + } + default: + return ErrUnsupportedValueType + } + + return nil + } else { + return ErrNotSettable + } + default: + return ErrUnsupportedValueType + } +} + +func unwrapFields(v reflect.Value) []reflect.Value { + var fields []reflect.Value + indirect := reflect.Indirect(v) + + for i := 0; i < indirect.NumField(); i++ { + child := indirect.Field(i) + if child.Kind() == reflect.Ptr && child.IsNil() { + baseValueType := mapping.Deref(child.Type()) + child.Set(reflect.New(baseValueType)) + } + + child = reflect.Indirect(child) + childType := indirect.Type().Field(i) + if child.Kind() == reflect.Struct && childType.Anonymous { + fields = append(fields, unwrapFields(child)...) + } else { + fields = append(fields, child) + } + } + + return fields +} diff --git a/tools/goctl/model/sql/test/sqlconn.go b/tools/goctl/model/sql/test/sqlconn.go new file mode 100644 index 00000000..dc6efc40 --- /dev/null +++ b/tools/goctl/model/sql/test/sqlconn.go @@ -0,0 +1,90 @@ +// copy from core/stores/sqlx/sqlconn.go +package mocksql + +import ( + "database/sql" + + "github.com/tal-tech/go-zero/core/stores/sqlx" +) + +type ( + MockConn struct { + db *sql.DB + } + statement struct { + stmt *sql.Stmt + } +) + +func NewMockConn(db *sql.DB) *MockConn { + return &MockConn{db: db} +} + +func (conn *MockConn) Exec(query string, args ...interface{}) (sql.Result, error) { + return exec(conn.db, query, args...) +} + +func (conn *MockConn) Prepare(query string) (sqlx.StmtSession, error) { + st, err := conn.db.Prepare(query) + return statement{stmt: st}, err +} + +func (conn *MockConn) QueryRow(v interface{}, q string, args ...interface{}) error { + return query(conn.db, func(rows *sql.Rows) error { + return unmarshalRow(v, rows, true) + }, q, args...) +} + +func (conn *MockConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error { + return query(conn.db, func(rows *sql.Rows) error { + return unmarshalRow(v, rows, false) + }, q, args...) +} + +func (conn *MockConn) QueryRows(v interface{}, q string, args ...interface{}) error { + return query(conn.db, func(rows *sql.Rows) error { + return unmarshalRows(v, rows, true) + }, q, args...) +} + +func (conn *MockConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { + return query(conn.db, func(rows *sql.Rows) error { + return unmarshalRows(v, rows, false) + }, q, args...) +} + +func (conn *MockConn) Transact(func(session sqlx.Session) error) error { + return nil +} + +func (s statement) Close() error { + return s.stmt.Close() +} + +func (s statement) Exec(args ...interface{}) (sql.Result, error) { + return execStmt(s.stmt, args...) +} + +func (s statement) QueryRow(v interface{}, args ...interface{}) error { + return queryStmt(s.stmt, func(rows *sql.Rows) error { + return unmarshalRow(v, rows, true) + }, args...) +} + +func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error { + return queryStmt(s.stmt, func(rows *sql.Rows) error { + return unmarshalRow(v, rows, false) + }, args...) +} + +func (s statement) QueryRows(v interface{}, args ...interface{}) error { + return queryStmt(s.stmt, func(rows *sql.Rows) error { + return unmarshalRows(v, rows, true) + }, args...) +} + +func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error { + return queryStmt(s.stmt, func(rows *sql.Rows) error { + return unmarshalRows(v, rows, false) + }, args...) +} diff --git a/tools/goctl/model/sql/test/stmt.go b/tools/goctl/model/sql/test/stmt.go new file mode 100644 index 00000000..47c81486 --- /dev/null +++ b/tools/goctl/model/sql/test/stmt.go @@ -0,0 +1,122 @@ +// copy from core/stores/sqlx/stmt.go + +package mocksql + +import ( + "database/sql" + "fmt" + "time" + + "github.com/tal-tech/go-zero/core/logx" + "github.com/tal-tech/go-zero/core/timex" +) + +const slowThreshold = time.Millisecond * 500 + +func exec(db *sql.DB, q string, args ...interface{}) (sql.Result, error) { + tx, err := db.Begin() + if err != nil { + return nil, err + } + + defer func() { + switch err { + case nil: + err = tx.Commit() + default: + tx.Rollback() + } + }() + + stmt, err := format(q, args...) + if err != nil { + return nil, err + } + + startTime := timex.Now() + result, err := tx.Exec(q, args...) + duration := timex.Since(startTime) + if duration > slowThreshold { + logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt) + } else { + logx.WithDuration(duration).Infof("sql exec: %s", stmt) + } + if err != nil { + logSqlError(stmt, err) + } + + return result, err +} + +func execStmt(conn *sql.Stmt, args ...interface{}) (sql.Result, error) { + stmt := fmt.Sprint(args...) + startTime := timex.Now() + result, err := conn.Exec(args...) + duration := timex.Since(startTime) + if duration > slowThreshold { + logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt) + } else { + logx.WithDuration(duration).Infof("sql execStmt: %s", stmt) + } + if err != nil { + logSqlError(stmt, err) + } + + return result, err +} + +func query(db *sql.DB, scanner func(*sql.Rows) error, q string, args ...interface{}) error { + tx, err := db.Begin() + if err != nil { + return err + } + + defer func() { + switch err { + case nil: + err = tx.Commit() + default: + tx.Rollback() + } + }() + + stmt, err := format(q, args...) + if err != nil { + return err + } + + startTime := timex.Now() + rows, err := tx.Query(q, args...) + duration := timex.Since(startTime) + if duration > slowThreshold { + logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt) + } else { + logx.WithDuration(duration).Infof("sql query: %s", stmt) + } + if err != nil { + logSqlError(stmt, err) + return err + } + defer rows.Close() + + return scanner(rows) +} + +func queryStmt(conn *sql.Stmt, scanner func(*sql.Rows) error, args ...interface{}) error { + stmt := fmt.Sprint(args...) + startTime := timex.Now() + rows, err := conn.Query(args...) + duration := timex.Since(startTime) + if duration > slowThreshold { + logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt) + } else { + logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt) + } + if err != nil { + logSqlError(stmt, err) + return err + } + defer rows.Close() + + return scanner(rows) +} diff --git a/tools/goctl/model/sql/test/utils.go b/tools/goctl/model/sql/test/utils.go new file mode 100644 index 00000000..95a9d9d1 --- /dev/null +++ b/tools/goctl/model/sql/test/utils.go @@ -0,0 +1,105 @@ +// copy from core/stores/sqlx/utils.go +package mocksql + +import ( + "database/sql" + "fmt" + "strings" + + "github.com/tal-tech/go-zero/core/logx" + "github.com/tal-tech/go-zero/core/mapping" +) + +var ErrNotFound = sql.ErrNoRows + +func desensitize(datasource string) string { + // remove account + pos := strings.LastIndex(datasource, "@") + if 0 <= pos && pos+1 < len(datasource) { + datasource = datasource[pos+1:] + } + + return datasource +} + +func escape(input string) string { + var b strings.Builder + + for _, ch := range input { + switch ch { + case '\x00': + b.WriteString(`\x00`) + case '\r': + b.WriteString(`\r`) + case '\n': + b.WriteString(`\n`) + case '\\': + b.WriteString(`\\`) + case '\'': + b.WriteString(`\'`) + case '"': + b.WriteString(`\"`) + case '\x1a': + b.WriteString(`\x1a`) + default: + b.WriteRune(ch) + } + } + + return b.String() +} + +func format(query string, args ...interface{}) (string, error) { + numArgs := len(args) + if numArgs == 0 { + return query, nil + } + + var b strings.Builder + argIndex := 0 + + for _, ch := range query { + if ch == '?' { + if argIndex >= numArgs { + return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex) + } + + arg := args[argIndex] + argIndex++ + + switch v := arg.(type) { + case bool: + if v { + b.WriteByte('1') + } else { + b.WriteByte('0') + } + case string: + b.WriteByte('\'') + b.WriteString(escape(v)) + b.WriteByte('\'') + default: + b.WriteString(mapping.Repr(v)) + } + } else { + b.WriteRune(ch) + } + } + + if argIndex < numArgs { + return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex) + } + + return b.String(), nil +} + +func logInstanceError(datasource string, err error) { + datasource = desensitize(datasource) + logx.Errorf("Error on getting sql instance of %s: %v", datasource, err) +} + +func logSqlError(stmt string, err error) { + if err != nil && err != ErrNotFound { + logx.Errorf("stmt: %s, error: %s", stmt, err.Error()) + } +}