From d894b88c3e4927e5969c879b1170740e8e44216e Mon Sep 17 00:00:00 2001 From: anqiansong Date: Mon, 1 Mar 2021 17:29:07 +0800 Subject: [PATCH] feature 1.1.5 (#411) --- tools/goctl/api/new/newservice.go | 5 + tools/goctl/api/parser/g4/ast/ast.go | 16 - tools/goctl/api/parser/g4/ast/service.go | 2 - tools/goctl/api/parser/g4/ast/type.go | 7 - tools/goctl/model/sql/command/command.go | 15 +- tools/goctl/model/sql/example/makefile | 4 +- tools/goctl/model/sql/example/sql/user.sql | 4 +- tools/goctl/model/sql/gen/delete.go | 14 +- tools/goctl/model/sql/gen/field.go | 4 +- tools/goctl/model/sql/gen/findone.go | 4 +- tools/goctl/model/sql/gen/findonebyfield.go | 59 +++- tools/goctl/model/sql/gen/gen.go | 30 +- tools/goctl/model/sql/gen/gen_test.go | 15 +- tools/goctl/model/sql/gen/insert.go | 26 +- tools/goctl/model/sql/gen/keys.go | 196 +++++++++--- tools/goctl/model/sql/gen/keys_test.go | 168 +++++++--- tools/goctl/model/sql/gen/update.go | 6 +- tools/goctl/model/sql/gen/vars.go | 24 +- .../model/sql/model/informationschemamodel.go | 167 +++++++++- tools/goctl/model/sql/parser/parser.go | 292 +++++++++++------- tools/goctl/model/sql/parser/parser_test.go | 198 ++++++++---- tools/goctl/model/sql/template/find.go | 4 +- .../goctl/model/sql/test/model/model_test.go | 120 ++++--- .../model/sql/test/model/studentmodel.go | 64 ++-- tools/goctl/model/sql/test/model/usermodel.go | 20 +- tools/goctl/model/sql/test/utils.go | 1 + tools/goctl/util/stringx/string.go | 5 + 27 files changed, 1032 insertions(+), 438 deletions(-) diff --git a/tools/goctl/api/new/newservice.go b/tools/goctl/api/new/newservice.go index 2be81233..86e87ae3 100644 --- a/tools/goctl/api/new/newservice.go +++ b/tools/goctl/api/new/newservice.go @@ -1,6 +1,7 @@ package new import ( + "errors" "os" "path/filepath" "strings" @@ -35,6 +36,10 @@ func CreateServiceCommand(c *cli.Context) error { dirName = "greet" } + if strings.Contains(dirName, "-") { + return errors.New("api new command service name not support strikethrough, because this will used by function name") + } + abs, err := filepath.Abs(dirName) if err != nil { return err diff --git a/tools/goctl/api/parser/g4/ast/ast.go b/tools/goctl/api/parser/g4/ast/ast.go index 97e25b3f..9a608930 100644 --- a/tools/goctl/api/parser/g4/ast/ast.go +++ b/tools/goctl/api/parser/g4/ast/ast.go @@ -7,7 +7,6 @@ import ( "github.com/antlr/antlr4/runtime/Go/antlr" "github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api" - "github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/util/console" ) @@ -323,18 +322,3 @@ func (v *ApiVisitor) getHiddenTokensToRight(t TokenStream, channel int) []Expr { return list } - -func (v *ApiVisitor) exportCheck(expr Expr) { - if expr == nil || !expr.IsNotNil() { - return - } - - if api.IsBasicType(expr.Text()) { - return - } - - if util.UnExport(expr.Text()) { - v.log.Warning("%s line %d:%d unexported declaration '%s', use %s instead", expr.Prefix(), expr.Line(), - expr.Column(), expr.Text(), strings.Title(expr.Text())) - } -} diff --git a/tools/goctl/api/parser/g4/ast/service.go b/tools/goctl/api/parser/g4/ast/service.go index d248ee9c..3ac7ae4b 100644 --- a/tools/goctl/api/parser/g4/ast/service.go +++ b/tools/goctl/api/parser/g4/ast/service.go @@ -219,7 +219,6 @@ func (v *ApiVisitor) VisitBody(ctx *api.BodyContext) interface{} { if api.IsGolangKeyWord(idRxpr.Text()) { v.panic(idRxpr, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", idRxpr.Text())) } - v.exportCheck(idRxpr) return &Body{ Lp: v.newExprWithToken(ctx.GetLp()), @@ -250,7 +249,6 @@ func (v *ApiVisitor) VisitReplybody(ctx *api.ReplybodyContext) interface{} { default: v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text())) } - v.log.Warning("%s %d:%d deprecated array type near '%s'", v.prefix, dataType.ArrayExpr.Line(), dataType.ArrayExpr.Column(), dataType.ArrayExpr.Text()) case *Literal: lit := dataType.Literal.Text() if api.IsGolangKeyWord(dataType.Literal.Text()) { diff --git a/tools/goctl/api/parser/g4/ast/type.go b/tools/goctl/api/parser/g4/ast/type.go index 72eeeeac..ffad1c52 100644 --- a/tools/goctl/api/parser/g4/ast/type.go +++ b/tools/goctl/api/parser/g4/ast/type.go @@ -153,7 +153,6 @@ func (v *ApiVisitor) VisitTypeBlockBody(ctx *api.TypeBlockBodyContext) interface func (v *ApiVisitor) VisitTypeStruct(ctx *api.TypeStructContext) interface{} { var st TypeStruct st.Name = v.newExprWithToken(ctx.GetStructName()) - v.exportCheck(st.Name) if util.UnExport(ctx.GetStructName().GetText()) { @@ -189,7 +188,6 @@ func (v *ApiVisitor) VisitTypeStruct(ctx *api.TypeStructContext) interface{} { func (v *ApiVisitor) VisitTypeBlockStruct(ctx *api.TypeBlockStructContext) interface{} { var st TypeStruct st.Name = v.newExprWithToken(ctx.GetStructName()) - v.exportCheck(st.Name) if ctx.GetStructToken() != nil { structExpr := v.newExprWithToken(ctx.GetStructToken()) @@ -261,7 +259,6 @@ func (v *ApiVisitor) VisitField(ctx *api.FieldContext) interface{} { func (v *ApiVisitor) VisitNormalField(ctx *api.NormalFieldContext) interface{} { var field TypeField field.Name = v.newExprWithToken(ctx.GetFieldName()) - v.exportCheck(field.Name) iDataTypeContext := ctx.DataType() if iDataTypeContext != nil { @@ -289,7 +286,6 @@ func (v *ApiVisitor) VisitAnonymousFiled(ctx *api.AnonymousFiledContext) interfa field.IsAnonymous = true if ctx.GetStar() != nil { nameExpr := v.newExprWithTerminalNode(ctx.ID()) - v.exportCheck(nameExpr) field.DataType = &Pointer{ PointerExpr: v.newExprWithText(ctx.GetStar().GetText()+ctx.ID().GetText(), start.GetLine(), start.GetColumn(), start.GetStart(), stop.GetStop()), Star: v.newExprWithToken(ctx.GetStar()), @@ -297,7 +293,6 @@ func (v *ApiVisitor) VisitAnonymousFiled(ctx *api.AnonymousFiledContext) interfa } } else { nameExpr := v.newExprWithTerminalNode(ctx.ID()) - v.exportCheck(nameExpr) field.DataType = &Literal{Literal: nameExpr} } field.DocExpr = v.getDoc(ctx) @@ -309,7 +304,6 @@ func (v *ApiVisitor) VisitAnonymousFiled(ctx *api.AnonymousFiledContext) interfa func (v *ApiVisitor) VisitDataType(ctx *api.DataTypeContext) interface{} { if ctx.ID() != nil { idExpr := v.newExprWithTerminalNode(ctx.ID()) - v.exportCheck(idExpr) return &Literal{Literal: idExpr} } if ctx.MapType() != nil { @@ -337,7 +331,6 @@ func (v *ApiVisitor) VisitDataType(ctx *api.DataTypeContext) interface{} { // VisitPointerType implements from api.BaseApiParserVisitor func (v *ApiVisitor) VisitPointerType(ctx *api.PointerTypeContext) interface{} { nameExpr := v.newExprWithTerminalNode(ctx.ID()) - v.exportCheck(nameExpr) return &Pointer{ PointerExpr: v.newExprWithText(ctx.GetText(), ctx.GetStar().GetLine(), ctx.GetStar().GetColumn(), ctx.GetStar().GetStart(), ctx.ID().GetSymbol().GetStop()), Star: v.newExprWithToken(ctx.GetStar()), diff --git a/tools/goctl/model/sql/command/command.go b/tools/goctl/model/sql/command/command.go index 183b92e1..273716e6 100644 --- a/tools/goctl/model/sql/command/command.go +++ b/tools/goctl/model/sql/command/command.go @@ -121,7 +121,7 @@ func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bo return err } - matchTables := make(map[string][]*model.Column) + matchTables := make(map[string]*model.Table) for _, item := range tables { match, err := filepath.Match(pattern, item) if err != nil { @@ -131,11 +131,18 @@ func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bo if !match { continue } - columns, err := im.FindByTableName(dsn.DBName, item) + + columnData, err := im.FindColumns(dsn.DBName, item) + if err != nil { + return err + } + + table, err := columnData.Convert() if err != nil { return err } - matchTables[item] = columns + + matchTables[item] = table } if len(matchTables) == 0 { @@ -147,5 +154,5 @@ func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bo return err } - return generator.StartFromInformationSchema(dsn.DBName, matchTables, cache) + return generator.StartFromInformationSchema(matchTables, cache) } diff --git a/tools/goctl/model/sql/example/makefile b/tools/goctl/model/sql/example/makefile index fa80267c..17fccaac 100644 --- a/tools/goctl/model/sql/example/makefile +++ b/tools/goctl/model/sql/example/makefile @@ -11,8 +11,8 @@ fromDDLWithoutCache: # generate model with cache from data source -user=root -password=password +user=ugozero +password= datasource=127.0.0.1:3306 database=gozero diff --git a/tools/goctl/model/sql/example/sql/user.sql b/tools/goctl/model/sql/example/sql/user.sql index e6640c46..5aaec5e6 100644 --- a/tools/goctl/model/sql/example/sql/user.sql +++ b/tools/goctl/model/sql/example/sql/user.sql @@ -17,10 +17,12 @@ CREATE TABLE `user` ( CREATE TABLE `student` ( `id` bigint NOT NULL AUTO_INCREMENT, + `class` varchar(255) COLLATE utf8mb4_bin NOT NULL DEFAULT '', `name` varchar(255) COLLATE utf8mb4_bin NOT NULL DEFAULT '', `age` tinyint DEFAULT NULL, `score` float(10,0) DEFAULT NULL, `create_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, `update_time` timestamp NULL DEFAULT NULL, - PRIMARY KEY (`id`) USING BTREE + PRIMARY KEY (`id`) USING BTREE, + UNIQUE KEY `class_name_index` (`class`,`name`) ) ENGINE=InnoDB AUTO_INCREMENT=4 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; \ No newline at end of file diff --git a/tools/goctl/model/sql/gen/delete.go b/tools/goctl/model/sql/gen/delete.go index 3ae91818..f6758ce2 100644 --- a/tools/goctl/model/sql/gen/delete.go +++ b/tools/goctl/model/sql/gen/delete.go @@ -12,13 +12,11 @@ import ( func genDelete(table Table, withCache bool) (string, string, error) { keySet := collection.NewSet() keyVariableSet := collection.NewSet() - for fieldName, key := range table.CacheKey { - if fieldName == table.PrimaryKey.Name.Source() { - keySet.AddStr(key.KeyExpression) - } else { - keySet.AddStr(key.DataKeyExpression) - } - keyVariableSet.AddStr(key.Variable) + keySet.AddStr(table.PrimaryCacheKey.KeyExpression) + keyVariableSet.AddStr(table.PrimaryCacheKey.KeyLeft) + for _, key := range table.UniqueCacheKey { + keySet.AddStr(key.DataKeyExpression) + keyVariableSet.AddStr(key.KeyLeft) } camel := table.Name.ToCamel() @@ -32,7 +30,7 @@ func genDelete(table Table, withCache bool) (string, string, error) { Execute(map[string]interface{}{ "upperStartCamelObject": camel, "withCache": withCache, - "containsIndexCache": table.ContainsUniqueKey, + "containsIndexCache": table.ContainsUniqueCacheKey, "lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(), "dataType": table.PrimaryKey.DataType, "keys": strings.Join(keySet.KeysStr(), "\n"), diff --git a/tools/goctl/model/sql/gen/field.go b/tools/goctl/model/sql/gen/field.go index 0a734057..e319bb33 100644 --- a/tools/goctl/model/sql/gen/field.go +++ b/tools/goctl/model/sql/gen/field.go @@ -8,7 +8,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/util" ) -func genFields(fields []parser.Field) (string, error) { +func genFields(fields []*parser.Field) (string, error) { var list []string for _, field := range fields { @@ -23,7 +23,7 @@ func genFields(fields []parser.Field) (string, error) { return strings.Join(list, "\n"), nil } -func genField(field parser.Field) (string, error) { +func genField(field *parser.Field) (string, error) { tag, err := genTag(field.Name.Source()) if err != nil { return "", err diff --git a/tools/goctl/model/sql/gen/findone.go b/tools/goctl/model/sql/gen/findone.go index 472b80ce..e9219206 100644 --- a/tools/goctl/model/sql/gen/findone.go +++ b/tools/goctl/model/sql/gen/findone.go @@ -22,8 +22,8 @@ func genFindOne(table Table, withCache bool) (string, string, error) { "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), "lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(), "dataType": table.PrimaryKey.DataType, - "cacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].KeyExpression, - "cacheKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable, + "cacheKey": table.PrimaryCacheKey.KeyExpression, + "cacheKeyVariable": table.PrimaryCacheKey.KeyLeft, }) if err != nil { return "", "", err diff --git a/tools/goctl/model/sql/gen/findonebyfield.go b/tools/goctl/model/sql/gen/findonebyfield.go index c68febcf..9c23261a 100644 --- a/tools/goctl/model/sql/gen/findonebyfield.go +++ b/tools/goctl/model/sql/gen/findonebyfield.go @@ -24,22 +24,40 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { t := util.With("findOneByField").Parse(text) var list []string camelTableName := table.Name.ToCamel() - for _, field := range table.Fields { - if field.IsPrimaryKey || !field.IsUniqueKey { - continue + for _, key := range table.UniqueCacheKey { + var inJoin, paramJoin, argJoin Join + for _, f := range key.Fields { + param := stringx.From(f.Name.ToCamel()).Untitle() + inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType)) + paramJoin = append(paramJoin, param) + argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source()))) } - camelFieldName := field.Name.ToCamel() + var in string + if len(inJoin) > 0 { + in = inJoin.With(", ").Source() + } + + var paramJoinString string + if len(paramJoin) > 0 { + paramJoinString = paramJoin.With(",").Source() + } + + var originalFieldString string + if len(argJoin) > 0 { + originalFieldString = argJoin.With(" and ").Source() + } + output, err := t.Execute(map[string]interface{}{ "upperStartCamelObject": camelTableName, - "upperField": camelFieldName, - "in": fmt.Sprintf("%s %s", stringx.From(camelFieldName).Untitle(), field.DataType), + "upperField": key.FieldNameJoin.Camel().With("").Source(), + "in": in, "withCache": withCache, - "cacheKey": table.CacheKey[field.Name.Source()].KeyExpression, - "cacheKeyVariable": table.CacheKey[field.Name.Source()].Variable, + "cacheKey": key.KeyExpression, + "cacheKeyVariable": key.KeyLeft, "lowerStartCamelObject": stringx.From(camelTableName).Untitle(), - "lowerStartCamelField": stringx.From(camelFieldName).Untitle(), + "lowerStartCamelField": paramJoinString, "upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(), - "originalField": wrapWithRawString(field.Name.Source()), + "originalField": originalFieldString, }) if err != nil { return nil, err @@ -55,15 +73,22 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { t = util.With("findOneByFieldMethod").Parse(text) var listMethod []string - for _, field := range table.Fields { - if field.IsPrimaryKey || !field.IsUniqueKey { - continue + for _, key := range table.UniqueCacheKey { + var inJoin, paramJoin Join + for _, f := range key.Fields { + param := stringx.From(f.Name.ToCamel()).Untitle() + inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType)) + paramJoin = append(paramJoin, param) + } + + var in string + if len(inJoin) > 0 { + in = inJoin.With(", ").Source() } - camelFieldName := field.Name.ToCamel() output, err := t.Execute(map[string]interface{}{ "upperStartCamelObject": camelTableName, - "upperField": camelFieldName, - "in": fmt.Sprintf("%s %s", stringx.From(camelFieldName).Untitle(), field.DataType), + "upperField": key.FieldNameJoin.Camel().With("").Source(), + "in": in, }) if err != nil { return nil, err @@ -80,7 +105,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { out, err := util.With("findOneByFieldExtraMethod").Parse(text).Execute(map[string]interface{}{ "upperStartCamelObject": camelTableName, - "primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left, + "primaryKeyLeft": table.PrimaryCacheKey.VarLeft, "lowerStartCamelObject": stringx.From(camelTableName).Untitle(), "originalPrimaryField": wrapWithRawString(table.PrimaryKey.Name.Source()), }) diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index 0968ace6..a92e107b 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -99,10 +99,10 @@ func (g *defaultGenerator) StartFromDDL(source string, withCache bool) error { return g.createFile(modelList) } -func (g *defaultGenerator) StartFromInformationSchema(db string, columns map[string][]*model.Column, withCache bool) error { +func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.Table, withCache bool) error { m := make(map[string]string) - for tableName, column := range columns { - table, err := parser.ConvertColumn(db, tableName, column) + for _, each := range tables { + table, err := parser.ConvertDataType(each) if err != nil { return err } @@ -182,10 +182,12 @@ func (g *defaultGenerator) genFromDDL(source string, withCache bool) (map[string if err != nil { return nil, err } + code, err := g.genModel(*table, withCache) if err != nil { return nil, err } + m[table.Name.Source()] = code } @@ -195,8 +197,9 @@ func (g *defaultGenerator) genFromDDL(source string, withCache bool) (map[string // Table defines mysql table type Table struct { parser.Table - CacheKey map[string]Key - ContainsUniqueKey bool + PrimaryCacheKey Key + UniqueCacheKey []Key + ContainsUniqueCacheKey bool } func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) { @@ -204,10 +207,7 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er return "", fmt.Errorf("table %s: missing primary key", in.Name.Source()) } - m, err := genCacheKeys(in) - if err != nil { - return "", err - } + primaryKey, uniqueKey := genCacheKeys(in) importsCode, err := genImports(withCache, in.ContainsTime()) if err != nil { @@ -216,15 +216,9 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er var table Table table.Table = in - table.CacheKey = m - var containsUniqueCache = false - for _, item := range table.Fields { - if item.IsUniqueKey { - containsUniqueCache = true - break - } - } - table.ContainsUniqueKey = containsUniqueCache + table.PrimaryCacheKey = primaryKey + table.UniqueCacheKey = uniqueKey + table.ContainsUniqueCacheKey = len(uniqueKey) > 0 varsCode, err := genVars(table, withCache) if err != nil { diff --git a/tools/goctl/model/sql/gen/gen_test.go b/tools/goctl/model/sql/gen/gen_test.go index b32f2300..4a888c6e 100644 --- a/tools/goctl/model/sql/gen/gen_test.go +++ b/tools/goctl/model/sql/gen/gen_test.go @@ -16,18 +16,15 @@ import ( ) var ( - source = "CREATE TABLE `test_user_info` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `nanosecond` bigint NOT NULL DEFAULT '0',\n `data` varchar(255) DEFAULT '',\n `content` json DEFAULT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `nanosecond_unique` (`nanosecond`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;" + source = "CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL,\n `class` bigint NOT NULL,\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;" ) func TestCacheModel(t *testing.T) { logx.Disable() _ = Clean() - dir, _ := filepath.Abs("./testmodel") + dir := filepath.Join(t.TempDir(), "./testmodel") cacheDir := filepath.Join(dir, "cache") noCacheDir := filepath.Join(dir, "nocache") - defer func() { - _ = os.RemoveAll(dir) - }() g, err := NewDefaultGenerator(cacheDir, &config.Config{ NamingFormat: "GoZero", }) @@ -36,7 +33,7 @@ func TestCacheModel(t *testing.T) { err = g.StartFromDDL(source, true) assert.Nil(t, err) assert.True(t, func() bool { - _, err := os.Stat(filepath.Join(cacheDir, "TestUserInfoModel.go")) + _, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go")) return err == nil }()) g, err = NewDefaultGenerator(noCacheDir, &config.Config{ @@ -47,7 +44,7 @@ func TestCacheModel(t *testing.T) { err = g.StartFromDDL(source, false) assert.Nil(t, err) assert.True(t, func() bool { - _, err := os.Stat(filepath.Join(noCacheDir, "testuserinfomodel.go")) + _, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go")) return err == nil }()) } @@ -69,7 +66,7 @@ func TestNamingModel(t *testing.T) { err = g.StartFromDDL(source, true) assert.Nil(t, err) assert.True(t, func() bool { - _, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go")) + _, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go")) return err == nil }()) g, err = NewDefaultGenerator(snakeDir, &config.Config{ @@ -80,7 +77,7 @@ func TestNamingModel(t *testing.T) { err = g.StartFromDDL(source, true) assert.Nil(t, err) assert.True(t, func() bool { - _, err := os.Stat(filepath.Join(snakeDir, "test_user_info_model.go")) + _, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go")) return err == nil }()) } diff --git a/tools/goctl/model/sql/gen/insert.go b/tools/goctl/model/sql/gen/insert.go index e6ff7cbc..a82d46cb 100644 --- a/tools/goctl/model/sql/gen/insert.go +++ b/tools/goctl/model/sql/gen/insert.go @@ -12,12 +12,9 @@ import ( func genInsert(table Table, withCache bool) (string, string, error) { keySet := collection.NewSet() keyVariableSet := collection.NewSet() - for fieldName, key := range table.CacheKey { - if fieldName == table.PrimaryKey.Name.Source() { - continue - } + for _, key := range table.UniqueCacheKey { keySet.AddStr(key.DataKeyExpression) - keyVariableSet.AddStr(key.Variable) + keyVariableSet.AddStr(key.KeyLeft) } expressions := make([]string, 0) @@ -27,12 +24,17 @@ func genInsert(table Table, withCache bool) (string, string, error) { if camel == "CreateTime" || camel == "UpdateTime" { continue } - if field.IsPrimaryKey && table.PrimaryKey.AutoIncrement { - continue + + if field.Name.Source() == table.PrimaryKey.Name.Source() { + if table.PrimaryKey.AutoIncrement { + continue + } } + expressions = append(expressions, "?") expressionValues = append(expressionValues, "data."+camel) } + camel := table.Name.ToCamel() text, err := util.LoadTemplate(category, insertTemplateFile, template.Insert) if err != nil { @@ -43,7 +45,7 @@ func genInsert(table Table, withCache bool) (string, string, error) { Parse(text). Execute(map[string]interface{}{ "withCache": withCache, - "containsIndexCache": table.ContainsUniqueKey, + "containsIndexCache": table.ContainsUniqueCacheKey, "upperStartCamelObject": camel, "lowerStartCamelObject": stringx.From(camel).Untitle(), "expression": strings.Join(expressions, ", "), @@ -61,11 +63,9 @@ func genInsert(table Table, withCache bool) (string, string, error) { return "", "", err } - insertMethodOutput, err := util.With("insertMethod"). - Parse(text). - Execute(map[string]interface{}{ - "upperStartCamelObject": camel, - }) + insertMethodOutput, err := util.With("insertMethod").Parse(text).Execute(map[string]interface{}{ + "upperStartCamelObject": camel, + }) if err != nil { return "", "", err } diff --git a/tools/goctl/model/sql/gen/keys.go b/tools/goctl/model/sql/gen/keys.go index 04220bcc..50cb37fe 100644 --- a/tools/goctl/model/sql/gen/keys.go +++ b/tools/goctl/model/sql/gen/keys.go @@ -2,61 +2,163 @@ package gen import ( "fmt" + "sort" "strings" "github.com/tal-tech/go-zero/tools/goctl/model/sql/parser" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) -// Key defines cache key variable for generating code +// Key describes cache key type Key struct { - // VarExpression likes cacheUserIdPrefix = "cache#User#id#" + // VarLeft describes the varible of cache key expression which likes cacheUserIdPrefix + VarLeft string + // VarRight describes the value of cache key expression which likes "cache#user#id#" + VarRight string + // VarExpression describes the cache key expression which likes cacheUserIdPrefix = "cache#user#id#" VarExpression string - // Left likes cacheUserIdPrefix - Left string - // Right likes cache#user#id# - Right string - // Variable likes userIdKey - Variable string - // KeyExpression likes userIdKey: = fmt.Sprintf("cache#user#id#%v", userId) + // KeyLeft describes the varible of key definiation expression which likes userKey + KeyLeft string + // KeyRight describes the value of key definiation expression which likes fmt.Sprintf("%s%v", cacheUserPrefix, user) + KeyRight string + // DataKeyRight describes data key likes fmt.Sprintf("%s%v", cacheUserPrefix, data.User) + DataKeyRight string + // KeyExpression describes key expression likes userKey := fmt.Sprintf("%s%v", cacheUserPrefix, user) KeyExpression string - // DataKeyExpression likes userIdKey: = fmt.Sprintf("cache#user#id#%v", data.userId) + // DataKeyExpression describes data key expression likes userKey := fmt.Sprintf("%s%v", cacheUserPrefix, data.User) DataKeyExpression string - // RespKeyExpression likes userIdKey: = fmt.Sprintf("cache#user#id#%v", resp.userId) - RespKeyExpression string -} - -// key-数据库原始字段名,value-缓存key相关数据 -func genCacheKeys(table parser.Table) (map[string]Key, error) { - fields := table.Fields - m := make(map[string]Key) - camelTableName := table.Name.ToCamel() - lowerStartCamelTableName := stringx.From(camelTableName).Untitle() - for _, field := range fields { - if field.IsUniqueKey || field.IsPrimaryKey { - camelFieldName := field.Name.ToCamel() - lowerStartCamelFieldName := stringx.From(camelFieldName).Untitle() - left := fmt.Sprintf("cache%s%sPrefix", camelTableName, camelFieldName) - if strings.ToLower(camelFieldName) == strings.ToLower(camelTableName) { - left = fmt.Sprintf("cache%sPrefix", camelTableName) - } - right := fmt.Sprintf("cache#%s#%s#", camelTableName, lowerStartCamelFieldName) - variable := fmt.Sprintf("%s%sKey", lowerStartCamelTableName, camelFieldName) - if strings.ToLower(lowerStartCamelTableName) == strings.ToLower(camelFieldName) { - variable = fmt.Sprintf("%sKey", lowerStartCamelTableName) - } - - m[field.Name.Source()] = Key{ - VarExpression: fmt.Sprintf(`%s = "%s"`, left, right), - Left: left, - Right: right, - Variable: variable, - KeyExpression: fmt.Sprintf(`%s := fmt.Sprintf("%s%s", %s,%s)`, variable, "%s", "%v", left, lowerStartCamelFieldName), - DataKeyExpression: fmt.Sprintf(`%s := fmt.Sprintf("%s%s",%s, data.%s)`, variable, "%s", "%v", left, camelFieldName), - RespKeyExpression: fmt.Sprintf(`%s := fmt.Sprintf("%s%s", %s,resp.%s)`, variable, "%s", "%v", left, camelFieldName), - } - } - } - - return m, nil + // FieldNameJoin describes the filed slice of table + FieldNameJoin Join + // Fields describes the fields of table + Fields []*parser.Field +} + +// Join describes an alias of string slice +type Join []string + +func genCacheKeys(table parser.Table) (Key, []Key) { + var primaryKey Key + var uniqueKey []Key + primaryKey = genCacheKey(table.Name, []*parser.Field{&table.PrimaryKey.Field}) + for _, each := range table.UniqueIndex { + uniqueKey = append(uniqueKey, genCacheKey(table.Name, each)) + } + sort.Slice(uniqueKey, func(i, j int) bool { + return uniqueKey[i].VarLeft < uniqueKey[j].VarLeft + }) + + return primaryKey, uniqueKey +} + +func genCacheKey(table stringx.String, in []*parser.Field) Key { + var ( + varLeftJoin, varRightJon, fieldNameJoin Join + varLeft, varRight, varExpression string + + keyLeftJoin, keyRightJoin, keyRightArgJoin, dataRightJoin Join + keyLeft, keyRight, dataKeyRight, keyExpression, dataKeyExpression string + ) + + varLeftJoin = append(varLeftJoin, "cache", table.Source()) + varRightJon = append(varRightJon, "cache", table.Source()) + keyLeftJoin = append(keyLeftJoin, table.Source()) + + for _, each := range in { + varLeftJoin = append(varLeftJoin, each.Name.Source()) + varRightJon = append(varRightJon, each.Name.Source()) + keyLeftJoin = append(keyLeftJoin, each.Name.Source()) + keyRightJoin = append(keyRightJoin, stringx.From(each.Name.ToCamel()).Untitle()) + keyRightArgJoin = append(keyRightArgJoin, "%v") + dataRightJoin = append(dataRightJoin, "data."+each.Name.ToCamel()) + fieldNameJoin = append(fieldNameJoin, each.Name.Source()) + } + varLeftJoin = append(varLeftJoin, "prefix") + keyLeftJoin = append(keyLeftJoin, "key") + + varLeft = varLeftJoin.Camel().With("").Untitle() + varRight = fmt.Sprintf(`"%s"`, varRightJon.Camel().Untitle().With("#").Source()+"#") + varExpression = fmt.Sprintf(`%s = %s`, varLeft, varRight) + + keyLeft = keyLeftJoin.Camel().With("").Untitle() + keyRight = fmt.Sprintf(`fmt.Sprintf("%s%s", %s, %s)`, "%s", keyRightArgJoin.With("").Source(), varLeft, keyRightJoin.With(", ").Source()) + dataKeyRight = fmt.Sprintf(`fmt.Sprintf("%s%s", %s, %s)`, "%s", keyRightArgJoin.With("").Source(), varLeft, dataRightJoin.With(", ").Source()) + keyExpression = fmt.Sprintf("%s := %s", keyLeft, keyRight) + dataKeyExpression = fmt.Sprintf("%s := %s", keyLeft, dataKeyRight) + + return Key{ + VarLeft: varLeft, + VarRight: varRight, + VarExpression: varExpression, + KeyLeft: keyLeft, + KeyRight: keyRight, + DataKeyRight: dataKeyRight, + KeyExpression: keyExpression, + DataKeyExpression: dataKeyExpression, + Fields: in, + FieldNameJoin: fieldNameJoin, + } +} + +// Title convert items into Title and return +func (j Join) Title() Join { + var join Join + for _, each := range j { + join = append(join, stringx.From(each).Title()) + } + + return join +} + +// Camel convert items into Camel and return +func (j Join) Camel() Join { + var join Join + for _, each := range j { + join = append(join, stringx.From(each).ToCamel()) + } + return join +} + +// Snake convert items into Snake and return +func (j Join) Snake() Join { + var join Join + for _, each := range j { + join = append(join, stringx.From(each).ToSnake()) + } + + return join +} + +// Snake convert items into Untitle and return +func (j Join) Untitle() Join { + var join Join + for _, each := range j { + join = append(join, stringx.From(each).Untitle()) + } + + return join +} + +// Upper convert items into Upper and return +func (j Join) Upper() Join { + var join Join + for _, each := range j { + join = append(join, stringx.From(each).Upper()) + } + + return join +} + +// Lower convert items into Lower and return +func (j Join) Lower() Join { + var join Join + for _, each := range j { + join = append(join, stringx.From(each).Lower()) + } + + return join +} + +// With convert items into With and return +func (j Join) With(sep string) stringx.String { + return stringx.From(strings.Join(j, sep)) } diff --git a/tools/goctl/model/sql/gen/keys_test.go b/tools/goctl/model/sql/gen/keys_test.go index 95868f8d..536eb6f2 100644 --- a/tools/goctl/model/sql/gen/keys_test.go +++ b/tools/goctl/model/sql/gen/keys_test.go @@ -1,7 +1,7 @@ package gen import ( - "fmt" + "sort" "testing" "github.com/stretchr/testify/assert" @@ -10,62 +10,156 @@ import ( ) func TestGenCacheKeys(t *testing.T) { - m, err := genCacheKeys(parser.Table{ + primaryField := &parser.Field{ + Name: stringx.From("id"), + DataBaseType: "bigint", + DataType: "int64", + Comment: "自增id", + SeqInIndex: 1, + } + mobileField := &parser.Field{ + Name: stringx.From("mobile"), + DataBaseType: "varchar", + DataType: "string", + Comment: "手机号", + SeqInIndex: 1, + } + classField := &parser.Field{ + Name: stringx.From("class"), + DataBaseType: "varchar", + DataType: "string", + Comment: "班级", + SeqInIndex: 1, + } + nameField := &parser.Field{ + Name: stringx.From("name"), + DataBaseType: "varchar", + DataType: "string", + Comment: "姓名", + SeqInIndex: 2, + } + primariCacheKey, uniqueCacheKey := genCacheKeys(parser.Table{ Name: stringx.From("user"), PrimaryKey: parser.Primary{ - Field: parser.Field{ - Name: stringx.From("id"), - DataBaseType: "bigint", - DataType: "int64", - IsPrimaryKey: true, - IsUniqueKey: false, - Comment: "自增id", - }, + Field: *primaryField, AutoIncrement: true, }, - Fields: []parser.Field{ - { - Name: stringx.From("mobile"), - DataBaseType: "varchar", - DataType: "string", - IsPrimaryKey: false, - IsUniqueKey: true, - Comment: "手机号", + UniqueIndex: map[string][]*parser.Field{ + "mobile_unique": []*parser.Field{ + mobileField, }, - { - Name: stringx.From("name"), - DataBaseType: "varchar", - DataType: "string", - IsPrimaryKey: false, - IsUniqueKey: true, - Comment: "姓名", + "class_name_unique": []*parser.Field{ + classField, + nameField, }, + }, + NormalIndex: nil, + Fields: []*parser.Field{ + primaryField, + mobileField, + classField, + nameField, { Name: stringx.From("createTime"), DataBaseType: "timestamp", DataType: "time.Time", - IsPrimaryKey: false, - IsUniqueKey: false, Comment: "创建时间", }, { Name: stringx.From("updateTime"), DataBaseType: "timestamp", DataType: "time.Time", - IsPrimaryKey: false, - IsUniqueKey: false, Comment: "更新时间", }, }, }) - assert.Nil(t, err) - for fieldName, key := range m { - name := stringx.From(fieldName) - assert.Equal(t, fmt.Sprintf(`cacheUser%sPrefix = "cache#User#%s#"`, name.ToCamel(), name.Untitle()), key.VarExpression) - assert.Equal(t, fmt.Sprintf(`cacheUser%sPrefix`, name.ToCamel()), key.Left) - assert.Equal(t, fmt.Sprintf(`cache#User#%s#`, name.Untitle()), key.Right) - assert.Equal(t, fmt.Sprintf(`user%sKey`, name.ToCamel()), key.Variable) - assert.Equal(t, `user`+name.ToCamel()+`Key := fmt.Sprintf("%s%v", cacheUser`+name.ToCamel()+`Prefix,`+name.Untitle()+`)`, key.KeyExpression) + t.Run("primaryCacheKey", func(t *testing.T) { + assert.Equal(t, true, func() bool { + return cacheKeyEqual(primariCacheKey, Key{ + VarLeft: "cacheUserIdPrefix", + VarRight: `"cache#user#id#"`, + VarExpression: `cacheUserIdPrefix = "cache#user#id#"`, + KeyLeft: "userIdKey", + KeyRight: `fmt.Sprintf("%s%v", cacheUserIdPrefix, id)`, + DataKeyRight: `fmt.Sprintf("%s%v", cacheUserIdPrefix, data.Id)`, + KeyExpression: `userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, id)`, + DataKeyExpression: `userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, data.Id)`, + FieldNameJoin: []string{"id"}, + }) + }()) + }) + + t.Run("uniqueCacheKey", func(t *testing.T) { + assert.Equal(t, true, func() bool { + expected := []Key{ + { + VarLeft: "cacheUserClassNamePrefix", + VarRight: `"cache#user#class#name#"`, + VarExpression: `cacheUserClassNamePrefix = "cache#user#class#name#"`, + KeyLeft: "userClassNameKey", + KeyRight: `fmt.Sprintf("%s%v%v", cacheUserClassNamePrefix, class, name)`, + DataKeyRight: `fmt.Sprintf("%s%v%v", cacheUserClassNamePrefix, data.Class, data.Name)`, + KeyExpression: `userClassNameKey := fmt.Sprintf("%s%v%v", cacheUserClassNamePrefix, class, name)`, + DataKeyExpression: `userClassNameKey := fmt.Sprintf("%s%v%v", cacheUserClassNamePrefix, data.Class, data.Name)`, + FieldNameJoin: []string{"class", "name"}, + }, + { + VarLeft: "cacheUserMobilePrefix", + VarRight: `"cache#user#mobile#"`, + VarExpression: `cacheUserMobilePrefix = "cache#user#mobile#"`, + KeyLeft: "userMobileKey", + KeyRight: `fmt.Sprintf("%s%v", cacheUserMobilePrefix, mobile)`, + DataKeyRight: `fmt.Sprintf("%s%v", cacheUserMobilePrefix, data.Mobile)`, + KeyExpression: `userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, mobile)`, + DataKeyExpression: `userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, data.Mobile)`, + FieldNameJoin: []string{"mobile"}, + }, + } + sort.Slice(uniqueCacheKey, func(i, j int) bool { + return uniqueCacheKey[i].VarLeft < uniqueCacheKey[j].VarLeft + }) + + if len(expected) != len(uniqueCacheKey) { + return false + } + + for index, each := range uniqueCacheKey { + expecting := expected[index] + if !cacheKeyEqual(expecting, each) { + return false + } + } + + return true + }()) + }) + +} + +func cacheKeyEqual(k1 Key, k2 Key) bool { + k1Join := k1.FieldNameJoin + k2Join := k2.FieldNameJoin + sort.Strings(k1Join) + sort.Strings(k2Join) + if len(k1Join) != len(k2Join) { + return false + } + + for index, each := range k1Join { + k2Item := k2Join[index] + if each != k2Item { + return false + } } + + return k1.VarLeft == k2.VarLeft && + k1.VarRight == k2.VarRight && + k1.VarExpression == k2.VarExpression && + k1.KeyLeft == k2.KeyLeft && + k1.KeyRight == k2.KeyRight && + k1.DataKeyRight == k2.DataKeyRight && + k1.DataKeyExpression == k2.DataKeyExpression && + k1.KeyExpression == k2.KeyExpression + } diff --git a/tools/goctl/model/sql/gen/update.go b/tools/goctl/model/sql/gen/update.go index 3e5c99f8..9e80689b 100644 --- a/tools/goctl/model/sql/gen/update.go +++ b/tools/goctl/model/sql/gen/update.go @@ -16,7 +16,7 @@ func genUpdate(table Table, withCache bool) (string, string, error) { continue } - if field.IsPrimaryKey { + if field.Name.Source() == table.PrimaryKey.Name.Source() { continue } @@ -35,8 +35,8 @@ func genUpdate(table Table, withCache bool) (string, string, error) { Execute(map[string]interface{}{ "withCache": withCache, "upperStartCamelObject": camelTableName, - "primaryCacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].DataKeyExpression, - "primaryKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable, + "primaryCacheKey": table.PrimaryCacheKey.DataKeyExpression, + "primaryKeyVariable": table.PrimaryCacheKey.KeyLeft, "lowerStartCamelObject": stringx.From(camelTableName).Untitle(), "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), "expressionValues": strings.Join(expressionValues, ", "), diff --git a/tools/goctl/model/sql/gen/vars.go b/tools/goctl/model/sql/gen/vars.go index 6ec6ee49..6be4fc70 100644 --- a/tools/goctl/model/sql/gen/vars.go +++ b/tools/goctl/model/sql/gen/vars.go @@ -10,26 +10,26 @@ import ( func genVars(table Table, withCache bool) (string, error) { keys := make([]string, 0) - for _, v := range table.CacheKey { + keys = append(keys, table.PrimaryCacheKey.VarExpression) + for _, v := range table.UniqueCacheKey { keys = append(keys, v.VarExpression) } + camel := table.Name.ToCamel() text, err := util.LoadTemplate(category, varTemplateFile, template.Vars) if err != nil { return "", err } - output, err := util.With("var"). - Parse(text). - GoFmt(true). - Execute(map[string]interface{}{ - "lowerStartCamelObject": stringx.From(camel).Untitle(), - "upperStartCamelObject": camel, - "cacheKeys": strings.Join(keys, "\n"), - "autoIncrement": table.PrimaryKey.AutoIncrement, - "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), - "withCache": withCache, - }) + output, err := util.With("var").Parse(text). + GoFmt(true).Execute(map[string]interface{}{ + "lowerStartCamelObject": stringx.From(camel).Untitle(), + "upperStartCamelObject": camel, + "cacheKeys": strings.Join(keys, "\n"), + "autoIncrement": table.PrimaryKey.AutoIncrement, + "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), + "withCache": withCache, + }) if err != nil { return "", err } diff --git a/tools/goctl/model/sql/model/informationschemamodel.go b/tools/goctl/model/sql/model/informationschemamodel.go index 2b7b1ba7..5a20ec66 100644 --- a/tools/goctl/model/sql/model/informationschemamodel.go +++ b/tools/goctl/model/sql/model/informationschemamodel.go @@ -1,6 +1,13 @@ package model -import "github.com/tal-tech/go-zero/core/stores/sqlx" +import ( + "fmt" + "sort" + + "github.com/tal-tech/go-zero/core/stores/sqlx" +) + +const indexPri = "PRIMARY" type ( // InformationSchemaModel defines information schema model @@ -10,13 +17,53 @@ type ( // Column defines column in table Column struct { - Name string `db:"COLUMN_NAME"` - DataType string `db:"DATA_TYPE"` - Key string `db:"COLUMN_KEY"` - Extra string `db:"EXTRA"` - Comment string `db:"COLUMN_COMMENT"` - ColumnDefault interface{} `db:"COLUMN_DEFAULT"` - IsNullAble string `db:"IS_NULLABLE"` + *DbColumn + Index *DbIndex + } + + // DbColumn defines column info of columns + DbColumn struct { + Name string `db:"COLUMN_NAME"` + DataType string `db:"DATA_TYPE"` + Extra string `db:"EXTRA"` + Comment string `db:"COLUMN_COMMENT"` + ColumnDefault interface{} `db:"COLUMN_DEFAULT"` + IsNullAble string `db:"IS_NULLABLE"` + OrdinalPosition int `db:"ORDINAL_POSITION"` + } + + // DbIndex defines index of columns in information_schema.statistic + DbIndex struct { + IndexName string `db:"INDEX_NAME"` + NonUnique int `db:"NON_UNIQUE"` + SeqInIndex int `db:"SEQ_IN_INDEX"` + } + + // ColumnData describes the columns of table + ColumnData struct { + Db string + Table string + Columns []*Column + } + + // Table describes mysql table which contains database name, table name, columns, keys + Table struct { + Db string + Table string + Columns []*Column + // Primary key not included + UniqueIndex map[string][]*Column + PrimaryKey *Column + NormalIndex map[string][]*Column + } + + // IndexType describes an alias of string + IndexType string + + // Index describes a column index + Index struct { + IndexType IndexType + Columns []*Column } ) @@ -37,10 +84,102 @@ func (m *InformationSchemaModel) GetAllTables(database string) ([]string, error) return tables, nil } -// FindByTableName finds out the target table by name -func (m *InformationSchemaModel) FindByTableName(db, table string) ([]*Column, error) { - querySQL := `select COLUMN_NAME,COLUMN_DEFAULT,IS_NULLABLE,DATA_TYPE,COLUMN_KEY,EXTRA,COLUMN_COMMENT from COLUMNS where TABLE_SCHEMA = ? and TABLE_NAME = ?` - var reply []*Column - err := m.conn.QueryRows(&reply, querySQL, db, table) - return reply, err +// FindColumns return columns in specified database and table +func (m *InformationSchemaModel) FindColumns(db, table string) (*ColumnData, error) { + querySql := `SELECT c.COLUMN_NAME,c.DATA_TYPE,EXTRA,c.COLUMN_COMMENT,c.COLUMN_DEFAULT,c.IS_NULLABLE,c.ORDINAL_POSITION from COLUMNS c WHERE c.TABLE_SCHEMA = ? and c.TABLE_NAME = ? ` + var reply []*DbColumn + err := m.conn.QueryRowsPartial(&reply, querySql, db, table) + if err != nil { + return nil, err + } + + var list []*Column + for _, item := range reply { + index, err := m.FindIndex(db, table, item.Name) + if err != nil { + if err != sqlx.ErrNotFound { + return nil, err + } + continue + } + + if len(index) > 0 { + for _, i := range index { + list = append(list, &Column{ + DbColumn: item, + Index: i, + }) + } + } else { + list = append(list, &Column{ + DbColumn: item, + }) + } + } + + sort.Slice(list, func(i, j int) bool { + return list[i].OrdinalPosition < list[j].OrdinalPosition + }) + + var columnData ColumnData + columnData.Db = db + columnData.Table = table + columnData.Columns = list + return &columnData, nil +} + +func (m *InformationSchemaModel) FindIndex(db, table, column string) ([]*DbIndex, error) { + querySql := `SELECT s.INDEX_NAME,s.NON_UNIQUE,s.SEQ_IN_INDEX from STATISTICS s WHERE s.TABLE_SCHEMA = ? and s.TABLE_NAME = ? and s.COLUMN_NAME = ?` + var reply []*DbIndex + err := m.conn.QueryRowsPartial(&reply, querySql, db, table, column) + if err != nil { + return nil, err + } + + return reply, nil +} + +// Convert converts column data into Table +func (c *ColumnData) Convert() (*Table, error) { + var table Table + table.Table = c.Table + table.Db = c.Db + table.Columns = c.Columns + table.UniqueIndex = map[string][]*Column{} + table.NormalIndex = map[string][]*Column{} + + m := make(map[string][]*Column) + for _, each := range c.Columns { + if each.Index != nil { + m[each.Index.IndexName] = append(m[each.Index.IndexName], each) + } + } + + primaryColumns := m[indexPri] + if len(primaryColumns) == 0 { + return nil, fmt.Errorf("db:%s, table:%s, missing primary key", c.Db, c.Table) + } + + if len(primaryColumns) > 1 { + return nil, fmt.Errorf("db:%s, table:%s, joint primary key is not supported", c.Db, c.Table) + } + + table.PrimaryKey = primaryColumns[0] + for indexName, columns := range m { + if indexName == indexPri { + continue + } + + for _, one := range columns { + if one.Index != nil { + if one.Index.NonUnique == 0 { + table.UniqueIndex[indexName] = columns + } else { + table.NormalIndex[indexName] = columns + } + } + } + } + + return &table, nil } diff --git a/tools/goctl/model/sql/parser/parser.go b/tools/goctl/model/sql/parser/parser.go index e8c24327..ce454c44 100644 --- a/tools/goctl/model/sql/parser/parser.go +++ b/tools/goctl/model/sql/parser/parser.go @@ -2,30 +2,27 @@ package parser import ( "fmt" + "sort" "strings" + "github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/tools/goctl/model/sql/converter" "github.com/tal-tech/go-zero/tools/goctl/model/sql/model" + "github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/xwb1989/sqlparser" ) -const ( - _ = iota - primary - unique - normal - spatial -) - const timeImport = "time.Time" type ( // Table describes a mysql table Table struct { - Name stringx.String - PrimaryKey Primary - Fields []Field + Name stringx.String + PrimaryKey Primary + UniqueIndex map[string][]*Field + NormalIndex map[string][]*Field + Fields []*Field } // Primary describes a primary key @@ -36,12 +33,12 @@ type ( // Field describes a table field Field struct { - Name stringx.String - DataBaseType string - DataType string - IsPrimaryKey bool - IsUniqueKey bool - Comment string + Name stringx.String + DataBaseType string + DataType string + Comment string + SeqInIndex int + OrdinalPosition int } // KeyType types alias of int @@ -73,34 +70,58 @@ func Parse(ddl string) (*Table, error) { columns := tableSpec.Columns indexes := tableSpec.Indexes - keyMap, err := getIndexKeyType(indexes) + primaryColumn, uniqueKeyMap, normalKeyMap, err := convertIndexes(indexes) if err != nil { return nil, err } - fields, primaryKey, err := convertFileds(columns, keyMap) + fields, primaryKey, fieldM, err := convertColumns(columns, primaryColumn) if err != nil { return nil, err } + var ( + uniqueIndex = make(map[string][]*Field) + normalIndex = make(map[string][]*Field) + ) + for indexName, each := range uniqueKeyMap { + for _, columnName := range each { + uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName]) + } + } + + for indexName, each := range normalKeyMap { + for _, columnName := range each { + normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName]) + } + } + return &Table{ - Name: stringx.From(tableName), - PrimaryKey: primaryKey, - Fields: fields, + Name: stringx.From(tableName), + PrimaryKey: primaryKey, + UniqueIndex: uniqueIndex, + NormalIndex: normalIndex, + Fields: fields, }, nil } -func convertFileds(columns []*sqlparser.ColumnDefinition, keyMap map[string]KeyType) ([]Field, Primary, error) { - var fields []Field - var primaryKey Primary +func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) ([]*Field, Primary, map[string]*Field, error) { + var ( + fields []*Field + primaryKey Primary + fieldM = make(map[string]*Field) + ) + for _, column := range columns { if column == nil { continue } + var comment string if column.Type.Comment != nil { comment = string(column.Type.Comment.Val) } + var isDefaultNull = true if column.Type.NotNull { isDefaultNull = false @@ -111,9 +132,10 @@ func convertFileds(columns []*sqlparser.ColumnDefinition, keyMap map[string]KeyT isDefaultNull = false } } + dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull) if err != nil { - return nil, primaryKey, err + return nil, Primary{}, nil, err } var field Field @@ -121,60 +143,75 @@ func convertFileds(columns []*sqlparser.ColumnDefinition, keyMap map[string]KeyT field.DataBaseType = column.Type.Type field.DataType = dataType field.Comment = comment - key, ok := keyMap[column.Name.String()] - if ok { - field.IsPrimaryKey = key == primary - field.IsUniqueKey = key == unique - if field.IsPrimaryKey { - primaryKey.Field = field - if column.Type.Autoincrement { - primaryKey.AutoIncrement = true - } + + if field.Name.Source() == primaryColumn { + primaryKey = Primary{ + Field: field, + AutoIncrement: bool(column.Type.Autoincrement), } } - fields = append(fields, field) + + fields = append(fields, &field) + fieldM[field.Name.Source()] = &field } - return fields, primaryKey, nil + return fields, primaryKey, fieldM, nil } -func getIndexKeyType(indexes []*sqlparser.IndexDefinition) (map[string]KeyType, error) { - keyMap := make(map[string]KeyType) +func convertIndexes(indexes []*sqlparser.IndexDefinition) (string, map[string][]string, map[string][]string, error) { + var primaryColumn string + uniqueKeyMap := make(map[string][]string) + normalKeyMap := make(map[string][]string) + + isCreateTimeOrUpdateTime := func(name string) bool { + camelColumnName := stringx.From(name).ToCamel() + // by default, createTime|updateTime findOne is not used. + return camelColumnName == "CreateTime" || camelColumnName == "UpdateTime" + } + for _, index := range indexes { info := index.Info if info == nil { continue } + + indexName := index.Info.Name.String() if info.Primary { if len(index.Columns) > 1 { - return nil, errPrimaryKey + return "", nil, nil, errPrimaryKey + } + columnName := index.Columns[0].Column.String() + if isCreateTimeOrUpdateTime(columnName) { + continue } - keyMap[index.Columns[0].Column.String()] = primary - continue - } - // can optimize - if len(index.Columns) > 1 { + primaryColumn = columnName continue - } - column := index.Columns[0] - columnName := column.Column.String() - camelColumnName := stringx.From(columnName).ToCamel() - // by default, createTime|updateTime findOne is not used. - if camelColumnName == "CreateTime" || camelColumnName == "UpdateTime" { - continue - } - if info.Unique { - keyMap[columnName] = unique + } else if info.Unique { + for _, each := range index.Columns { + columnName := each.Column.String() + if isCreateTimeOrUpdateTime(columnName) { + break + } + + uniqueKeyMap[indexName] = append(uniqueKeyMap[indexName], columnName) + } } else if info.Spatial { - keyMap[columnName] = spatial + // do nothing } else { - keyMap[columnName] = normal + for _, each := range index.Columns { + columnName := each.Column.String() + if isCreateTimeOrUpdateTime(columnName) { + break + } + + normalKeyMap[indexName] = append(normalKeyMap[indexName], each.Column.String()) + } } } - return keyMap, nil + return primaryColumn, uniqueKeyMap, normalKeyMap, nil } -// ContainsTime determines whether the table field contains time.Time +// ContainsTime returns true if contains golang type time.Time func (t *Table) ContainsTime() bool { for _, item := range t.Fields { if item.DataType == timeImport { @@ -184,63 +221,110 @@ func (t *Table) ContainsTime() bool { return false } -// ConvertColumn provides type conversion for mysql clolumn, primary key lookup -func ConvertColumn(db, table string, in []*model.Column) (*Table, error) { - var reply Table - reply.Name = stringx.From(table) - keyMap := make(map[string][]*model.Column) - - for _, column := range in { - keyMap[column.Key] = append(keyMap[column.Key], column) +// ConvertDataType converts mysql data type into golang data type +func ConvertDataType(table *model.Table) (*Table, error) { + isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES" + primaryDataType, err := converter.ConvertDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull) + if err != nil { + return nil, err } - primaryColumns := keyMap["PRI"] - if len(primaryColumns) == 0 { - return nil, fmt.Errorf("database:%s, table %s: missing primary key", db, table) + + var reply Table + reply.UniqueIndex = map[string][]*Field{} + reply.NormalIndex = map[string][]*Field{} + reply.Name = stringx.From(table.Table) + seqInIndex := 0 + if table.PrimaryKey.Index != nil { + seqInIndex = table.PrimaryKey.Index.SeqInIndex } - if len(primaryColumns) > 1 { - return nil, fmt.Errorf("database:%s, table %s: only one primary key expected", db, table) + reply.PrimaryKey = Primary{ + Field: Field{ + Name: stringx.From(table.PrimaryKey.Name), + DataBaseType: table.PrimaryKey.DataType, + DataType: primaryDataType, + Comment: table.PrimaryKey.Comment, + SeqInIndex: seqInIndex, + OrdinalPosition: table.PrimaryKey.OrdinalPosition, + }, + AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"), } - primaryColumn := primaryColumns[0] - isDefaultNull := primaryColumn.ColumnDefault == nil && primaryColumn.IsNullAble == "YES" - primaryFt, err := converter.ConvertDataType(primaryColumn.DataType, isDefaultNull) - if err != nil { - return nil, err + fieldM := make(map[string]*Field) + for _, each := range table.Columns { + isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES" + dt, err := converter.ConvertDataType(each.DataType, isDefaultNull) + if err != nil { + return nil, err + } + columnSeqInIndex := 0 + if each.Index != nil { + columnSeqInIndex = each.Index.SeqInIndex + } + + field := &Field{ + Name: stringx.From(each.Name), + DataBaseType: each.DataType, + DataType: dt, + Comment: each.Comment, + SeqInIndex: columnSeqInIndex, + OrdinalPosition: each.OrdinalPosition, + } + fieldM[each.Name] = field } - primaryField := Field{ - Name: stringx.From(primaryColumn.Name), - DataBaseType: primaryColumn.DataType, - DataType: primaryFt, - IsUniqueKey: true, - IsPrimaryKey: true, - Comment: primaryColumn.Comment, + for _, each := range fieldM { + reply.Fields = append(reply.Fields, each) } - reply.PrimaryKey = Primary{ - Field: primaryField, - AutoIncrement: strings.Contains(primaryColumn.Extra, "auto_increment"), - } - for key, columns := range keyMap { - for _, item := range columns { - isColumnDefaultNull := item.ColumnDefault == nil && item.IsNullAble == "YES" - dt, err := converter.ConvertDataType(item.DataType, isColumnDefaultNull) - if err != nil { - return nil, err - } + sort.Slice(reply.Fields, func(i, j int) bool { + return reply.Fields[i].OrdinalPosition < reply.Fields[j].OrdinalPosition + }) - f := Field{ - Name: stringx.From(item.Name), - DataBaseType: item.DataType, - DataType: dt, - IsPrimaryKey: primaryColumn.Name == item.Name, - Comment: item.Comment, + uniqueIndexSet := collection.NewSet() + log := console.NewColorConsole() + for indexName, each := range table.UniqueIndex { + sort.Slice(each, func(i, j int) bool { + if each[i].Index != nil { + return each[i].Index.SeqInIndex < each[j].Index.SeqInIndex } - if key == "UNI" { - f.IsUniqueKey = true + return false + }) + + if len(each) == 1 { + one := each[0] + if one.Name == table.PrimaryKey.Name { + log.Warning("duplicate unique index with primary key, %s", one.Name) + continue } - reply.Fields = append(reply.Fields, f) } + + var list []*Field + var uniqueJoin []string + for _, c := range each { + list = append(list, fieldM[c.Name]) + uniqueJoin = append(uniqueJoin, c.Name) + } + + uniqueKey := strings.Join(uniqueJoin, ",") + if uniqueIndexSet.Contains(uniqueKey) { + log.Warning("duplicate unique index, %s", uniqueKey) + continue + } + + reply.UniqueIndex[indexName] = list + } + + for indexName, each := range table.NormalIndex { + var list []*Field + for _, c := range each { + list = append(list, fieldM[c.Name]) + } + + sort.Slice(list, func(i, j int) bool { + return list[i].SeqInIndex < list[j].SeqInIndex + }) + + reply.NormalIndex[indexName] = list } return &reply, nil diff --git a/tools/goctl/model/sql/parser/parser_test.go b/tools/goctl/model/sql/parser/parser_test.go index def99a15..de2faf22 100644 --- a/tools/goctl/model/sql/parser/parser_test.go +++ b/tools/goctl/model/sql/parser/parser_test.go @@ -1,10 +1,12 @@ package parser import ( + "sort" "testing" "github.com/stretchr/testify/assert" "github.com/tal-tech/go-zero/tools/goctl/model/sql/model" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) func TestParsePlainText(t *testing.T) { @@ -18,68 +20,158 @@ func TestParseSelect(t *testing.T) { } func TestParseCreateTable(t *testing.T) { - table, err := Parse("CREATE TABLE `user_snake` (\n `id` bigint(10) NOT NULL AUTO_INCREMENT,\n `name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',\n `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',\n `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',\n `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',\n `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `name_index` (`name`),\n KEY `mobile_index` (`mobile`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;") + table, err := Parse("CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL,\n `class` bigint NOT NULL,\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;") assert.Nil(t, err) - assert.Equal(t, "user_snake", table.Name.Source()) + assert.Equal(t, "test_user", table.Name.Source()) assert.Equal(t, "id", table.PrimaryKey.Name.Source()) assert.Equal(t, true, table.ContainsTime()) + assert.Equal(t, true, func() bool { + mobileUniqueIndex, ok := table.UniqueIndex["mobile_unique"] + if !ok { + return false + } + + classNameUniqueIndex, ok := table.UniqueIndex["class_name_unique"] + if !ok { + return false + } + + equal := func(f1, f2 []*Field) bool { + sort.Slice(f1, func(i, j int) bool { + return f1[i].Name.Source() < f1[j].Name.Source() + }) + sort.Slice(f2, func(i, j int) bool { + return f2[i].Name.Source() < f2[j].Name.Source() + }) + + if len(f2) != len(f2) { + return false + } + + for index, f := range f1 { + if f1[index].Name.Source() != f.Name.Source() { + return false + } + } + return true + } + + if !equal(mobileUniqueIndex, []*Field{ + { + Name: stringx.From("mobile"), + DataBaseType: "varchar", + DataType: "string", + SeqInIndex: 1, + }, + }) { + return false + } + + return equal(classNameUniqueIndex, []*Field{ + { + Name: stringx.From("class"), + DataBaseType: "bigint", + DataType: "int64", + SeqInIndex: 1, + }, + { + Name: stringx.From("name"), + DataBaseType: "varchar", + DataType: "string", + SeqInIndex: 2, + }, + }) + }()) } func TestConvertColumn(t *testing.T) { - _, err := ConvertColumn("user", "user", []*model.Column{ - { - Name: "id", - DataType: "bigint", - Key: "", - Extra: "", - Comment: "", - }, + t.Run("missingPrimaryKey", func(t *testing.T) { + columnData := model.ColumnData{ + Db: "user", + Table: "user", + Columns: []*model.Column{ + { + DbColumn: &model.DbColumn{ + Name: "id", + DataType: "bigint", + }, + }, + }, + } + _, err := columnData.Convert() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "missing primary key") }) - assert.NotNil(t, err) - assert.Contains(t, err.Error(), "missing primary key") - _, err = ConvertColumn("user", "user", []*model.Column{ - { - Name: "id", - DataType: "bigint", - Key: "PRI", - Extra: "", - Comment: "", - }, - { - Name: "mobile", - DataType: "varchar", - Key: "PRI", - Extra: "", - Comment: "手机号", - }, + t.Run("jointPrimaryKey", func(t *testing.T) { + columnData := model.ColumnData{ + Db: "user", + Table: "user", + Columns: []*model.Column{ + { + DbColumn: &model.DbColumn{ + Name: "id", + DataType: "bigint", + }, + Index: &model.DbIndex{ + IndexName: "PRIMARY", + }, + }, + { + DbColumn: &model.DbColumn{ + Name: "mobile", + DataType: "varchar", + Comment: "手机号", + }, + Index: &model.DbIndex{ + IndexName: "PRIMARY", + }, + }, + }, + } + _, err := columnData.Convert() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "joint primary key is not supported") }) - assert.NotNil(t, err) - assert.Contains(t, err.Error(), "only one primary key expected") - table, err := ConvertColumn("user", "user", []*model.Column{ - { - Name: "id", - DataType: "bigint", - Key: "PRI", - Extra: "auto_increment", - Comment: "", - }, - { - Name: "mobile", - DataType: "varchar", - Key: "UNI", - Extra: "", - Comment: "手机号", - }, - }) - assert.Nil(t, err) - assert.True(t, table.PrimaryKey.AutoIncrement && table.PrimaryKey.IsPrimaryKey) - assert.Equal(t, "id", table.PrimaryKey.Name.Source()) - for _, item := range table.Fields { - if item.Name.Source() == "mobile" { - assert.True(t, item.IsUniqueKey) - break + t.Run("normal", func(t *testing.T) { + columnData := model.ColumnData{ + Db: "user", + Table: "user", + Columns: []*model.Column{ + { + DbColumn: &model.DbColumn{ + Name: "id", + DataType: "bigint", + Extra: "auto_increment", + }, + Index: &model.DbIndex{ + IndexName: "PRIMARY", + SeqInIndex: 1, + }, + }, + { + DbColumn: &model.DbColumn{ + Name: "mobile", + DataType: "varchar", + Comment: "手机号", + }, + Index: &model.DbIndex{ + IndexName: "mobile_unique", + SeqInIndex: 1, + }, + }, + }, } - } + + table, err := columnData.Convert() + assert.Nil(t, err) + assert.True(t, table.PrimaryKey.Index.IndexName == "PRIMARY" && table.PrimaryKey.Name == "id") + for _, item := range table.Columns { + if item.Name == "mobile" { + assert.True(t, item.Index.NonUnique == 0) + break + } + } + }) } diff --git a/tools/goctl/model/sql/template/find.go b/tools/goctl/model/sql/template/find.go index d8b0b1ad..2e23531a 100644 --- a/tools/goctl/model/sql/template/find.go +++ b/tools/goctl/model/sql/template/find.go @@ -36,7 +36,7 @@ func (m *default{{.upperStartCamelObject}}Model) FindOneBy{{.upperField}}({{.in} {{if .withCache}}{{.cacheKey}} var resp {{.upperStartCamelObject}} err := m.QueryRowIndex(&resp, {{.cacheKeyVariable}}, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) { - query := fmt.Sprintf("select %s from %s where {{.originalField}} = ? limit 1", {{.lowerStartCamelObject}}Rows, m.table) + query := fmt.Sprintf("select %s from %s where {{.originalField}} limit 1", {{.lowerStartCamelObject}}Rows, m.table) if err := conn.QueryRow(&resp, query, {{.lowerStartCamelField}}); err != nil { return nil, err } @@ -51,7 +51,7 @@ func (m *default{{.upperStartCamelObject}}Model) FindOneBy{{.upperField}}({{.in} return nil, err } }{{else}}var resp {{.upperStartCamelObject}} - query := fmt.Sprintf("select %s from %s where {{.originalField}} = ? limit 1", {{.lowerStartCamelObject}}Rows, m.table ) + query := fmt.Sprintf("select %s from %s where {{.originalField}} limit 1", {{.lowerStartCamelObject}}Rows, m.table ) err := m.conn.QueryRow(&resp, query, {{.lowerStartCamelField}}) switch err { case nil: diff --git a/tools/goctl/model/sql/test/model/model_test.go b/tools/goctl/model/sql/test/model/model_test.go index 4f043f49..6b51de5e 100644 --- a/tools/goctl/model/sql/test/model/model_test.go +++ b/tools/goctl/model/sql/test/model/model_test.go @@ -2,6 +2,7 @@ package model import ( "database/sql" + "encoding/json" "fmt" "testing" "time" @@ -20,11 +21,13 @@ func TestStudentModel(t *testing.T) { testTable = "`student`" testUpdateName = "gozero1" testRowsAffected int64 = 1 - testInsertID int64 = 1 + testInsertId int64 = 1 + class = "一年级1班" ) var data Student - data.ID = testInsertID + data.Id = testInsertId + data.Class = class data.Name = "gozero" data.Age = sql.NullInt64{ Int64: 1, @@ -42,15 +45,15 @@ func TestStudentModel(t *testing.T) { err := mockStudent(func(mock sqlmock.Sqlmock) { mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)). - WithArgs(data.Name, data.Age, data.Score). - WillReturnResult(sqlmock.NewResult(testInsertID, testRowsAffected)) - }, func(m StudentModel) { + WithArgs(data.Class, data.Name, data.Age, data.Score). + WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected)) + }, func(m StudentModel, redis *redis.Redis) { r, err := m.Insert(data) assert.Nil(t, err) - lastInsertID, err := r.LastInsertId() + lastInsertId, err := r.LastInsertId() assert.Nil(t, err) - assert.Equal(t, testInsertID, lastInsertID) + assert.Equal(t, testInsertId, lastInsertId) rowsAffected, err := r.RowsAffected() assert.Nil(t, err) @@ -60,42 +63,85 @@ func TestStudentModel(t *testing.T) { err = mockStudent(func(mock sqlmock.Sqlmock) { mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)). - WithArgs(testInsertID). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertID, data.Name, data.Age, data.Score, testTimeValue, testTimeValue)) - }, func(m StudentModel) { - result, err := m.FindOne(testInsertID) + WithArgs(testInsertId). + WillReturnRows(sqlmock.NewRows([]string{"id", "class", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Class, data.Name, data.Age, data.Score, testTimeValue, testTimeValue)) + }, func(m StudentModel, redis *redis.Redis) { + result, err := m.FindOne(testInsertId) assert.Nil(t, err) assert.Equal(t, *result, data) + + var resp Student + val, err := redis.Get(fmt.Sprintf("%s%v", cacheStudentIdPrefix, testInsertId)) + assert.Nil(t, err) + err = json.Unmarshal([]byte(val), &resp) + assert.Nil(t, err) + assert.Equal(t, resp.Name, data.Name) }) assert.Nil(t, err) err = mockStudent(func(mock sqlmock.Sqlmock) { - mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(testUpdateName, data.Age, data.Score, testInsertID).WillReturnResult(sqlmock.NewResult(testInsertID, testRowsAffected)) - }, func(m StudentModel) { + mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(data.Class, testUpdateName, data.Age, data.Score, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected)) + }, func(m StudentModel, redis *redis.Redis) { data.Name = testUpdateName err := m.Update(data) assert.Nil(t, err) + + val, err := redis.Get(fmt.Sprintf("%s%v", cacheStudentIdPrefix, testInsertId)) + assert.Nil(t, err) + assert.Equal(t, "", val) }) assert.Nil(t, err) + data.Name = testUpdateName err = mockStudent(func(mock sqlmock.Sqlmock) { mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)). - WithArgs(testInsertID). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertID, data.Name, data.Age, data.Score, testTimeValue, testTimeValue)) - }, func(m StudentModel) { - result, err := m.FindOne(testInsertID) + WithArgs(testInsertId). + WillReturnRows(sqlmock.NewRows([]string{"id", "class", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Class, data.Name, data.Age, data.Score, testTimeValue, testTimeValue)) + }, func(m StudentModel, redis *redis.Redis) { + result, err := m.FindOne(testInsertId) assert.Nil(t, err) assert.Equal(t, *result, data) + + var resp Student + val, err := redis.Get(fmt.Sprintf("%s%v", cacheStudentIdPrefix, testInsertId)) + assert.Nil(t, err) + err = json.Unmarshal([]byte(val), &resp) + assert.Nil(t, err) + assert.Equal(t, testUpdateName, data.Name) }) assert.Nil(t, err) err = mockStudent(func(mock sqlmock.Sqlmock) { - mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertID).WillReturnResult(sqlmock.NewResult(testInsertID, testRowsAffected)) - }, func(m StudentModel) { - err := m.Delete(testInsertID) + mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)). + WithArgs(class, testUpdateName). + WillReturnRows(sqlmock.NewRows([]string{"id", "class", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Class, data.Name, data.Age, data.Score, testTimeValue, testTimeValue)) + }, func(m StudentModel, redis *redis.Redis) { + result, err := m.FindOneByClassName(class, testUpdateName) + assert.Nil(t, err) + assert.Equal(t, *result, data) + + val, err := redis.Get(fmt.Sprintf("%s%v%v", cacheStudentClassNamePrefix, class, testUpdateName)) assert.Nil(t, err) + assert.Equal(t, "1", val) }) assert.Nil(t, err) + + err = mockStudent(func(mock sqlmock.Sqlmock) { + mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected)) + }, func(m StudentModel, redis *redis.Redis) { + err = m.Delete(testInsertId, class, testUpdateName) + assert.Nil(t, err) + + val, err := redis.Get(fmt.Sprintf("%s%v", cacheStudentIdPrefix, testInsertId)) + assert.Nil(t, err) + assert.Equal(t, "", val) + + val, err = redis.Get(fmt.Sprintf("%s%v%v", cacheStudentClassNamePrefix, class, testUpdateName)) + assert.Nil(t, err) + assert.Equal(t, "", val) + }) + + assert.Nil(t, err) } func TestUserModel(t *testing.T) { @@ -109,11 +155,11 @@ func TestUserModel(t *testing.T) { testGender = "男" testNickname = "test_nickname" testRowsAffected int64 = 1 - testInsertID int64 = 1 + testInsertId int64 = 1 ) var data User - data.ID = testInsertID + data.ID = testInsertId data.User = testUser data.Name = "gozero" data.Password = testPassword @@ -126,14 +172,14 @@ func TestUserModel(t *testing.T) { err := mockUser(func(mock sqlmock.Sqlmock) { mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)). WithArgs(data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname). - WillReturnResult(sqlmock.NewResult(testInsertID, testRowsAffected)) + WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected)) }, func(m UserModel) { r, err := m.Insert(data) assert.Nil(t, err) - lastInsertID, err := r.LastInsertId() + lastInsertId, err := r.LastInsertId() assert.Nil(t, err) - assert.Equal(t, testInsertID, lastInsertID) + assert.Equal(t, testInsertId, lastInsertId) rowsAffected, err := r.RowsAffected() assert.Nil(t, err) @@ -143,17 +189,17 @@ func TestUserModel(t *testing.T) { err = mockUser(func(mock sqlmock.Sqlmock) { mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)). - WithArgs(testInsertID). - WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertID, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue)) + WithArgs(testInsertId). + WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue)) }, func(m UserModel) { - result, err := m.FindOne(testInsertID) + result, err := m.FindOne(testInsertId) assert.Nil(t, err) assert.Equal(t, *result, data) }) assert.Nil(t, err) err = mockUser(func(mock sqlmock.Sqlmock) { - mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(data.User, testUpdateName, data.Password, data.Mobile, data.Gender, data.Nickname, testInsertID).WillReturnResult(sqlmock.NewResult(testInsertID, testRowsAffected)) + mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(data.User, testUpdateName, data.Password, data.Mobile, data.Gender, data.Nickname, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected)) }, func(m UserModel) { data.Name = testUpdateName err := m.Update(data) @@ -163,26 +209,26 @@ func TestUserModel(t *testing.T) { err = mockUser(func(mock sqlmock.Sqlmock) { mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)). - WithArgs(testInsertID). - WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertID, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue)) + WithArgs(testInsertId). + WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue)) }, func(m UserModel) { - result, err := m.FindOne(testInsertID) + result, err := m.FindOne(testInsertId) assert.Nil(t, err) assert.Equal(t, *result, data) }) assert.Nil(t, err) err = mockUser(func(mock sqlmock.Sqlmock) { - mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertID).WillReturnResult(sqlmock.NewResult(testInsertID, testRowsAffected)) + mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected)) }, func(m UserModel) { - err := m.Delete(testInsertID) + err := m.Delete(testInsertId) assert.Nil(t, err) }) assert.Nil(t, err) } // with cache -func mockStudent(mockFn func(mock sqlmock.Sqlmock), fn func(m StudentModel)) error { +func mockStudent(mockFn func(mock sqlmock.Sqlmock), fn func(m StudentModel, r *redis.Redis)) error { db, mock, err := sqlmock.New() if err != nil { return err @@ -211,7 +257,9 @@ func mockStudent(mockFn func(mock sqlmock.Sqlmock), fn func(m StudentModel)) err Weight: 100, }, }) - fn(m) + mock.ExpectBegin() + fn(m, r) + mock.ExpectCommit() return nil } diff --git a/tools/goctl/model/sql/test/model/studentmodel.go b/tools/goctl/model/sql/test/model/studentmodel.go index f8aaeeec..2acac963 100755 --- a/tools/goctl/model/sql/test/model/studentmodel.go +++ b/tools/goctl/model/sql/test/model/studentmodel.go @@ -19,16 +19,19 @@ var ( studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",") studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?" - cacheStudentIDPrefix = "cache#Student#id#" + cacheStudentIdPrefix = "cache#student#id#" + cacheStudentClassNamePrefix = "cache#student#class#name#" ) type ( - // StudentModel defines a model for Student + // StudentModel only for test StudentModel interface { Insert(data Student) (sql.Result, error) FindOne(id int64) (*Student, error) + FindOneByClassName(class string, name string) (*Student, error) Update(data Student) error - Delete(id int64) error + // only for test + Delete(id int64, className, studentName string) error } defaultStudentModel struct { @@ -36,9 +39,10 @@ type ( table string } - // Student defines an data structure for mysql + // Student only for test Student struct { - ID int64 `db:"id"` + Id int64 `db:"id"` + Class string `db:"class"` Name string `db:"name"` Age sql.NullInt64 `db:"age"` Score sql.NullFloat64 `db:"score"` @@ -47,7 +51,7 @@ type ( } ) -// NewStudentModel creates an instance for StudentModel +// NewStudentModel only for test func NewStudentModel(conn sqlx.SqlConn, c cache.CacheConf) StudentModel { return &defaultStudentModel{ CachedConn: sqlc.NewConn(conn, c), @@ -56,16 +60,18 @@ func NewStudentModel(conn sqlx.SqlConn, c cache.CacheConf) StudentModel { } func (m *defaultStudentModel) Insert(data Student) (sql.Result, error) { - query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?)", m.table, studentRowsExpectAutoSet) - ret, err := m.ExecNoCache(query, data.Name, data.Age, data.Score) - + studentClassNameKey := fmt.Sprintf("%s%v%v", cacheStudentClassNamePrefix, data.Class, data.Name) + ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { + query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?)", m.table, studentRowsExpectAutoSet) + return conn.Exec(query, data.Class, data.Name, data.Age, data.Score) + }, studentClassNameKey) return ret, err } func (m *defaultStudentModel) FindOne(id int64) (*Student, error) { - studentIDKey := fmt.Sprintf("%s%v", cacheStudentIDPrefix, id) + studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id) var resp Student - err := m.QueryRow(&resp, studentIDKey, func(conn sqlx.SqlConn, v interface{}) error { + err := m.QueryRow(&resp, studentIdKey, func(conn sqlx.SqlConn, v interface{}) error { query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table) return conn.QueryRow(v, query, id) }) @@ -79,27 +85,47 @@ func (m *defaultStudentModel) FindOne(id int64) (*Student, error) { } } +func (m *defaultStudentModel) FindOneByClassName(class string, name string) (*Student, error) { + studentClassNameKey := fmt.Sprintf("%s%v%v", cacheStudentClassNamePrefix, class, name) + var resp Student + err := m.QueryRowIndex(&resp, studentClassNameKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) { + query := fmt.Sprintf("select %s from %s where `class` = ? and `name` = ? limit 1", studentRows, m.table) + if err := conn.QueryRow(&resp, query, class, name); err != nil { + return nil, err + } + return resp.Id, nil + }, m.queryPrimary) + switch err { + case nil: + return &resp, nil + case sqlc.ErrNotFound: + return nil, ErrNotFound + default: + return nil, err + } +} + func (m *defaultStudentModel) Update(data Student) error { - studentIDKey := fmt.Sprintf("%s%v", cacheStudentIDPrefix, data.ID) + studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, data.Id) _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, studentRowsWithPlaceHolder) - return conn.Exec(query, data.Name, data.Age, data.Score, data.ID) - }, studentIDKey) + return conn.Exec(query, data.Class, data.Name, data.Age, data.Score, data.Id) + }, studentIdKey) return err } -func (m *defaultStudentModel) Delete(id int64) error { - - studentIDKey := fmt.Sprintf("%s%v", cacheStudentIDPrefix, id) +func (m *defaultStudentModel) Delete(id int64, className, studentName string) error { + studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id) + studentClassNameKey := fmt.Sprintf("%s%v%v", cacheStudentClassNamePrefix, className, studentName) _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { query := fmt.Sprintf("delete from %s where `id` = ?", m.table) return conn.Exec(query, id) - }, studentIDKey) + }, studentIdKey, studentClassNameKey) return err } func (m *defaultStudentModel) formatPrimary(primary interface{}) string { - return fmt.Sprintf("%s%v", cacheStudentIDPrefix, primary) + return fmt.Sprintf("%s%v", cacheStudentIdPrefix, primary) } func (m *defaultStudentModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error { diff --git a/tools/goctl/model/sql/test/model/usermodel.go b/tools/goctl/model/sql/test/model/usermodel.go index b3f773fe..6735274c 100755 --- a/tools/goctl/model/sql/test/model/usermodel.go +++ b/tools/goctl/model/sql/test/model/usermodel.go @@ -13,10 +13,10 @@ import ( ) var ( - userFieldNames = builderx.FieldNames(&User{}) + userFieldNames = builderx.RawFieldNames(&User{}) userRows = strings.Join(userFieldNames, ",") - userRowsExpectAutoSet = strings.Join(stringx.Remove(userFieldNames, "id", "create_time", "update_time"), ",") - userRowsWithPlaceHolder = strings.Join(stringx.Remove(userFieldNames, "id", "create_time", "update_time"), "=?,") + "=?" + userRowsExpectAutoSet = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), ",") + userRowsWithPlaceHolder = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?" ) type ( @@ -25,8 +25,8 @@ type ( Insert(data User) (sql.Result, error) FindOne(id int64) (*User, error) FindOneByUser(user string) (*User, error) - FindOneByName(name string) (*User, error) FindOneByMobile(mobile string) (*User, error) + FindOneByName(name string) (*User, error) Update(data User) error Delete(id int64) error } @@ -92,10 +92,10 @@ func (m *defaultUserModel) FindOneByUser(user string) (*User, error) { } } -func (m *defaultUserModel) FindOneByName(name string) (*User, error) { +func (m *defaultUserModel) FindOneByMobile(mobile string) (*User, error) { var resp User - query := fmt.Sprintf("select %s from %s where `name` = ? limit 1", userRows, m.table) - err := m.conn.QueryRow(&resp, query, name) + query := fmt.Sprintf("select %s from %s where `mobile` = ? limit 1", userRows, m.table) + err := m.conn.QueryRow(&resp, query, mobile) switch err { case nil: return &resp, nil @@ -106,10 +106,10 @@ func (m *defaultUserModel) FindOneByName(name string) (*User, error) { } } -func (m *defaultUserModel) FindOneByMobile(mobile string) (*User, error) { +func (m *defaultUserModel) FindOneByName(name string) (*User, error) { var resp User - query := fmt.Sprintf("select %s from %s where `mobile` = ? limit 1", userRows, m.table) - err := m.conn.QueryRow(&resp, query, mobile) + query := fmt.Sprintf("select %s from %s where `name` = ? limit 1", userRows, m.table) + err := m.conn.QueryRow(&resp, query, name) switch err { case nil: return &resp, nil diff --git a/tools/goctl/model/sql/test/utils.go b/tools/goctl/model/sql/test/utils.go index 2ca42ec5..94927967 100644 --- a/tools/goctl/model/sql/test/utils.go +++ b/tools/goctl/model/sql/test/utils.go @@ -11,6 +11,7 @@ import ( "github.com/tal-tech/go-zero/core/mapping" ) +// ErrNotFound is the alias of sql.ErrNoRows var ErrNotFound = sql.ErrNoRows func desensitize(datasource string) string { diff --git a/tools/goctl/util/stringx/string.go b/tools/goctl/util/stringx/string.go index 001e4bdd..3df98449 100644 --- a/tools/goctl/util/stringx/string.go +++ b/tools/goctl/util/stringx/string.go @@ -32,6 +32,11 @@ func (s String) Lower() string { return strings.ToLower(s.source) } +// Upper calls the strings.ToUpper +func (s String) Upper() string { + return strings.ToUpper(s.source) +} + // ReplaceAll calls the strings.ReplaceAll func (s String) ReplaceAll(old, new string) string { return strings.ReplaceAll(s.source, old, new)