diff --git a/tools/goctl/model/sql/builderx/builder.go b/tools/goctl/model/sql/builderx/builder.go index 803f5b42..8e4ddd14 100644 --- a/tools/goctl/model/sql/builderx/builder.go +++ b/tools/goctl/model/sql/builderx/builder.go @@ -23,10 +23,12 @@ func ToMap(in interface{}) map[string]interface{} { if v.Kind() == reflect.Ptr { v = v.Elem() } + // we only accept structs if v.Kind() != reflect.Struct { panic(fmt.Errorf("ToMap only accepts structs; got %T", v)) } + typ := v.Type() for i := 0; i < v.NumField(); i++ { // gets us a StructField @@ -43,6 +45,7 @@ func ToMap(in interface{}) map[string]interface{} { out[tagv] = current } } + return out } @@ -53,10 +56,12 @@ func FieldNames(in interface{}) []string { if v.Kind() == reflect.Ptr { v = v.Elem() } + // we only accept structs if v.Kind() != reflect.Struct { panic(fmt.Errorf("ToMap only accepts structs; got %T", v)) } + typ := v.Type() for i := 0; i < v.NumField(); i++ { // gets us a StructField @@ -67,6 +72,7 @@ func FieldNames(in interface{}) []string { out = append(out, fi.Name) } } + return out } @@ -76,10 +82,12 @@ func RawFieldNames(in interface{}) []string { if v.Kind() == reflect.Ptr { v = v.Elem() } + // we only accept structs if v.Kind() != reflect.Struct { panic(fmt.Errorf("ToMap only accepts structs; got %T", v)) } + typ := v.Type() for i := 0; i < v.NumField(); i++ { // gets us a StructField @@ -90,5 +98,6 @@ func RawFieldNames(in interface{}) []string { out = append(out, fmt.Sprintf(`"%s"`, fi.Name)) } } + return out } diff --git a/tools/goctl/model/sql/builderx/builder_test.go b/tools/goctl/model/sql/builderx/builder_test.go index 09f9bc47..5b21136a 100644 --- a/tools/goctl/model/sql/builderx/builder_test.go +++ b/tools/goctl/model/sql/builderx/builder_test.go @@ -8,34 +8,32 @@ import ( "github.com/stretchr/testify/assert" ) -type ( - User struct { - // 自增id - Id string `db:"id" json:"id,omitempty"` - // 姓名 - UserName string `db:"user_name" json:"userName,omitempty"` - // 1男,2女 - Sex int `db:"sex" json:"sex,omitempty"` - - Uuid string `db:"uuid" uuid:"uuid,omitempty"` - - Age int `db:"age" json:"age"` - } -) +type mockedUser struct { + // 自增id + Id string `db:"id" json:"id,omitempty"` + // 姓名 + UserName string `db:"user_name" json:"userName,omitempty"` + // 1男,2女 + Sex int `db:"sex" json:"sex,omitempty"` + Uuid string `db:"uuid" uuid:"uuid,omitempty"` + Age int `db:"age" json:"age"` +} -var userFieldsWithRawStringQuote = RawFieldNames(User{}) -var userFieldsWithoutRawStringQuote = FieldNames(User{}) +var ( + userFieldsWithRawStringQuote = RawFieldNames(mockedUser{}) + userFieldsWithoutRawStringQuote = FieldNames(mockedUser{}) +) func TestFieldNames(t *testing.T) { t.Run("old", func(t *testing.T) { - var u User + var u mockedUser out := FieldNames(&u) expected := []string{"id", "user_name", "sex", "uuid", "age"} assert.Equal(t, expected, out) }) t.Run("new", func(t *testing.T) { - var u User + var u mockedUser out := RawFieldNames(&u) expected := []string{"`id`", "`user_name`", "`sex`", "`uuid`", "`age`"} assert.Equal(t, expected, out) @@ -43,7 +41,7 @@ func TestFieldNames(t *testing.T) { } func TestNewEq(t *testing.T) { - u := &User{ + u := &mockedUser{ Id: "123456", UserName: "wahaha", } @@ -55,7 +53,7 @@ func TestNewEq(t *testing.T) { // @see https://github.com/go-xorm/builder func TestBuilderSql(t *testing.T) { - u := &User{ + u := &mockedUser{ Id: "123123", } fields := RawFieldNames(u) @@ -96,7 +94,7 @@ func TestBuildSqlDefaultValue(t *testing.T) { } func TestBuilderSqlIn(t *testing.T) { - u := &User{ + u := &mockedUser{ Age: 18, } gtU := NewGt(u) diff --git a/tools/goctl/model/sql/command/command.go b/tools/goctl/model/sql/command/command.go index 87fde7aa..6a94ad3f 100644 --- a/tools/goctl/model/sql/command/command.go +++ b/tools/goctl/model/sql/command/command.go @@ -17,8 +17,6 @@ import ( "github.com/urfave/cli" ) -var errNotMatched = errors.New("sql not matched") - const ( flagSrc = "src" flagDir = "dir" @@ -29,6 +27,8 @@ const ( flagStyle = "style" ) +var errNotMatched = errors.New("sql not matched") + func MysqlDDL(ctx *cli.Context) error { src := ctx.String(flagSrc) dir := ctx.String(flagDir) @@ -39,6 +39,7 @@ func MysqlDDL(ctx *cli.Context) error { if err != nil { return err } + return fromDDl(src, dir, cfg, cache, idea) } @@ -82,13 +83,13 @@ func fromDDl(src, dir string, cfg *config.Config, cache, idea bool) error { source = append(source, string(data)) } + generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log)) if err != nil { return err } - err = generator.StartFromDDL(strings.Join(source, "\n"), cache) - return err + return generator.StartFromDDL(strings.Join(source, "\n"), cache) } func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bool) error { @@ -144,6 +145,5 @@ func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bo return err } - err = generator.StartFromInformationSchema(dsn.DBName, matchTables, cache) - return err + return generator.StartFromInformationSchema(dsn.DBName, matchTables, cache) } diff --git a/tools/goctl/model/sql/command/command_test.go b/tools/goctl/model/sql/command/command_test.go index 74da0a73..a78ae555 100644 --- a/tools/goctl/model/sql/command/command_test.go +++ b/tools/goctl/model/sql/command/command_test.go @@ -11,10 +11,12 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/util" ) -var sql = "-- 用户表 --\nCREATE TABLE `user` (\n `id` bigint(10) NOT NULL AUTO_INCREMENT,\n `name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',\n `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',\n `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',\n `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',\n `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `name_index` (`name`),\n UNIQUE KEY `mobile_index` (`mobile`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;\n\n" -var cfg = &config.Config{ - NamingFormat: "gozero", -} +var ( + sql = "-- 用户表 --\nCREATE TABLE `user` (\n `id` bigint(10) NOT NULL AUTO_INCREMENT,\n `name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',\n `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',\n `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',\n `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',\n `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `name_index` (`name`),\n UNIQUE KEY `mobile_index` (`mobile`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;\n\n" + cfg = &config.Config{ + NamingFormat: "gozero", + } +) func TestFromDDl(t *testing.T) { err := fromDDl("./user.sql", t.TempDir(), cfg, true, false) diff --git a/tools/goctl/model/sql/converter/types.go b/tools/goctl/model/sql/converter/types.go index c052d0ec..d0718362 100644 --- a/tools/goctl/model/sql/converter/types.go +++ b/tools/goctl/model/sql/converter/types.go @@ -5,41 +5,39 @@ import ( "strings" ) -var ( - commonMysqlDataTypeMap = map[string]string{ - // For consistency, all integer types are converted to int64 - // number - "bool": "int64", - "boolean": "int64", - "tinyint": "int64", - "smallint": "int64", - "mediumint": "int64", - "int": "int64", - "integer": "int64", - "bigint": "int64", - "float": "float64", - "double": "float64", - "decimal": "float64", - // date&time - "date": "time.Time", - "datetime": "time.Time", - "timestamp": "time.Time", - "time": "string", - "year": "int64", - // string - "char": "string", - "varchar": "string", - "binary": "string", - "varbinary": "string", - "tinytext": "string", - "text": "string", - "mediumtext": "string", - "longtext": "string", - "enum": "string", - "set": "string", - "json": "string", - } -) +var commonMysqlDataTypeMap = map[string]string{ + // For consistency, all integer types are converted to int64 + // number + "bool": "int64", + "boolean": "int64", + "tinyint": "int64", + "smallint": "int64", + "mediumint": "int64", + "int": "int64", + "integer": "int64", + "bigint": "int64", + "float": "float64", + "double": "float64", + "decimal": "float64", + // date&time + "date": "time.Time", + "datetime": "time.Time", + "timestamp": "time.Time", + "time": "string", + "year": "int64", + // string + "char": "string", + "varchar": "string", + "binary": "string", + "varbinary": "string", + "tinytext": "string", + "text": "string", + "mediumtext": "string", + "longtext": "string", + "enum": "string", + "set": "string", + "json": "string", +} func ConvertDataType(dataBaseType string, isDefaultNull bool) (string, error) { tp, ok := commonMysqlDataTypeMap[strings.ToLower(dataBaseType)] diff --git a/tools/goctl/model/sql/gen/delete.go b/tools/goctl/model/sql/gen/delete.go index f79773d8..3ae91818 100644 --- a/tools/goctl/model/sql/gen/delete.go +++ b/tools/goctl/model/sql/gen/delete.go @@ -58,5 +58,6 @@ func genDelete(table Table, withCache bool) (string, string, error) { if err != nil { return "", "", err } + return output.String(), deleteMethodOut.String(), nil } diff --git a/tools/goctl/model/sql/gen/field.go b/tools/goctl/model/sql/gen/field.go index f257300e..0a734057 100644 --- a/tools/goctl/model/sql/gen/field.go +++ b/tools/goctl/model/sql/gen/field.go @@ -10,6 +10,7 @@ import ( func genFields(fields []parser.Field) (string, error) { var list []string + for _, field := range fields { result, err := genField(field) if err != nil { @@ -18,6 +19,7 @@ func genFields(fields []parser.Field) (string, error) { list = append(list, result) } + return strings.Join(list, "\n"), nil } diff --git a/tools/goctl/model/sql/gen/findone.go b/tools/goctl/model/sql/gen/findone.go index d6ff0367..472b80ce 100644 --- a/tools/goctl/model/sql/gen/findone.go +++ b/tools/goctl/model/sql/gen/findone.go @@ -44,5 +44,6 @@ func genFindOne(table Table, withCache bool) (string, string, error) { if err != nil { return "", "", err } + return output.String(), findOneMethod.String(), nil } diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index d21df855..f29d5992 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -25,7 +25,7 @@ const ( type ( defaultGenerator struct { - //source string + // source string dir string console.Console pkg string @@ -57,6 +57,7 @@ func NewDefaultGenerator(dir string, cfg *config.Config, opt ...Option) (*defaul for _, fn := range optionList { fn(generator) } + return generator, nil } @@ -96,6 +97,7 @@ func (g *defaultGenerator) StartFromInformationSchema(db string, columns map[str m[table.Name.Source()] = code } + return g.createFile(m) } @@ -130,6 +132,7 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error { return err } } + // generate error file varFilename, err := format.FileNamingFormat(g.cfg.NamingFormat, "vars") if err != nil { @@ -168,16 +171,15 @@ func (g *defaultGenerator) genFromDDL(source string, withCache bool) (map[string } m[table.Name.Source()] = code } + return m, nil } -type ( - Table struct { - parser.Table - CacheKey map[string]Key - ContainsUniqueKey bool - } -) +type Table struct { + parser.Table + CacheKey map[string]Key + ContainsUniqueKey bool +} func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) { if len(in.PrimaryKey.Name.Source()) == 0 { @@ -292,5 +294,6 @@ func wrapWithRawString(v string) string { } else if len(v) == 1 { v = v + "`" } + return v } diff --git a/tools/goctl/model/sql/gen/keys.go b/tools/goctl/model/sql/gen/keys.go index 4d2ecd11..e77c59eb 100644 --- a/tools/goctl/model/sql/gen/keys.go +++ b/tools/goctl/model/sql/gen/keys.go @@ -8,20 +8,18 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) -type ( - // tableName:user - // {{prefix}}=cache - // key:id - Key struct { - VarExpression string // cacheUserIdPrefix = "cache#User#id#" - Left string // cacheUserIdPrefix - Right string // cache#user#id# - Variable string // userIdKey - KeyExpression string // userIdKey: = fmt.Sprintf("cache#user#id#%v", userId) - DataKeyExpression string // userIdKey: = fmt.Sprintf("cache#user#id#%v", data.userId) - RespKeyExpression string // userIdKey: = fmt.Sprintf("cache#user#id#%v", resp.userId) - } -) +// tableName:user +// {{prefix}}=cache +// key:id +type Key struct { + VarExpression string // cacheUserIdPrefix = "cache#User#id#" + Left string // cacheUserIdPrefix + Right string // cache#user#id# + Variable string // userIdKey + KeyExpression string // userIdKey: = fmt.Sprintf("cache#user#id#%v", userId) + DataKeyExpression string // userIdKey: = fmt.Sprintf("cache#user#id#%v", data.userId) + RespKeyExpression string // userIdKey: = fmt.Sprintf("cache#user#id#%v", resp.userId) +} // key-数据库原始字段名,value-缓存key相关数据 func genCacheKeys(table parser.Table) (map[string]Key, error) { @@ -42,6 +40,7 @@ func genCacheKeys(table parser.Table) (map[string]Key, error) { if strings.ToLower(lowerStartCamelTableName) == strings.ToLower(camelFieldName) { variable = fmt.Sprintf("%sKey", lowerStartCamelTableName) } + m[field.Name.Source()] = Key{ VarExpression: fmt.Sprintf(`%s = "%s"`, left, right), Left: left, @@ -53,5 +52,6 @@ func genCacheKeys(table parser.Table) (map[string]Key, error) { } } } + return m, nil } diff --git a/tools/goctl/model/sql/gen/keys_test.go b/tools/goctl/model/sql/gen/keys_test.go index 261be176..95868f8d 100644 --- a/tools/goctl/model/sql/gen/keys_test.go +++ b/tools/goctl/model/sql/gen/keys_test.go @@ -68,5 +68,4 @@ func TestGenCacheKeys(t *testing.T) { assert.Equal(t, fmt.Sprintf(`user%sKey`, name.ToCamel()), key.Variable) assert.Equal(t, `user`+name.ToCamel()+`Key := fmt.Sprintf("%s%v", cacheUser`+name.ToCamel()+`Prefix,`+name.Untitle()+`)`, key.KeyExpression) } - } diff --git a/tools/goctl/model/sql/gen/split.go b/tools/goctl/model/sql/gen/split.go index eb3e526b..f3358540 100644 --- a/tools/goctl/model/sql/gen/split.go +++ b/tools/goctl/model/sql/gen/split.go @@ -1,13 +1,12 @@ package gen -import ( - "regexp" -) +import "regexp" func (g *defaultGenerator) split(source string) []string { reg := regexp.MustCompile(createTableFlag) index := reg.FindAllStringIndex(source, -1) list := make([]string, 0) + for i := len(index) - 1; i >= 0; i-- { subIndex := index[i] if len(subIndex) == 0 { @@ -18,5 +17,6 @@ func (g *defaultGenerator) split(source string) []string { list = append(list, ddl) source = source[:start] } + return list } diff --git a/tools/goctl/model/sql/gen/tag.go b/tools/goctl/model/sql/gen/tag.go index 3c5ff150..fe09622a 100644 --- a/tools/goctl/model/sql/gen/tag.go +++ b/tools/goctl/model/sql/gen/tag.go @@ -9,16 +9,15 @@ func genTag(in string) (string, error) { if in == "" { return in, nil } + text, err := util.LoadTemplate(category, tagTemplateFile, template.Tag) if err != nil { return "", err } - output, err := util.With("tag"). - Parse(text). - Execute(map[string]interface{}{ - "field": in, - }) + output, err := util.With("tag").Parse(text).Execute(map[string]interface{}{ + "field": in, + }) if err != nil { return "", err } diff --git a/tools/goctl/model/sql/gen/template.go b/tools/goctl/model/sql/gen/template.go index 65e3a38f..c9078e6b 100644 --- a/tools/goctl/model/sql/gen/template.go +++ b/tools/goctl/model/sql/gen/template.go @@ -71,6 +71,7 @@ func RevertTemplate(name string) error { if !ok { return fmt.Errorf("%s: no such file name", name) } + return util.CreateTemplate(category, name, content) } @@ -79,5 +80,6 @@ func Update() error { if err != nil { return err } + return util.InitTemplates(category, templates) } diff --git a/tools/goctl/model/sql/gen/update.go b/tools/goctl/model/sql/gen/update.go index 6e854167..3e5c99f8 100644 --- a/tools/goctl/model/sql/gen/update.go +++ b/tools/goctl/model/sql/gen/update.go @@ -15,11 +15,14 @@ func genUpdate(table Table, withCache bool) (string, string, error) { if camel == "CreateTime" || camel == "UpdateTime" { continue } + if field.IsPrimaryKey { continue } + expressionValues = append(expressionValues, "data."+camel) } + expressionValues = append(expressionValues, "data."+table.PrimaryKey.Name.ToCamel()) camelTableName := table.Name.ToCamel() text, err := util.LoadTemplate(category, updateTemplateFile, template.Update) diff --git a/tools/goctl/model/sql/model/informationschemamodel.go b/tools/goctl/model/sql/model/informationschemamodel.go index c566a17a..3f889729 100644 --- a/tools/goctl/model/sql/model/informationschemamodel.go +++ b/tools/goctl/model/sql/model/informationschemamodel.go @@ -8,6 +8,7 @@ type ( InformationSchemaModel struct { conn sqlx.SqlConn } + Column struct { Name string `db:"COLUMN_NAME"` DataType string `db:"DATA_TYPE"`