diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index 6730142d..0c34edb9 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -49,6 +49,11 @@ type ( deleteCode string cacheExtra string } + + codeTuple struct { + modelCode string + modelCustomCode string + } ) // NewDefaultGenerator creates an instance for defaultGenerator @@ -109,7 +114,7 @@ func (g *defaultGenerator) StartFromDDL(filename string, withCache bool, databas } func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.Table, withCache bool) error { - m := make(map[string]string) + m := make(map[string]*codeTuple) for _, each := range tables { table, err := parser.ConvertDataType(each) if err != nil { @@ -120,14 +125,21 @@ func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.T if err != nil { return err } + customCode, err := g.genModelCustom(*table) + if err != nil { + return err + } - m[table.Name.Source()] = code + m[table.Name.Source()] = &codeTuple{ + modelCode: code, + modelCustomCode: customCode, + } } return g.createFile(m) } -func (g *defaultGenerator) createFile(modelList map[string]string) error { +func (g *defaultGenerator) createFile(modelList map[string]*codeTuple) error { dirAbs, err := filepath.Abs(g.dir) if err != nil { return err @@ -140,20 +152,27 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error { return err } - for tableName, code := range modelList { + for tableName, codes := range modelList { tn := stringx.From(tableName) modelFilename, err := format.FileNamingFormat(g.cfg.NamingFormat, fmt.Sprintf("%s_model", tn.Source())) if err != nil { return err } - name := util.SafeString(modelFilename) + ".go" + name := util.SafeString(modelFilename) + "_gen.go" filename := filepath.Join(dirAbs, name) + err = ioutil.WriteFile(filename, []byte(codes.modelCode), os.ModePerm) + if err != nil { + return err + } + + name = util.SafeString(modelFilename) + ".go" + filename = filepath.Join(dirAbs, name) if pathx.FileExists(filename) { g.Warning("%s already exists, ignored.", name) continue } - err = ioutil.WriteFile(filename, []byte(code), os.ModePerm) + err = ioutil.WriteFile(filename, []byte(codes.modelCustomCode), os.ModePerm) if err != nil { return err } @@ -183,8 +202,8 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error { } // ret1: key-table name,value-code -func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) (map[string]string, error) { - m := make(map[string]string) +func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) (map[string]*codeTuple, error) { + m := make(map[string]*codeTuple) tables, err := parser.Parse(filename, database) if err != nil { return nil, err @@ -195,8 +214,15 @@ func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database if err != nil { return nil, err } + customCode, err := g.genModelCustom(*e) + if err != nil { + return nil, err + } - m[e.Name.Source()] = code + m[e.Name.Source()] = &codeTuple{ + modelCode: code, + modelCustomCode: customCode, + } } return m, nil @@ -292,8 +318,27 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er return output.String(), nil } +func (g *defaultGenerator) genModelCustom(in parser.Table) (string, error) { + text, err := pathx.LoadTemplate(category, modelCustomTemplateFile, template.ModelCustom) + if err != nil { + return "", err + } + t := util.With("model-custom"). + Parse(text). + GoFmt(true) + output, err := t.Execute(map[string]interface{}{ + "pkg": g.pkg, + "upperStartCamelObject": in.Name.ToCamel(), + "lowerStartCamelObject": stringx.From(in.Name.ToCamel()).Untitle(), + }) + if err != nil { + return "", err + } + return output.String(), nil +} + func (g *defaultGenerator) executeModel(table Table, code *code) (*bytes.Buffer, error) { - text, err := pathx.LoadTemplate(category, modelTemplateFile, template.Model) + text, err := pathx.LoadTemplate(category, modelGenTemplateFile, template.ModelGen) if err != nil { return nil, err } diff --git a/tools/goctl/model/sql/gen/gen_test.go b/tools/goctl/model/sql/gen/gen_test.go index f31593b2..5faa6f88 100644 --- a/tools/goctl/model/sql/gen/gen_test.go +++ b/tools/goctl/model/sql/gen/gen_test.go @@ -4,16 +4,19 @@ import ( "database/sql" "io/ioutil" "os" + "path" "path/filepath" "strings" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/tools/goctl/config" "github.com/zeromicro/go-zero/tools/goctl/model/sql/builderx" + "github.com/zeromicro/go-zero/tools/goctl/model/sql/parser" "github.com/zeromicro/go-zero/tools/goctl/util/pathx" ) @@ -121,3 +124,28 @@ func TestFields(t *testing.T) { assert.Equal(t, "`name`,`age`,`score`", studentRowsExpectAutoSet) assert.Equal(t, "`name`=?,`age`=?,`score`=?", studentRowsWithPlaceHolder) } + +func Test_genPublicModel(t *testing.T) { + var err error + dir := pathx.MustTempDir() + modelDir := path.Join(dir, "model") + err = os.MkdirAll(modelDir, 0777) + require.NoError(t, err) + defer os.RemoveAll(dir) + + modelFilename := filepath.Join(modelDir, "foo.sql") + err = ioutil.WriteFile(modelFilename, []byte(source), 0777) + require.NoError(t, err) + + g, err := NewDefaultGenerator(modelDir, &config.Config{ + NamingFormat: config.DefaultFormat, + }) + require.NoError(t, err) + + tables, err := parser.Parse(modelFilename, "") + require.Equal(t, 1, len(tables)) + + code, err := g.genModelCustom(*tables[0]) + assert.NoError(t, err) + assert.Equal(t, "package model\n\ntype TestUserModel interface {\n\ttestUserModel\n}\n", code) +} diff --git a/tools/goctl/model/sql/gen/template.go b/tools/goctl/model/sql/gen/template.go index d9a664d4..cdb0f8cb 100644 --- a/tools/goctl/model/sql/gen/template.go +++ b/tools/goctl/model/sql/gen/template.go @@ -22,7 +22,8 @@ const ( importsWithNoCacheTemplateFile = "import-no-cache.tpl" insertTemplateFile = "insert.tpl" insertTemplateMethodFile = "interface-insert.tpl" - modelTemplateFile = "model.tpl" + modelGenTemplateFile = "model-gen.tpl" + modelCustomTemplateFile = "model.tpl" modelNewTemplateFile = "model-new.tpl" tagTemplateFile = "tag.tpl" typesTemplateFile = "types.tpl" @@ -45,7 +46,8 @@ var templates = map[string]string{ importsWithNoCacheTemplateFile: template.ImportsNoCache, insertTemplateFile: template.Insert, insertTemplateMethodFile: template.InsertMethod, - modelTemplateFile: template.Model, + modelGenTemplateFile: template.ModelGen, + modelCustomTemplateFile: template.ModelCustom, modelNewTemplateFile: template.New, tagTemplateFile: template.Tag, typesTemplateFile: template.Types, diff --git a/tools/goctl/model/sql/gen/types.go b/tools/goctl/model/sql/gen/types.go index a76afabc..4c5f65d8 100644 --- a/tools/goctl/model/sql/gen/types.go +++ b/tools/goctl/model/sql/gen/types.go @@ -4,6 +4,7 @@ import ( "github.com/zeromicro/go-zero/tools/goctl/model/sql/template" "github.com/zeromicro/go-zero/tools/goctl/util" "github.com/zeromicro/go-zero/tools/goctl/util/pathx" + "github.com/zeromicro/go-zero/tools/goctl/util/stringx" ) func genTypes(table Table, methods string, withCache bool) (string, error) { @@ -24,6 +25,7 @@ func genTypes(table Table, methods string, withCache bool) (string, error) { "withCache": withCache, "method": methods, "upperStartCamelObject": table.Name.ToCamel(), + "lowerStartCamelObject": stringx.From(table.Name.ToCamel()).Untitle(), "fields": fieldsString, "data": table, }) diff --git a/tools/goctl/model/sql/template/model.go b/tools/goctl/model/sql/template/model.go index 939f3347..d907d915 100644 --- a/tools/goctl/model/sql/template/model.go +++ b/tools/goctl/model/sql/template/model.go @@ -1,7 +1,15 @@ package template -// Model defines a template for model -var Model = `package {{.pkg}} +import ( + "fmt" + + "github.com/zeromicro/go-zero/tools/goctl/util" +) + +// ModelGen defines a template for model +var ModelGen = fmt.Sprintf(`%s + +package {{.pkg}} {{.imports}} {{.vars}} {{.types}} @@ -11,4 +19,11 @@ var Model = `package {{.pkg}} {{.update}} {{.delete}} {{.extraMethod}} -` +`, util.DoNotEditHead) + +// ModelCustom defines a template for extension +var ModelCustom = fmt.Sprintf(`package {{.pkg}} +type {{.upperStartCamelObject}}Model interface { + {{.lowerStartCamelObject}}Model +} +`) diff --git a/tools/goctl/model/sql/template/types.go b/tools/goctl/model/sql/template/types.go index bb752378..870890e7 100644 --- a/tools/goctl/model/sql/template/types.go +++ b/tools/goctl/model/sql/template/types.go @@ -3,7 +3,7 @@ package template // Types defines a template for types in model var Types = ` type ( - {{.upperStartCamelObject}}Model interface{ + {{.lowerStartCamelObject}}Model interface{ {{.method}} } diff --git a/tools/goctl/model/sql/template/vars.go b/tools/goctl/model/sql/template/vars.go index f5309cfa..8ef09f86 100644 --- a/tools/goctl/model/sql/template/vars.go +++ b/tools/goctl/model/sql/template/vars.go @@ -5,6 +5,8 @@ import "fmt" // Vars defines a template for var block in model var Vars = fmt.Sprintf(` var ( + _ {{.upperStartCamelObject}}Model = (*default{{.upperStartCamelObject}}Model)(nil) + {{.lowerStartCamelObject}}FieldNames = builder.RawFieldNames(&{{.upperStartCamelObject}}{}{{if .postgreSql}},true{{end}}) {{.lowerStartCamelObject}}Rows = strings.Join({{.lowerStartCamelObject}}FieldNames, ",") {{.lowerStartCamelObject}}RowsExpectAutoSet = {{if .postgreSql}}strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ","){{else}}strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ","){{end}} diff --git a/tools/goctl/util/head.go b/tools/goctl/util/head.go index 36829a5e..1f3be01f 100644 --- a/tools/goctl/util/head.go +++ b/tools/goctl/util/head.go @@ -1,5 +1,8 @@ package util +// DoNotEditHead added to the beginning of a file to prompt the user not to edit +var DoNotEditHead = "// Code generated by goctl. DO NOT EDIT!" + var headTemplate = `// Code generated by goctl. DO NOT EDIT! // Source: {{.source}}`