diff --git a/tools/goctl/goctl.go b/tools/goctl/goctl.go index 224be3a8..93377e2d 100644 --- a/tools/goctl/goctl.go +++ b/tools/goctl/goctl.go @@ -201,6 +201,10 @@ var ( Name: "new", Usage: `generate rpc demo service`, Flags: []cli.Flag{ + cli.StringFlag{ + Name: "style", + Usage: "the file naming style, lower|camel|snake,default is lower", + }, cli.BoolFlag{ Name: "idea", Usage: "whether the command execution environment is from idea plugin. [optional]", @@ -235,6 +239,10 @@ var ( Name: "dir, d", Usage: `the target path of the code`, }, + cli.StringFlag{ + Name: "style", + Usage: "the file naming style, lower|camel|snake,default is lower", + }, cli.BoolFlag{ Name: "idea", Usage: "whether the command execution environment is from idea plugin. [optional]", @@ -266,7 +274,7 @@ var ( }, cli.StringFlag{ Name: "style", - Usage: "the file naming style, lower|camel|underline,default is lower", + Usage: "the file naming style, lower|camel|snake,default is lower", }, cli.BoolFlag{ Name: "cache, c", diff --git a/tools/goctl/model/sql/builderx/builder.go b/tools/goctl/model/sql/builderx/builder.go index ad80ba65..3ce38604 100644 --- a/tools/goctl/model/sql/builderx/builder.go +++ b/tools/goctl/model/sql/builderx/builder.go @@ -68,30 +68,3 @@ func FieldNames(in interface{}) []string { } return out } -func FieldNamesAlias(in interface{}, alias string) []string { - out := make([]string, 0) - v := reflect.ValueOf(in) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - // we only accept structs - if v.Kind() != reflect.Struct { - panic(fmt.Errorf("ToMap only accepts structs; got %T", v)) - } - typ := v.Type() - for i := 0; i < v.NumField(); i++ { - // gets us a StructField - fi := typ.Field(i) - tagName := "" - if tagv := fi.Tag.Get(dbTag); tagv != "" { - tagName = tagv - } else { - tagName = fi.Name - } - if len(alias) > 0 { - tagName = alias + "." + tagName - } - out = append(out, tagName) - } - return out -} diff --git a/tools/goctl/model/sql/command/command.go b/tools/goctl/model/sql/command/command.go index fc377ff9..7693a7af 100644 --- a/tools/goctl/model/sql/command/command.go +++ b/tools/goctl/model/sql/command/command.go @@ -17,6 +17,8 @@ import ( "github.com/urfave/cli" ) +var errNotMatched = errors.New("sql not matched") + const ( flagSrc = "src" flagDir = "dir" @@ -33,6 +35,20 @@ func MysqlDDL(ctx *cli.Context) error { cache := ctx.Bool(flagCache) idea := ctx.Bool(flagIdea) namingStyle := strings.TrimSpace(ctx.String(flagStyle)) + return fromDDl(src, dir, namingStyle, cache, idea) +} + +func MyDataSource(ctx *cli.Context) error { + url := strings.TrimSpace(ctx.String(flagUrl)) + dir := strings.TrimSpace(ctx.String(flagDir)) + cache := ctx.Bool(flagCache) + idea := ctx.Bool(flagIdea) + namingStyle := strings.TrimSpace(ctx.String(flagStyle)) + pattern := strings.TrimSpace(ctx.String(flagTable)) + return fromDataSource(url, pattern, dir, namingStyle, cache, idea) +} + +func fromDDl(src, dir, namingStyle string, cache, idea bool) error { log := console.NewConsole(idea) src = strings.TrimSpace(src) if len(src) == 0 { @@ -52,29 +68,29 @@ func MysqlDDL(ctx *cli.Context) error { return err } + if len(files) == 0 { + return errNotMatched + } + var source []string for _, file := range files { data, err := ioutil.ReadFile(file) if err != nil { return err } + source = append(source, string(data)) } - generator := gen.NewDefaultGenerator(strings.Join(source, "\n"), dir, namingStyle, gen.WithConsoleOption(log)) - err = generator.Start(cache) + generator, err := gen.NewDefaultGenerator(dir, namingStyle, gen.WithConsoleOption(log)) if err != nil { - log.Error("%v", err) + return err } - return nil + + err = generator.StartFromDDL(strings.Join(source, "\n"), cache) + return err } -func MyDataSource(ctx *cli.Context) error { - url := strings.TrimSpace(ctx.String(flagUrl)) - dir := strings.TrimSpace(ctx.String(flagDir)) - cache := ctx.Bool(flagCache) - idea := ctx.Bool(flagIdea) - namingStyle := strings.TrimSpace(ctx.String(flagStyle)) - pattern := strings.TrimSpace(ctx.String(flagTable)) +func fromDataSource(url, pattern, dir, namingStyle string, cache, idea bool) error { log := console.NewConsole(idea) if len(url) == 0 { log.Error("%v", "expected data source of mysql, but nothing found") @@ -100,10 +116,8 @@ func MyDataSource(ctx *cli.Context) error { } logx.Disable() - conn := sqlx.NewMysql(url) databaseSource := strings.TrimSuffix(url, "/"+cfg.DBName) + "/information_schema" db := sqlx.NewMysql(databaseSource) - m := model.NewDDLModel(conn) im := model.NewInformationSchemaModel(db) tables, err := im.GetAllTables(cfg.DBName) @@ -111,7 +125,7 @@ func MyDataSource(ctx *cli.Context) error { return err } - var matchTables []string + matchTables := make(map[string][]*model.Column) for _, item := range tables { match, err := filepath.Match(pattern, item) if err != nil { @@ -121,24 +135,22 @@ func MyDataSource(ctx *cli.Context) error { if !match { continue } - - matchTables = append(matchTables, item) + columns, err := im.FindByTableName(cfg.DBName, item) + if err != nil { + return err + } + matchTables[item] = columns } + if len(matchTables) == 0 { return errors.New("no tables matched") } - ddl, err := m.ShowDDL(matchTables...) - if err != nil { - log.Error("%v", err) - return nil - } - - generator := gen.NewDefaultGenerator(strings.Join(ddl, "\n"), dir, namingStyle, gen.WithConsoleOption(log)) - err = generator.Start(cache) + generator, err := gen.NewDefaultGenerator(dir, namingStyle, gen.WithConsoleOption(log)) if err != nil { - log.Error("%v", err) + return err } - return nil + err = generator.StartFromInformationSchema(cfg.DBName, matchTables, cache) + return err } diff --git a/tools/goctl/model/sql/command/command_test.go b/tools/goctl/model/sql/command/command_test.go new file mode 100644 index 00000000..a689cdd0 --- /dev/null +++ b/tools/goctl/model/sql/command/command_test.go @@ -0,0 +1,75 @@ +package command + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen" + "github.com/tal-tech/go-zero/tools/goctl/util" +) + +var sql = "-- 用户表 --\nCREATE TABLE `user` (\n `id` bigint(10) NOT NULL AUTO_INCREMENT,\n `name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',\n `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',\n `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',\n `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',\n `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `name_index` (`name`),\n UNIQUE KEY `mobile_index` (`mobile`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;\n\n" + +func TestFromDDl(t *testing.T) { + err := fromDDl("./user.sql", t.TempDir(), gen.NamingCamel, true, false) + assert.Equal(t, errNotMatched, err) + + // case dir is not exists + unknownDir := filepath.Join(t.TempDir(), "test", "user.sql") + err = fromDDl(unknownDir, t.TempDir(), gen.NamingCamel, true, false) + assert.True(t, func() bool { + switch err.(type) { + case *os.PathError: + return true + default: + return false + } + }()) + + // case empty src + err = fromDDl("", t.TempDir(), gen.NamingCamel, true, false) + if err != nil { + assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error()) + } + + // case unknown naming style + tmp := filepath.Join(t.TempDir(), "user.sql") + err = fromDDl(tmp, t.TempDir(), "lower1", true, false) + if err != nil { + assert.Equal(t, "unexpected naming style: lower1", err.Error()) + } + + tempDir := filepath.Join(t.TempDir(), "test") + err = util.MkdirIfNotExist(tempDir) + if err != nil { + return + } + + user1Sql := filepath.Join(tempDir, "user1.sql") + user2Sql := filepath.Join(tempDir, "user2.sql") + + err = ioutil.WriteFile(user1Sql, []byte(sql), os.ModePerm) + if err != nil { + return + } + + err = ioutil.WriteFile(user2Sql, []byte(sql), os.ModePerm) + if err != nil { + return + } + + _, err = os.Stat(user1Sql) + assert.Nil(t, err) + + _, err = os.Stat(user2Sql) + assert.Nil(t, err) + + err = fromDDl(filepath.Join(tempDir, "user*.sql"), tempDir, gen.NamingLower, true, false) + assert.Nil(t, err) + + _, err = os.Stat(filepath.Join(tempDir, "usermodel.go")) + assert.Nil(t, err) +} diff --git a/tools/goctl/model/sql/example/generator.sh b/tools/goctl/model/sql/example/generator.sh deleted file mode 100644 index 71ead77b..00000000 --- a/tools/goctl/model/sql/example/generator.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -# generate model with cache from ddl -goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/user" -c - -# generate model with cache from data source -#user=root -#password=password -#datasource=127.0.0.1:3306 -#database=test -#goctl model mysql datasource -url="${user}:${password}@tcp(${datasource})/${database}" -table="*" -dir ./model \ No newline at end of file diff --git a/tools/goctl/model/sql/example/makefile b/tools/goctl/model/sql/example/makefile new file mode 100644 index 00000000..8e36f387 --- /dev/null +++ b/tools/goctl/model/sql/example/makefile @@ -0,0 +1,15 @@ +#!/bin/bash + +# generate model with cache from ddl +fromDDL: + goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/user" -c + + +# generate model with cache from data source +user=root +password=password +datasource=127.0.0.1:3306 +database=gozero + +fromDataSource: + goctl model mysql datasource -url="$(user):$(password)@tcp($(datasource))/$(database)" -table="*" -dir ./model/cache -c -style camel \ No newline at end of file diff --git a/tools/goctl/model/sql/gen/delete.go b/tools/goctl/model/sql/gen/delete.go index 4f9bada9..f384b0bd 100644 --- a/tools/goctl/model/sql/gen/delete.go +++ b/tools/goctl/model/sql/gen/delete.go @@ -42,5 +42,6 @@ func genDelete(table Table, withCache bool) (string, error) { if err != nil { return "", err } + return output.String(), nil } diff --git a/tools/goctl/model/sql/gen/field.go b/tools/goctl/model/sql/gen/field.go index d82be07a..f257300e 100644 --- a/tools/goctl/model/sql/gen/field.go +++ b/tools/goctl/model/sql/gen/field.go @@ -15,6 +15,7 @@ func genFields(fields []parser.Field) (string, error) { if err != nil { return "", err } + list = append(list, result) } return strings.Join(list, "\n"), nil @@ -43,5 +44,6 @@ func genField(field parser.Field) (string, error) { if err != nil { return "", err } + return output.String(), nil } diff --git a/tools/goctl/model/sql/gen/findone.go b/tools/goctl/model/sql/gen/findone.go index 284967b8..4088b640 100644 --- a/tools/goctl/model/sql/gen/findone.go +++ b/tools/goctl/model/sql/gen/findone.go @@ -28,5 +28,6 @@ func genFindOne(table Table, withCache bool) (string, error) { if err != nil { return "", err } + return output.String(), nil } diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index 752fe0ee..a09a6a50 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/model" "github.com/tal-tech/go-zero/tools/goctl/model/sql/parser" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/util" @@ -24,8 +25,8 @@ const ( type ( defaultGenerator struct { - source string - dir string + //source string + dir string console.Console pkg string namingStyle string @@ -33,18 +34,30 @@ type ( Option func(generator *defaultGenerator) ) -func NewDefaultGenerator(source, dir, namingStyle string, opt ...Option) *defaultGenerator { +func NewDefaultGenerator(dir, namingStyle string, opt ...Option) (*defaultGenerator, error) { if dir == "" { dir = pwd } - generator := &defaultGenerator{source: source, dir: dir, namingStyle: namingStyle} + dirAbs, err := filepath.Abs(dir) + if err != nil { + return nil, err + } + + dir = dirAbs + pkg := filepath.Base(dirAbs) + err = util.MkdirIfNotExist(dir) + if err != nil { + return nil, err + } + + generator := &defaultGenerator{dir: dir, namingStyle: namingStyle, pkg: pkg} var optionList []Option optionList = append(optionList, newDefaultOption()) optionList = append(optionList, opt...) for _, fn := range optionList { fn(generator) } - return generator + return generator, nil } func WithConsoleOption(c console.Console) Option { @@ -59,21 +72,45 @@ func newDefaultOption() Option { } } -func (g *defaultGenerator) Start(withCache bool) error { +func (g *defaultGenerator) StartFromDDL(source string, withCache bool) error { + modelList, err := g.genFromDDL(source, withCache) + if err != nil { + return err + } + + return g.createFile(modelList) +} + +func (g *defaultGenerator) StartFromInformationSchema(db string, columns map[string][]*model.Column, withCache bool) error { + m := make(map[string]string) + for tableName, column := range columns { + table, err := parser.ConvertColumn(db, tableName, column) + if err != nil { + return err + } + + code, err := g.genModel(*table, withCache) + if err != nil { + return err + } + + m[table.Name.Source()] = code + } + return g.createFile(m) +} + +func (g *defaultGenerator) createFile(modelList map[string]string) error { dirAbs, err := filepath.Abs(g.dir) if err != nil { return err } + g.dir = dirAbs g.pkg = filepath.Base(dirAbs) err = util.MkdirIfNotExist(dirAbs) if err != nil { return err } - modelList, err := g.genFromDDL(withCache) - if err != nil { - return err - } for tableName, code := range modelList { tn := stringx.From(tableName) @@ -96,6 +133,9 @@ func (g *defaultGenerator) Start(withCache bool) error { } // generate error file filename := filepath.Join(dirAbs, "vars.go") + if g.namingStyle == NamingCamel { + filename = filepath.Join(dirAbs, "Vars.go") + } text, err := util.LoadTemplate(category, errTemplateFile, template.Error) if err != nil { return err @@ -113,8 +153,8 @@ func (g *defaultGenerator) Start(withCache bool) error { } // ret1: key-table name,value-code -func (g *defaultGenerator) genFromDDL(withCache bool) (map[string]string, error) { - ddlList := g.split() +func (g *defaultGenerator) genFromDDL(source string, withCache bool) (map[string]string, error) { + ddlList := g.split(source) m := make(map[string]string) for _, ddl := range ddlList { table, err := parser.Parse(ddl) @@ -139,10 +179,15 @@ type ( ) func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) { + if len(in.PrimaryKey.Name.Source()) == 0 { + return "", fmt.Errorf("table %s: missing primary key", in.Name.Source()) + } + text, err := util.LoadTemplate(category, modelTemplateFile, template.Model) if err != nil { return "", err } + t := util.With("model"). Parse(text). GoFmt(true) diff --git a/tools/goctl/model/sql/gen/gen_test.go b/tools/goctl/model/sql/gen/gen_test.go index 656b6e50..6483adaa 100644 --- a/tools/goctl/model/sql/gen/gen_test.go +++ b/tools/goctl/model/sql/gen/gen_test.go @@ -22,15 +22,19 @@ func TestCacheModel(t *testing.T) { defer func() { _ = os.RemoveAll(dir) }() - g := NewDefaultGenerator(source, cacheDir, NamingLower) - err := g.Start(true) + g, err := NewDefaultGenerator(cacheDir, NamingCamel) + assert.Nil(t, err) + + err = g.StartFromDDL(source, true) assert.Nil(t, err) assert.True(t, func() bool { - _, err := os.Stat(filepath.Join(cacheDir, "testuserinfomodel.go")) + _, err := os.Stat(filepath.Join(cacheDir, "TestUserInfoModel.go")) return err == nil }()) - g = NewDefaultGenerator(source, noCacheDir, NamingLower) - err = g.Start(false) + g, err = NewDefaultGenerator(noCacheDir, NamingLower) + assert.Nil(t, err) + + err = g.StartFromDDL(source, false) assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(noCacheDir, "testuserinfomodel.go")) @@ -47,15 +51,19 @@ func TestNamingModel(t *testing.T) { defer func() { _ = os.RemoveAll(dir) }() - g := NewDefaultGenerator(source, camelDir, NamingCamel) - err := g.Start(true) + g, err := NewDefaultGenerator(camelDir, NamingCamel) + assert.Nil(t, err) + + err = g.StartFromDDL(source, true) assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go")) return err == nil }()) - g = NewDefaultGenerator(source, snakeDir, NamingSnake) - err = g.Start(true) + g, err = NewDefaultGenerator(snakeDir, NamingSnake) + assert.Nil(t, err) + + err = g.StartFromDDL(source, true) assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(snakeDir, "test_user_info_model.go")) diff --git a/tools/goctl/model/sql/gen/keys_test.go b/tools/goctl/model/sql/gen/keys_test.go index 28219cd5..a2cf4882 100644 --- a/tools/goctl/model/sql/gen/keys_test.go +++ b/tools/goctl/model/sql/gen/keys_test.go @@ -17,7 +17,6 @@ func TestGenCacheKeys(t *testing.T) { Name: stringx.From("id"), DataBaseType: "bigint", DataType: "int64", - IsKey: false, IsPrimaryKey: true, IsUniqueKey: false, Comment: "自增id", @@ -29,7 +28,6 @@ func TestGenCacheKeys(t *testing.T) { Name: stringx.From("mobile"), DataBaseType: "varchar", DataType: "string", - IsKey: false, IsPrimaryKey: false, IsUniqueKey: true, Comment: "手机号", @@ -38,7 +36,6 @@ func TestGenCacheKeys(t *testing.T) { Name: stringx.From("name"), DataBaseType: "varchar", DataType: "string", - IsKey: false, IsPrimaryKey: false, IsUniqueKey: true, Comment: "姓名", @@ -47,7 +44,6 @@ func TestGenCacheKeys(t *testing.T) { Name: stringx.From("createTime"), DataBaseType: "timestamp", DataType: "time.Time", - IsKey: false, IsPrimaryKey: false, IsUniqueKey: false, Comment: "创建时间", @@ -56,7 +52,6 @@ func TestGenCacheKeys(t *testing.T) { Name: stringx.From("updateTime"), DataBaseType: "timestamp", DataType: "time.Time", - IsKey: false, IsPrimaryKey: false, IsUniqueKey: false, Comment: "更新时间", diff --git a/tools/goctl/model/sql/gen/split.go b/tools/goctl/model/sql/gen/split.go index abc8e633..eb3e526b 100644 --- a/tools/goctl/model/sql/gen/split.go +++ b/tools/goctl/model/sql/gen/split.go @@ -4,11 +4,10 @@ import ( "regexp" ) -func (g *defaultGenerator) split() []string { +func (g *defaultGenerator) split(source string) []string { reg := regexp.MustCompile(createTableFlag) - index := reg.FindAllStringIndex(g.source, -1) + index := reg.FindAllStringIndex(source, -1) list := make([]string, 0) - source := g.source for i := len(index) - 1; i >= 0; i-- { subIndex := index[i] if len(subIndex) == 0 { diff --git a/tools/goctl/model/sql/gen/tag.go b/tools/goctl/model/sql/gen/tag.go index 11ce17d6..3c5ff150 100644 --- a/tools/goctl/model/sql/gen/tag.go +++ b/tools/goctl/model/sql/gen/tag.go @@ -22,5 +22,6 @@ func genTag(in string) (string, error) { if err != nil { return "", err } + return output.String(), nil } diff --git a/tools/goctl/model/sql/model/ddlmodel.go b/tools/goctl/model/sql/model/ddlmodel.go index e89aa610..a7c3a38f 100644 --- a/tools/goctl/model/sql/model/ddlmodel.go +++ b/tools/goctl/model/sql/model/ddlmodel.go @@ -27,6 +27,7 @@ func (m *DDLModel) ShowDDL(table ...string) ([]string, error) { if err != nil { return nil, err } + ddl = append(ddl, resp.DDL) } return ddl, nil diff --git a/tools/goctl/model/sql/model/informationschemamodel.go b/tools/goctl/model/sql/model/informationschemamodel.go index 942f6137..fbbc9718 100644 --- a/tools/goctl/model/sql/model/informationschemamodel.go +++ b/tools/goctl/model/sql/model/informationschemamodel.go @@ -8,6 +8,13 @@ type ( InformationSchemaModel struct { conn sqlx.SqlConn } + Column struct { + Name string `db:"COLUMN_NAME"` + DataType string `db:"DATA_TYPE"` + Key string `db:"COLUMN_KEY"` + Extra string `db:"EXTRA"` + Comment string `db:"COLUMN_COMMENT"` + } ) func NewInformationSchemaModel(conn sqlx.SqlConn) *InformationSchemaModel { @@ -21,5 +28,13 @@ func (m *InformationSchemaModel) GetAllTables(database string) ([]string, error) if err != nil { return nil, err } + return tables, nil } + +func (m *InformationSchemaModel) FindByTableName(db, table string) ([]*Column, error) { + querySql := `select COLUMN_NAME,DATA_TYPE,COLUMN_KEY,EXTRA,COLUMN_COMMENT from COLUMNS where TABLE_SCHEMA = ? and TABLE_NAME = ?` + var reply []*Column + err := m.conn.QueryRows(&reply, querySql, db, table) + return reply, err +} diff --git a/tools/goctl/model/sql/parser/parser.go b/tools/goctl/model/sql/parser/parser.go index 2aaea49b..ea38c75f 100644 --- a/tools/goctl/model/sql/parser/parser.go +++ b/tools/goctl/model/sql/parser/parser.go @@ -2,8 +2,10 @@ package parser import ( "fmt" + "strings" "github.com/tal-tech/go-zero/tools/goctl/model/sql/converter" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/model" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/xwb1989/sqlparser" ) @@ -34,7 +36,6 @@ type ( Name stringx.String DataBaseType string DataType string - IsKey bool IsPrimaryKey bool IsUniqueKey bool Comment string @@ -123,7 +124,6 @@ func Parse(ddl string) (*Table, error) { field.Comment = comment key, ok := keyMap[column.Name.String()] if ok { - field.IsKey = true field.IsPrimaryKey = key == primary field.IsUniqueKey = key == unique if field.IsPrimaryKey { @@ -151,3 +151,62 @@ func (t *Table) ContainsTime() bool { } return false } + +func ConvertColumn(db, table string, in []*model.Column) (*Table, error) { + var reply Table + reply.Name = stringx.From(table) + keyMap := make(map[string][]*model.Column) + + for _, column := range in { + keyMap[column.Key] = append(keyMap[column.Key], column) + } + primaryColumns := keyMap["PRI"] + if len(primaryColumns) == 0 { + return nil, fmt.Errorf("database:%s, table %s: missing primary key", db, table) + } + + if len(primaryColumns) > 1 { + return nil, fmt.Errorf("database:%s, table %s: only one primary key expected", db, table) + } + + primaryColumn := primaryColumns[0] + primaryFt, err := converter.ConvertDataType(primaryColumn.DataType) + if err != nil { + return nil, err + } + + primaryField := Field{ + Name: stringx.From(primaryColumn.Name), + DataBaseType: primaryColumn.DataType, + DataType: primaryFt, + IsUniqueKey: true, + IsPrimaryKey: true, + Comment: primaryColumn.Comment, + } + reply.PrimaryKey = Primary{ + Field: primaryField, + AutoIncrement: strings.Contains(primaryColumn.Extra, "auto_increment"), + } + for key, columns := range keyMap { + for _, item := range columns { + dt, err := converter.ConvertDataType(item.DataType) + if err != nil { + return nil, err + } + + f := Field{ + Name: stringx.From(item.Name), + DataBaseType: item.DataType, + DataType: dt, + IsPrimaryKey: primaryColumn.Name == item.Name, + Comment: item.Comment, + } + if key == "UNI" { + f.IsUniqueKey = true + } + reply.Fields = append(reply.Fields, f) + } + } + + return &reply, nil +} diff --git a/tools/goctl/model/sql/parser/parser_test.go b/tools/goctl/model/sql/parser/parser_test.go index b64f606c..11c256fb 100644 --- a/tools/goctl/model/sql/parser/parser_test.go +++ b/tools/goctl/model/sql/parser/parser_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/model" ) func TestParsePlainText(t *testing.T) { @@ -23,3 +24,58 @@ func TestParseCreateTable(t *testing.T) { assert.Equal(t, "id", table.PrimaryKey.Name.Source()) assert.Equal(t, true, table.ContainsTime()) } + +func TestConvertColumn(t *testing.T) { + _, err := ConvertColumn("user", "user", []*model.Column{ + { + Name: "id", + DataType: "bigint", + Key: "", + Extra: "", + Comment: "", + }, + }) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "missing primary key") + + _, err = ConvertColumn("user", "user", []*model.Column{ + { + Name: "id", + DataType: "bigint", + Key: "PRI", + Extra: "", + Comment: "", + }, + { + Name: "mobile", + DataType: "varchar", + Key: "PRI", + Extra: "", + Comment: "手机号", + }, + }) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "only one primary key expected") + + table, err := ConvertColumn("user", "user", []*model.Column{ + { + Name: "id", + DataType: "bigint", + Key: "PRI", + Extra: "auto_increment", + Comment: "", + }, + { + Name: "mobile", + DataType: "varchar", + Key: "UNI", + Extra: "", + Comment: "手机号", + }, + }) + assert.Nil(t, err) + assert.True(t, table.PrimaryKey.AutoIncrement && table.PrimaryKey.IsPrimaryKey) + assert.Equal(t, "id", table.PrimaryKey.Name.Source()) + assert.Equal(t, "mobile", table.Fields[1].Name.Source()) + assert.True(t, table.Fields[1].IsUniqueKey) +} diff --git a/tools/goctl/rpc/cli/cli.go b/tools/goctl/rpc/cli/cli.go index 483df6b4..561764d8 100644 --- a/tools/goctl/rpc/cli/cli.go +++ b/tools/goctl/rpc/cli/cli.go @@ -5,7 +5,6 @@ import ( "fmt" "path/filepath" - "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" "github.com/tal-tech/go-zero/tools/goctl/rpc/generator" "github.com/urfave/cli" ) @@ -16,6 +15,7 @@ import ( func Rpc(c *cli.Context) error { src := c.String("src") out := c.String("dir") + style := c.String("style") protoImportPath := c.StringSlice("proto_path") if len(src) == 0 { return errors.New("missing -src") @@ -23,7 +23,13 @@ func Rpc(c *cli.Context) error { if len(out) == 0 { return errors.New("missing -dir") } - g := generator.NewDefaultRpcGenerator() + + namingStyle, valid := generator.IsNamingValid(style) + if !valid { + return fmt.Errorf("unexpected naming style %s", style) + } + + g := generator.NewDefaultRpcGenerator(namingStyle) return g.Generate(src, out, protoImportPath) } @@ -36,6 +42,12 @@ func RpcNew(c *cli.Context) error { return fmt.Errorf("unexpected ext: %s", ext) } + style := c.String("style") + namingStyle, valid := generator.IsNamingValid(style) + if !valid { + return fmt.Errorf("expected naming style [lower|camel|snake], but found %s", style) + } + protoName := name + ".proto" filename := filepath.Join(".", name, protoName) src, err := filepath.Abs(filename) @@ -48,13 +60,7 @@ func RpcNew(c *cli.Context) error { return err } - workDir := filepath.Dir(src) - _, err = execx.Run("go mod init "+name, workDir) - if err != nil { - return err - } - - g := generator.NewDefaultRpcGenerator() + g := generator.NewDefaultRpcGenerator(namingStyle) return g.Generate(src, filepath.Dir(src), nil) } diff --git a/tools/goctl/rpc/generator/base/common.pb.go b/tools/goctl/rpc/generator/base/common.pb.go deleted file mode 100644 index 455529fe..00000000 --- a/tools/goctl/rpc/generator/base/common.pb.go +++ /dev/null @@ -1,75 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// source: common.proto - -package common - -import ( - fmt "fmt" - proto "github.com/golang/protobuf/proto" - math "math" -) - -// Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal -var _ = fmt.Errorf -var _ = math.Inf - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the proto package it is being compiled against. -// A compilation error at this line likely means your copy of the -// proto package needs to be updated. -const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package - -type User struct { - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *User) Reset() { *m = User{} } -func (m *User) String() string { return proto.CompactTextString(m) } -func (*User) ProtoMessage() {} -func (*User) Descriptor() ([]byte, []int) { - return fileDescriptor_555bd8c177793206, []int{0} -} - -func (m *User) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_User.Unmarshal(m, b) -} -func (m *User) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_User.Marshal(b, m, deterministic) -} -func (m *User) XXX_Merge(src proto.Message) { - xxx_messageInfo_User.Merge(m, src) -} -func (m *User) XXX_Size() int { - return xxx_messageInfo_User.Size(m) -} -func (m *User) XXX_DiscardUnknown() { - xxx_messageInfo_User.DiscardUnknown(m) -} - -var xxx_messageInfo_User proto.InternalMessageInfo - -func (m *User) GetName() string { - if m != nil { - return m.Name - } - return "" -} - -func init() { - proto.RegisterType((*User)(nil), "common.User") -} - -func init() { proto.RegisterFile("common.proto", fileDescriptor_555bd8c177793206) } - -var fileDescriptor_555bd8c177793206 = []byte{ - // 72 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x49, 0xce, 0xcf, 0xcd, - 0xcd, 0xcf, 0xd3, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x83, 0xf0, 0x94, 0xa4, 0xb8, 0x58, - 0x42, 0x8b, 0x53, 0x8b, 0x84, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, 0x53, 0x25, 0x18, 0x15, 0x18, - 0x35, 0x38, 0x83, 0xc0, 0xec, 0x24, 0x36, 0xb0, 0x52, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff, - 0x2c, 0x6d, 0x58, 0x59, 0x3a, 0x00, 0x00, 0x00, -} diff --git a/tools/goctl/rpc/generator/filename.go b/tools/goctl/rpc/generator/filename.go index 89617ded..1b07e56d 100644 --- a/tools/goctl/rpc/generator/filename.go +++ b/tools/goctl/rpc/generator/filename.go @@ -6,6 +6,13 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) -func formatFilename(filename string) string { - return strings.ToLower(stringx.From(filename).ToCamel()) +func formatFilename(filename string, style NamingStyle) string { + switch style { + case namingCamel: + return stringx.From(filename).ToCamel() + case namingSnake: + return stringx.From(filename).ToSnake() + default: + return strings.ToLower(stringx.From(filename).ToCamel()) + } } diff --git a/tools/goctl/rpc/generator/filename_test.go b/tools/goctl/rpc/generator/filename_test.go new file mode 100644 index 00000000..6d3618fb --- /dev/null +++ b/tools/goctl/rpc/generator/filename_test.go @@ -0,0 +1,17 @@ +package generator + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFormatFilename(t *testing.T) { + assert.Equal(t, "abc", formatFilename("a_b_c", namingLower)) + assert.Equal(t, "ABC", formatFilename("a_b_c", namingCamel)) + assert.Equal(t, "a_b_c", formatFilename("a_b_c", namingSnake)) + assert.Equal(t, "a", formatFilename("a", namingSnake)) + assert.Equal(t, "A", formatFilename("a", namingCamel)) + // no flag to convert to snake + assert.Equal(t, "abc", formatFilename("abc", namingSnake)) +} diff --git a/tools/goctl/rpc/generator/gen.go b/tools/goctl/rpc/generator/gen.go index d09c6394..95d05f1d 100644 --- a/tools/goctl/rpc/generator/gen.go +++ b/tools/goctl/rpc/generator/gen.go @@ -10,16 +10,18 @@ import ( ) type RpcGenerator struct { - g Generator + g Generator + style NamingStyle } -func NewDefaultRpcGenerator() *RpcGenerator { - return NewRpcGenerator(NewDefaultGenerator()) +func NewDefaultRpcGenerator(style NamingStyle) *RpcGenerator { + return NewRpcGenerator(NewDefaultGenerator(), style) } -func NewRpcGenerator(g Generator) *RpcGenerator { +func NewRpcGenerator(g Generator, style NamingStyle) *RpcGenerator { return &RpcGenerator{ - g: g, + g: g, + style: style, } } @@ -55,42 +57,42 @@ func (g *RpcGenerator) Generate(src, target string, protoImportPath []string) er return err } - err = g.g.GenEtc(dirCtx, proto) + err = g.g.GenEtc(dirCtx, proto, g.style) if err != nil { return err } - err = g.g.GenPb(dirCtx, protoImportPath, proto) + err = g.g.GenPb(dirCtx, protoImportPath, proto, g.style) if err != nil { return err } - err = g.g.GenConfig(dirCtx, proto) + err = g.g.GenConfig(dirCtx, proto, g.style) if err != nil { return err } - err = g.g.GenSvc(dirCtx, proto) + err = g.g.GenSvc(dirCtx, proto, g.style) if err != nil { return err } - err = g.g.GenLogic(dirCtx, proto) + err = g.g.GenLogic(dirCtx, proto, g.style) if err != nil { return err } - err = g.g.GenServer(dirCtx, proto) + err = g.g.GenServer(dirCtx, proto, g.style) if err != nil { return err } - err = g.g.GenMain(dirCtx, proto) + err = g.g.GenMain(dirCtx, proto, g.style) if err != nil { return err } - err = g.g.GenCall(dirCtx, proto) + err = g.g.GenCall(dirCtx, proto, g.style) console.NewColorConsole().MarkDone() diff --git a/tools/goctl/rpc/generator/gen_test.go b/tools/goctl/rpc/generator/gen_test.go index 1381cdd3..7a2e9e90 100644 --- a/tools/goctl/rpc/generator/gen_test.go +++ b/tools/goctl/rpc/generator/gen_test.go @@ -1,128 +1,74 @@ package generator import ( + "go/build" "os" "path/filepath" - "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/logx" + "github.com/tal-tech/go-zero/core/stringx" "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" ) -func TestRpcGenerateCaseNilImport(t *testing.T) { +func TestRpcGenerate(t *testing.T) { _ = Clean() dispatcher := NewDefaultGenerator() - if err := dispatcher.Prepare(); err == nil { - g := NewRpcGenerator(dispatcher) - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - err = g.Generate("./test_stream.proto", abs, nil) - defer func() { - _ = os.RemoveAll(abs) - }() - assert.Nil(t, err) - - _, err = execx.Run("go test "+abs, abs) - assert.Nil(t, err) + err := dispatcher.Prepare() + if err != nil { + logx.Error(err) + return } -} - -func TestRpcGenerateCaseOption(t *testing.T) { - _ = Clean() - dispatcher := NewDefaultGenerator() - if err := dispatcher.Prepare(); err == nil { - g := NewRpcGenerator(dispatcher) - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - err = g.Generate("./test_option.proto", abs, nil) - defer func() { - _ = os.RemoveAll(abs) - }() - assert.Nil(t, err) - - _, err = execx.Run("go test "+abs, abs) - assert.Nil(t, err) + projectName := stringx.Rand() + g := NewRpcGenerator(dispatcher, namingLower) + + // case go path + src := filepath.Join(build.Default.GOPATH, "src") + _, err = os.Stat(src) + if err != nil { + return } -} - -func TestRpcGenerateCaseWordOption(t *testing.T) { - _ = Clean() - dispatcher := NewDefaultGenerator() - if err := dispatcher.Prepare(); err == nil { - g := NewRpcGenerator(dispatcher) - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - err = g.Generate("./test_word_option.proto", abs, nil) - defer func() { - _ = os.RemoveAll(abs) - }() - assert.Nil(t, err) - - _, err = execx.Run("go test "+abs, abs) - assert.Nil(t, err) + projectDir := filepath.Join(src, projectName) + srcDir := projectDir + defer func() { + _ = os.RemoveAll(srcDir) + }() + err = g.Generate("./test.proto", projectDir, []string{src}) + assert.Nil(t, err) + _, err = execx.Run("go test "+projectName, projectDir) + if err != nil { + assert.Contains(t, err.Error(), "not in GOROOT") } -} -// test keyword go -func TestRpcGenerateCaseGoOption(t *testing.T) { - _ = Clean() - dispatcher := NewDefaultGenerator() - if err := dispatcher.Prepare(); err == nil { - g := NewRpcGenerator(dispatcher) - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - err = g.Generate("./test_go_option.proto", abs, nil) - defer func() { - _ = os.RemoveAll(abs) - }() - assert.Nil(t, err) - - _, err = execx.Run("go test "+abs, abs) - assert.Nil(t, err) + // case go mod + workDir := t.TempDir() + name := filepath.Base(workDir) + _, err = execx.Run("go mod init "+name, workDir) + if err != nil { + logx.Error(err) + return } -} - -func TestRpcGenerateCaseImport(t *testing.T) { - _ = Clean() - dispatcher := NewDefaultGenerator() - if err := dispatcher.Prepare(); err == nil { - g := NewRpcGenerator(dispatcher) - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - err = g.Generate("./test_import.proto", abs, []string{"./base"}) - defer func() { - _ = os.RemoveAll(abs) - }() - assert.Nil(t, err) - - _, err = execx.Run("go test "+abs, abs) - assert.True(t, func() bool { - return strings.Contains(err.Error(), "package base is not in GOROOT") - }()) + projectDir = filepath.Join(workDir, projectName) + err = g.Generate("./test.proto", projectDir, []string{src}) + assert.Nil(t, err) + _, err = execx.Run("go test "+projectName, projectDir) + if err != nil { + assert.Contains(t, err.Error(), "not in GOROOT") } -} -func TestRpcGenerateCaseServiceRpcNamingSnake(t *testing.T) { - _ = Clean() - dispatcher := NewDefaultGenerator() - if err := dispatcher.Prepare(); err == nil { - g := NewRpcGenerator(dispatcher) - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - err = g.Generate("./test_service_rpc_naming_snake.proto", abs, nil) - defer func() { - _ = os.RemoveAll(abs) - }() - assert.Nil(t, err) - - _, err = execx.Run("go test "+abs, abs) - assert.Nil(t, err) + // case not in go mod and go path + err = g.Generate("./test.proto", projectDir, []string{src}) + assert.Nil(t, err) + _, err = execx.Run("go test "+projectName, projectDir) + if err != nil { + assert.Contains(t, err.Error(), "not in GOROOT") } + + // invalid directory + projectDir = filepath.Join(t.TempDir(), ".....") + err = g.Generate("./test.proto", projectDir, nil) + assert.NotNil(t, err) } diff --git a/tools/goctl/rpc/generator/gencall.go b/tools/goctl/rpc/generator/gencall.go index 1e55fb9d..e5b01725 100644 --- a/tools/goctl/rpc/generator/gencall.go +++ b/tools/goctl/rpc/generator/gencall.go @@ -59,12 +59,12 @@ func (m *default{{.serviceName}}) {{.method}}(ctx context.Context,in *{{.pbReque ` ) -func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto) error { +func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error { dir := ctx.GetCall() service := proto.Service head := util.GetHead(proto.Name) - filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", formatFilename(service.Name))) + filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", formatFilename(service.Name, namingStyle))) functions, err := g.genFunction(proto.PbPackage, service) if err != nil { return err @@ -81,13 +81,12 @@ func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto) error { } var alias = collection.NewSet() - for _, item := range service.RPC { - alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(item.RequestType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.RequestType)))) - alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(item.ReturnsType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.ReturnsType)))) + for _, item := range proto.Message { + alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(item.Name), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.Name)))) } err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ - "name": formatFilename(service.Name), + "name": formatFilename(service.Name, namingStyle), "alias": strings.Join(alias.KeysStr(), util.NL), "head": head, "filePackage": dir.Base, diff --git a/tools/goctl/rpc/generator/gencall_test.go b/tools/goctl/rpc/generator/gencall_test.go deleted file mode 100644 index 6ae0cefe..00000000 --- a/tools/goctl/rpc/generator/gencall_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package generator - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/ctx" -) - -func TestGenerateCall(t *testing.T) { - _ = Clean() - project := "stream" - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - dir := filepath.Join(abs, project) - err = util.MkdirIfNotExist(dir) - assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(abs) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test_stream.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - - g := NewDefaultGenerator() - err = g.Prepare() - if err != nil { - return - } - err = g.GenCall(dirCtx, proto) - assert.Nil(t, err) -} diff --git a/tools/goctl/rpc/generator/genconfig.go b/tools/goctl/rpc/generator/genconfig.go index c9cd4bc4..22cd3d60 100644 --- a/tools/goctl/rpc/generator/genconfig.go +++ b/tools/goctl/rpc/generator/genconfig.go @@ -18,9 +18,9 @@ type Config struct { } ` -func (g *defaultGenerator) GenConfig(ctx DirContext, _ parser.Proto) error { +func (g *defaultGenerator) GenConfig(ctx DirContext, _ parser.Proto, namingStyle NamingStyle) error { dir := ctx.GetConfig() - fileName := filepath.Join(dir.Filename, formatFilename("config")+".go") + fileName := filepath.Join(dir.Filename, formatFilename("config", namingStyle)+".go") if util.FileExists(fileName) { return nil } diff --git a/tools/goctl/rpc/generator/genconfig_test.go b/tools/goctl/rpc/generator/genconfig_test.go deleted file mode 100644 index 39006b10..00000000 --- a/tools/goctl/rpc/generator/genconfig_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package generator - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/ctx" -) - -func TestGenerateConfig(t *testing.T) { - _ = Clean() - project := "stream" - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - dir := filepath.Join(abs, project) - err = util.MkdirIfNotExist(dir) - assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(abs) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test_stream.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - - g := NewDefaultGenerator() - err = g.Prepare() - if err != nil { - return - } - err = g.GenConfig(dirCtx, proto) - assert.Nil(t, err) - - // test file exists - err = g.GenConfig(dirCtx, proto) - assert.Nil(t, err) -} diff --git a/tools/goctl/rpc/generator/generator.go b/tools/goctl/rpc/generator/generator.go index a46ed42c..3d49e3de 100644 --- a/tools/goctl/rpc/generator/generator.go +++ b/tools/goctl/rpc/generator/generator.go @@ -4,12 +4,12 @@ import "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" type Generator interface { Prepare() error - GenMain(ctx DirContext, proto parser.Proto) error - GenCall(ctx DirContext, proto parser.Proto) error - GenEtc(ctx DirContext, proto parser.Proto) error - GenConfig(ctx DirContext, proto parser.Proto) error - GenLogic(ctx DirContext, proto parser.Proto) error - GenServer(ctx DirContext, proto parser.Proto) error - GenSvc(ctx DirContext, proto parser.Proto) error - GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error + GenMain(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error + GenCall(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error + GenEtc(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error + GenConfig(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error + GenLogic(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error + GenServer(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error + GenSvc(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error + GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, namingStyle NamingStyle) error } diff --git a/tools/goctl/rpc/generator/genetc.go b/tools/goctl/rpc/generator/genetc.go index ddd36fba..5fa0d1de 100644 --- a/tools/goctl/rpc/generator/genetc.go +++ b/tools/goctl/rpc/generator/genetc.go @@ -3,9 +3,11 @@ package generator import ( "fmt" "path/filepath" + "strings" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) const etcTemplate = `Name: {{.serviceName}}.rpc @@ -16,9 +18,9 @@ Etcd: Key: {{.serviceName}}.rpc ` -func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto) error { +func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto, namingStyle NamingStyle) error { dir := ctx.GetEtc() - serviceNameLower := formatFilename(ctx.GetMain().Base) + serviceNameLower := formatFilename(ctx.GetMain().Base, namingStyle) fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.yaml", serviceNameLower)) text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate) @@ -27,6 +29,6 @@ func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto) error { } return util.With("etc").Parse(text).SaveTo(map[string]interface{}{ - "serviceName": serviceNameLower, + "serviceName": strings.ToLower(stringx.From(ctx.GetMain().Base).ToCamel()), }, fileName, false) } diff --git a/tools/goctl/rpc/generator/genetc_test.go b/tools/goctl/rpc/generator/genetc_test.go deleted file mode 100644 index 457cfed4..00000000 --- a/tools/goctl/rpc/generator/genetc_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package generator - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/ctx" -) - -func TestGenerateEtc(t *testing.T) { - _ = Clean() - project := "stream" - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - dir := filepath.Join(abs, project) - err = util.MkdirIfNotExist(dir) - assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(abs) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test_stream.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - - g := NewDefaultGenerator() - err = g.Prepare() - if err != nil { - return - } - - err = g.GenEtc(dirCtx, proto) - assert.Nil(t, err) -} diff --git a/tools/goctl/rpc/generator/genlogic.go b/tools/goctl/rpc/generator/genlogic.go index 54f79585..19c05fb5 100644 --- a/tools/goctl/rpc/generator/genlogic.go +++ b/tools/goctl/rpc/generator/genlogic.go @@ -46,10 +46,10 @@ func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) { ` ) -func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto) error { +func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error { dir := ctx.GetLogic() for _, rpc := range proto.Service.RPC { - filename := filepath.Join(dir.Filename, formatFilename(rpc.Name+"_logic")+".go") + filename := filepath.Join(dir.Filename, formatFilename(rpc.Name+"_logic", namingStyle)+".go") functions, err := g.genLogicFunction(proto.PbPackage, rpc) if err != nil { return err diff --git a/tools/goctl/rpc/generator/genlogic_test.go b/tools/goctl/rpc/generator/genlogic_test.go deleted file mode 100644 index 681c89d7..00000000 --- a/tools/goctl/rpc/generator/genlogic_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package generator - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/ctx" -) - -func TestGenerateLogic(t *testing.T) { - project := "stream" - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - dir := filepath.Join(abs, project) - err = util.MkdirIfNotExist(dir) - assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(abs) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test_stream.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - - g := NewDefaultGenerator() - err = g.Prepare() - if err != nil { - return - } - - err = g.GenLogic(dirCtx, proto) - assert.Nil(t, err) -} diff --git a/tools/goctl/rpc/generator/genmain.go b/tools/goctl/rpc/generator/genmain.go index ca83866f..73270a83 100644 --- a/tools/goctl/rpc/generator/genmain.go +++ b/tools/goctl/rpc/generator/genmain.go @@ -45,9 +45,9 @@ func main() { } ` -func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto) error { +func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error { dir := ctx.GetMain() - serviceNameLower := formatFilename(ctx.GetMain().Base) + serviceNameLower := formatFilename(ctx.GetMain().Base, namingStyle) fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.go", serviceNameLower)) imports := make([]string, 0) pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package) @@ -63,7 +63,7 @@ func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto) error { return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ "head": head, - "serviceName": serviceNameLower, + "serviceName": strings.ToLower(stringx.From(ctx.GetMain().Base).ToCamel()), "imports": strings.Join(imports, util.NL), "pkg": proto.PbPackage, "serviceNew": stringx.From(proto.Service.Name).ToCamel(), diff --git a/tools/goctl/rpc/generator/genmain_test.go b/tools/goctl/rpc/generator/genmain_test.go deleted file mode 100644 index aed65b02..00000000 --- a/tools/goctl/rpc/generator/genmain_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package generator - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/ctx" -) - -func TestGenerateMain(t *testing.T) { - _ = Clean() - project := "stream" - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - dir := filepath.Join(abs, project) - err = util.MkdirIfNotExist(dir) - assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(abs) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test_stream.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - - g := NewDefaultGenerator() - err = g.Prepare() - if err != nil { - return - } - - err = g.GenMain(dirCtx, proto) - assert.Nil(t, err) -} diff --git a/tools/goctl/rpc/generator/genpb.go b/tools/goctl/rpc/generator/genpb.go index 5a541b60..66efaf94 100644 --- a/tools/goctl/rpc/generator/genpb.go +++ b/tools/goctl/rpc/generator/genpb.go @@ -9,7 +9,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" ) -func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error { +func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, namingStyle NamingStyle) error { dir := ctx.GetPb() cw := new(bytes.Buffer) base := filepath.Dir(proto.Src) diff --git a/tools/goctl/rpc/generator/genpb_test.go b/tools/goctl/rpc/generator/genpb_test.go deleted file mode 100644 index 1c423072..00000000 --- a/tools/goctl/rpc/generator/genpb_test.go +++ /dev/null @@ -1,184 +0,0 @@ -package generator - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/ctx" -) - -func TestGenerateCaseNilImport(t *testing.T) { - project := "stream" - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - dir := filepath.Join(abs, project) - err = util.MkdirIfNotExist(dir) - assert.Nil(t, err) - defer func() { - //_ = os.RemoveAll(abs) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test_stream.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - - g := NewDefaultGenerator() - if err := g.Prepare(); err == nil { - targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_stream.pb.go") - err = g.GenPb(dirCtx, nil, proto) - assert.Nil(t, err) - assert.True(t, func() bool { - return util.FileExists(targetPb) - }()) - } -} - -func TestGenerateCaseImport(t *testing.T) { - project := "stream" - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - dir := filepath.Join(abs, project) - err = util.MkdirIfNotExist(dir) - assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(abs) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test_stream.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - - g := NewDefaultGenerator() - if err := g.Prepare(); err == nil { - err = g.GenPb(dirCtx, nil, proto) - assert.Nil(t, err) - - targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_stream.pb.go") - assert.True(t, func() bool { - return util.FileExists(targetPb) - }()) - } -} - -func TestGenerateCasePathOption(t *testing.T) { - project := "stream" - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - dir := filepath.Join(abs, project) - err = util.MkdirIfNotExist(dir) - assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(abs) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test_option.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - - g := NewDefaultGenerator() - if err := g.Prepare(); err == nil { - err = g.GenPb(dirCtx, nil, proto) - assert.Nil(t, err) - - targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_option.pb.go") - assert.True(t, func() bool { - return util.FileExists(targetPb) - }()) - } -} - -func TestGenerateCaseWordOption(t *testing.T) { - project := "stream" - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - dir := filepath.Join(abs, project) - err = util.MkdirIfNotExist(dir) - assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(abs) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test_word_option.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - - g := NewDefaultGenerator() - if err := g.Prepare(); err == nil { - - err = g.GenPb(dirCtx, nil, proto) - assert.Nil(t, err) - - targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_word_option.pb.go") - assert.True(t, func() bool { - return util.FileExists(targetPb) - }()) - } -} - -// test keyword go -func TestGenerateCaseGoOption(t *testing.T) { - project := "stream" - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - dir := filepath.Join(abs, project) - err = util.MkdirIfNotExist(dir) - assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(abs) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test_go_option.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - - g := NewDefaultGenerator() - if err := g.Prepare(); err == nil { - - err = g.GenPb(dirCtx, nil, proto) - assert.Nil(t, err) - - targetPb := filepath.Join(dirCtx.GetPb().Filename, "test_go_option.pb.go") - assert.True(t, func() bool { - return util.FileExists(targetPb) - }()) - } -} diff --git a/tools/goctl/rpc/generator/genserver.go b/tools/goctl/rpc/generator/genserver.go index 01c35c23..41749598 100644 --- a/tools/goctl/rpc/generator/genserver.go +++ b/tools/goctl/rpc/generator/genserver.go @@ -43,7 +43,7 @@ func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ( ` ) -func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto) error { +func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error { dir := ctx.GetServer() logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package) svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package) @@ -54,7 +54,7 @@ func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto) error { head := util.GetHead(proto.Name) service := proto.Service - serverFile := filepath.Join(dir.Filename, formatFilename(service.Name+"_server")+".go") + serverFile := filepath.Join(dir.Filename, formatFilename(service.Name+"_server", namingStyle)+".go") funcList, err := g.genFunctions(proto.PbPackage, service) if err != nil { return err diff --git a/tools/goctl/rpc/generator/genserver_test.go b/tools/goctl/rpc/generator/genserver_test.go deleted file mode 100644 index e5f1e3f6..00000000 --- a/tools/goctl/rpc/generator/genserver_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package generator - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/ctx" -) - -func TestGenerateServer(t *testing.T) { - _ = Clean() - project := "stream" - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - dir := filepath.Join(abs, project) - err = util.MkdirIfNotExist(dir) - assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(abs) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test_stream.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - - g := NewDefaultGenerator() - err = g.Prepare() - if err != nil { - return - } - - err = g.GenServer(dirCtx, proto) - assert.Nil(t, err) -} diff --git a/tools/goctl/rpc/generator/gensvc.go b/tools/goctl/rpc/generator/gensvc.go index 86df41b4..79f0ac63 100644 --- a/tools/goctl/rpc/generator/gensvc.go +++ b/tools/goctl/rpc/generator/gensvc.go @@ -23,9 +23,9 @@ func NewServiceContext(c config.Config) *ServiceContext { } ` -func (g *defaultGenerator) GenSvc(ctx DirContext, _ parser.Proto) error { +func (g *defaultGenerator) GenSvc(ctx DirContext, _ parser.Proto, namingStyle NamingStyle) error { dir := ctx.GetSvc() - fileName := filepath.Join(dir.Filename, formatFilename("service_context")+".go") + fileName := filepath.Join(dir.Filename, formatFilename("service_context", namingStyle)+".go") text, err := util.LoadTemplate(category, svcTemplateFile, svcTemplate) if err != nil { return err diff --git a/tools/goctl/rpc/generator/gensvc_test.go b/tools/goctl/rpc/generator/gensvc_test.go deleted file mode 100644 index 6bf43e0d..00000000 --- a/tools/goctl/rpc/generator/gensvc_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package generator - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/ctx" -) - -func TestGenerateSvc(t *testing.T) { - _ = Clean() - project := "stream" - abs, err := filepath.Abs("./test") - assert.Nil(t, err) - - dir := filepath.Join(abs, project) - err = util.MkdirIfNotExist(dir) - assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(abs) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test_stream.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - - g := NewDefaultGenerator() - err = g.GenSvc(dirCtx, proto) - assert.Nil(t, err) -} diff --git a/tools/goctl/rpc/generator/mkdir_test.go b/tools/goctl/rpc/generator/mkdir_test.go deleted file mode 100644 index ae4f1f3a..00000000 --- a/tools/goctl/rpc/generator/mkdir_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package generator - -import ( - "go/build" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tal-tech/go-zero/core/stringx" - "github.com/tal-tech/go-zero/tools/goctl/rpc/execx" - "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" - "github.com/tal-tech/go-zero/tools/goctl/util" - "github.com/tal-tech/go-zero/tools/goctl/util/ctx" -) - -func TestMkDirInGoPath(t *testing.T) { - dft := build.Default - gp := dft.GOPATH - if len(gp) == 0 { - return - } - projectName := stringx.Rand() - dir := filepath.Join(gp, "src", projectName) - err := util.MkdirIfNotExist(dir) - if err != nil { - return - } - - defer func() { - _ = os.RemoveAll(dir) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - internal := filepath.Join(dir, "internal") - assert.True(t, true, func() bool { - return filepath.Join(dir, strings.ToLower(projectName)) == dirCtx.GetCall().Filename && projectName == dirCtx.GetCall().Package - }()) - assert.True(t, true, func() bool { - return filepath.Join(dir, "etc") == dirCtx.GetEtc().Filename && filepath.Join(projectName, "etc") == dirCtx.GetEtc().Package - }()) - assert.True(t, true, func() bool { - return internal == dirCtx.GetInternal().Filename && filepath.Join(projectName, "internal") == dirCtx.GetInternal().Package - }()) - assert.True(t, true, func() bool { - return filepath.Join(internal, "config") == dirCtx.GetConfig().Filename && filepath.Join(projectName, "internal", "config") == dirCtx.GetConfig().Package - }()) - assert.True(t, true, func() bool { - return filepath.Join(internal, "logic") == dirCtx.GetLogic().Filename && filepath.Join(projectName, "internal", "logic") == dirCtx.GetLogic().Package - }()) - assert.True(t, true, func() bool { - return filepath.Join(internal, "server") == dirCtx.GetServer().Filename && filepath.Join(projectName, "internal", "server") == dirCtx.GetServer().Package - }()) - assert.True(t, true, func() bool { - return filepath.Join(internal, "svc") == dirCtx.GetSvc().Filename && filepath.Join(projectName, "internal", "svc") == dirCtx.GetSvc().Package - }()) - assert.True(t, true, func() bool { - return filepath.Join(internal, strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Filename && filepath.Join(projectName, "internal", strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Package - }()) - assert.True(t, true, func() bool { - return dir == dirCtx.GetMain().Filename && projectName == dirCtx.GetMain().Package - }()) -} - -func TestMkDirInGoMod(t *testing.T) { - dft := build.Default - gp := dft.GOPATH - if len(gp) == 0 { - return - } - projectName := stringx.Rand() - dir := filepath.Join(gp, "src", projectName) - err := util.MkdirIfNotExist(dir) - if err != nil { - return - } - - _, err = execx.Run("go mod init "+projectName, dir) - assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(dir) - }() - - projectCtx, err := ctx.Prepare(dir) - assert.Nil(t, err) - - p := parser.NewDefaultProtoParser() - proto, err := p.Parse("./test.proto") - assert.Nil(t, err) - - dirCtx, err := mkdir(projectCtx, proto) - assert.Nil(t, err) - internal := filepath.Join(dir, "internal") - assert.True(t, true, func() bool { - return filepath.Join(dir, strings.ToLower(projectName)) == dirCtx.GetCall().Filename && projectName == dirCtx.GetCall().Package - }()) - assert.True(t, true, func() bool { - return filepath.Join(dir, "etc") == dirCtx.GetEtc().Filename && filepath.Join(projectName, "etc") == dirCtx.GetEtc().Package - }()) - assert.True(t, true, func() bool { - return internal == dirCtx.GetInternal().Filename && filepath.Join(projectName, "internal") == dirCtx.GetInternal().Package - }()) - assert.True(t, true, func() bool { - return filepath.Join(internal, "config") == dirCtx.GetConfig().Filename && filepath.Join(projectName, "internal", "config") == dirCtx.GetConfig().Package - }()) - assert.True(t, true, func() bool { - return filepath.Join(internal, "logic") == dirCtx.GetLogic().Filename && filepath.Join(projectName, "internal", "logic") == dirCtx.GetLogic().Package - }()) - assert.True(t, true, func() bool { - return filepath.Join(internal, "server") == dirCtx.GetServer().Filename && filepath.Join(projectName, "internal", "server") == dirCtx.GetServer().Package - }()) - assert.True(t, true, func() bool { - return filepath.Join(internal, "svc") == dirCtx.GetSvc().Filename && filepath.Join(projectName, "internal", "svc") == dirCtx.GetSvc().Package - }()) - assert.True(t, true, func() bool { - return filepath.Join(internal, strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Filename && filepath.Join(projectName, "internal", strings.ToLower(proto.Service.Name)) == dirCtx.GetPb().Package - }()) - assert.True(t, true, func() bool { - return dir == dirCtx.GetMain().Filename && projectName == dirCtx.GetMain().Package - }()) -} diff --git a/tools/goctl/rpc/generator/naming.go b/tools/goctl/rpc/generator/naming.go new file mode 100644 index 00000000..5a9f87ae --- /dev/null +++ b/tools/goctl/rpc/generator/naming.go @@ -0,0 +1,24 @@ +package generator + +type NamingStyle = string + +const ( + namingLower NamingStyle = "lower" + namingCamel NamingStyle = "camel" + namingSnake NamingStyle = "snake" +) + +// IsNamingValid validates whether the namingStyle is valid or not,return +// namingStyle and true if it is valid, or else return empty string +// and false, and it is a valid value even namingStyle is empty string +func IsNamingValid(namingStyle string) (NamingStyle, bool) { + if len(namingStyle) == 0 { + namingStyle = namingLower + } + switch namingStyle { + case namingLower, namingCamel, namingSnake: + return namingStyle, true + default: + return "", false + } +} diff --git a/tools/goctl/rpc/generator/naming_test.go b/tools/goctl/rpc/generator/naming_test.go new file mode 100644 index 00000000..f3e07cc3 --- /dev/null +++ b/tools/goctl/rpc/generator/naming_test.go @@ -0,0 +1,25 @@ +package generator + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsNamingValid(t *testing.T) { + style, valid := IsNamingValid("") + assert.True(t, valid) + assert.Equal(t, namingLower, style) + + _, valid = IsNamingValid("lower1") + assert.False(t, valid) + + _, valid = IsNamingValid("lower") + assert.True(t, valid) + + _, valid = IsNamingValid("snake") + assert.True(t, valid) + + _, valid = IsNamingValid("camel") + assert.True(t, valid) +} diff --git a/tools/goctl/rpc/generator/prototmpl_test.go b/tools/goctl/rpc/generator/prototmpl_test.go index 627ec6f4..e735c9e5 100644 --- a/tools/goctl/rpc/generator/prototmpl_test.go +++ b/tools/goctl/rpc/generator/prototmpl_test.go @@ -1,7 +1,6 @@ package generator import ( - "os" "path/filepath" "testing" @@ -9,13 +8,13 @@ import ( ) func TestProtoTmpl(t *testing.T) { - out, err := filepath.Abs("./test/test.proto") + _ = Clean() + // exists dir + err := ProtoTmpl(t.TempDir()) assert.Nil(t, err) - defer func() { - _ = os.RemoveAll(filepath.Dir(out)) - }() - err = ProtoTmpl(out) - assert.Nil(t, err) - _, err = os.Stat(out) + + // not exist dir + dir := filepath.Join(t.TempDir(), "test") + err = ProtoTmpl(dir) assert.Nil(t, err) } diff --git a/tools/goctl/rpc/generator/template_test.go b/tools/goctl/rpc/generator/template_test.go index 89200090..53f95ebe 100644 --- a/tools/goctl/rpc/generator/template_test.go +++ b/tools/goctl/rpc/generator/template_test.go @@ -2,6 +2,7 @@ package generator import ( "io/ioutil" + "os" "path/filepath" "testing" @@ -10,87 +11,104 @@ import ( ) func TestGenTemplates(t *testing.T) { - err := util.InitTemplates(category, templates) + _ = Clean() + err := GenTemplates(nil) assert.Nil(t, err) - dir, err := util.GetTemplateDir(category) - assert.Nil(t, err) - file := filepath.Join(dir, "main.tpl") - data, err := ioutil.ReadFile(file) - assert.Nil(t, err) - assert.Equal(t, string(data), mainTemplate) } func TestRevertTemplate(t *testing.T) { - name := "main.tpl" - err := util.InitTemplates(category, templates) - assert.Nil(t, err) - - dir, err := util.GetTemplateDir(category) - assert.Nil(t, err) - - file := filepath.Join(dir, name) - data, err := ioutil.ReadFile(file) - assert.Nil(t, err) - - modifyData := string(data) + "modify" - err = util.CreateTemplate(category, name, modifyData) - assert.Nil(t, err) - - data, err = ioutil.ReadFile(file) - assert.Nil(t, err) - - assert.Equal(t, string(data), modifyData) - - assert.Nil(t, RevertTemplate(name)) - - data, err = ioutil.ReadFile(file) - assert.Nil(t, err) - assert.Equal(t, mainTemplate, string(data)) + _ = Clean() + err := GenTemplates(nil) + assert.Nil(t, err) + fp, err := util.GetTemplateDir(category) + if err != nil { + return + } + mainTpl := filepath.Join(fp, mainTemplateFile) + data, err := ioutil.ReadFile(mainTpl) + if err != nil { + return + } + assert.Equal(t, templates[mainTemplateFile], string(data)) + + err = RevertTemplate("test") + if err != nil { + assert.Equal(t, "test: no such file name", err.Error()) + } + + err = ioutil.WriteFile(mainTpl, []byte("modify"), os.ModePerm) + if err != nil { + return + } + + data, err = ioutil.ReadFile(mainTpl) + if err != nil { + return + } + assert.Equal(t, "modify", string(data)) + + err = RevertTemplate(mainTemplateFile) + assert.Nil(t, err) + + data, err = ioutil.ReadFile(mainTpl) + if err != nil { + return + } + assert.Equal(t, templates[mainTemplateFile], string(data)) } func TestClean(t *testing.T) { - name := "main.tpl" - err := util.InitTemplates(category, templates) + _ = Clean() + err := GenTemplates(nil) + assert.Nil(t, err) + fp, err := util.GetTemplateDir(category) + if err != nil { + return + } + mainTpl := filepath.Join(fp, mainTemplateFile) + _, err = os.Stat(mainTpl) assert.Nil(t, err) - assert.Nil(t, Clean()) - - dir, err := util.GetTemplateDir(category) + err = Clean() assert.Nil(t, err) - file := filepath.Join(dir, name) - _, err = ioutil.ReadFile(file) + _, err = os.Stat(mainTpl) assert.NotNil(t, err) } func TestUpdate(t *testing.T) { - name := "main.tpl" - err := util.InitTemplates(category, templates) - assert.Nil(t, err) - - dir, err := util.GetTemplateDir(category) - assert.Nil(t, err) - - file := filepath.Join(dir, name) - data, err := ioutil.ReadFile(file) - assert.Nil(t, err) - - modifyData := string(data) + "modify" - err = util.CreateTemplate(category, name, modifyData) - assert.Nil(t, err) - - data, err = ioutil.ReadFile(file) - assert.Nil(t, err) - - assert.Equal(t, string(data), modifyData) - - assert.Nil(t, Update(category)) - - data, err = ioutil.ReadFile(file) - assert.Nil(t, err) - assert.Equal(t, mainTemplate, string(data)) + _ = Clean() + err := GenTemplates(nil) + assert.Nil(t, err) + fp, err := util.GetTemplateDir(category) + if err != nil { + return + } + mainTpl := filepath.Join(fp, mainTemplateFile) + + err = ioutil.WriteFile(mainTpl, []byte("modify"), os.ModePerm) + if err != nil { + return + } + + data, err := ioutil.ReadFile(mainTpl) + if err != nil { + return + } + assert.Equal(t, "modify", string(data)) + + err = Update(category) + assert.Nil(t, err) + + data, err = ioutil.ReadFile(mainTpl) + if err != nil { + return + } + assert.Equal(t, templates[mainTemplateFile], string(data)) } func TestGetCategory(t *testing.T) { - assert.Equal(t, category, GetCategory()) + _ = Clean() + result := GetCategory() + assert.Equal(t, category, result) } diff --git a/tools/goctl/rpc/generator/test.proto b/tools/goctl/rpc/generator/test.proto index cdb989ac..7e7d383f 100644 --- a/tools/goctl/rpc/generator/test.proto +++ b/tools/goctl/rpc/generator/test.proto @@ -2,24 +2,61 @@ syntax = "proto3"; package test; -option go_package = "go"; -import "test_base.proto"; +import "base/common.proto"; +import "google/protobuf/any.proto"; -message TestMessage { - base.CommonReq req = 1; +option go_package = "github.com/test"; + +message Req { + string in = 1; + common.User user = 2; + google.protobuf.Any object = 4; +} + +message Reply { + string out = 1; } -message TestReq {} -message TestReply { - base.CommonReply reply = 2; + +message snake_req {} + +message snake_reply {} + +message CamelReq{} + +message CamelReply{} + +message EnumMessage { + enum Enum { + unknown = 0; + male = 1; + female = 2; + } +} + +message CommonReply{} + +message MapReq{ + map m = 1; } -enum TestEnum { - unknown = 0; - male = 1; - female = 2; +message RepeatedReq{ + repeated string id = 1; } -service TestService { - rpc TestRpc (TestReq) returns (TestReply); +service Test_Service { + // service + rpc Service (Req) returns (Reply); + // greet service + rpc GreetService (Req) returns (Reply); + // case snake + rpc snake_service (snake_req) returns (snake_reply); + // case camel + rpc CamelService (CamelReq) returns (CamelReply); + // case enum + rpc EnumService (EnumMessage) returns (CommonReply); + // case map + rpc MapService (MapReq) returns (CommonReply); + // case repeated + rpc RepeatedService (RepeatedReq) returns (CommonReply); } \ No newline at end of file diff --git a/tools/goctl/rpc/generator/test_base.proto b/tools/goctl/rpc/generator/test_base.proto deleted file mode 100644 index 36e0ca5d..00000000 --- a/tools/goctl/rpc/generator/test_base.proto +++ /dev/null @@ -1,12 +0,0 @@ -// test proto -syntax = "proto3"; - -package base; - -message CommonReq { - string in = 1; -} - -message CommonReply { - string out = 1; -} diff --git a/tools/goctl/rpc/generator/test_go_option.proto b/tools/goctl/rpc/generator/test_go_option.proto deleted file mode 100644 index 0a674a23..00000000 --- a/tools/goctl/rpc/generator/test_go_option.proto +++ /dev/null @@ -1,18 +0,0 @@ -// test proto -syntax = "proto3"; - -package stream; - -option go_package="go"; - -message StreamReq { - string name = 1; -} - -message StreamResp { - string greet = 1; -} - -service StreamGreeter { - rpc greet (StreamReq) returns (StreamResp); -} \ No newline at end of file diff --git a/tools/goctl/rpc/generator/test_import.proto b/tools/goctl/rpc/generator/test_import.proto deleted file mode 100644 index ff2cd572..00000000 --- a/tools/goctl/rpc/generator/test_import.proto +++ /dev/null @@ -1,18 +0,0 @@ -// test proto -syntax = "proto3"; - -package greet; -import "base/common.proto"; - -message In { - string name = 1; - common.User user = 2; -} - -message Out { - string greet = 1; -} - -service StreamGreeter { - rpc greet (In) returns (Out); -} \ No newline at end of file diff --git a/tools/goctl/rpc/generator/test_option.proto b/tools/goctl/rpc/generator/test_option.proto deleted file mode 100644 index 36771f2f..00000000 --- a/tools/goctl/rpc/generator/test_option.proto +++ /dev/null @@ -1,18 +0,0 @@ -// test proto -syntax = "proto3"; - -package stream; - -option go_package="github.com/tal-tech/go-zero"; - -message StreamReq { - string name = 1; -} - -message StreamResp { - string greet = 1; -} - -service StreamGreeter { - rpc greet (StreamReq) returns (StreamResp); -} \ No newline at end of file diff --git a/tools/goctl/rpc/generator/test_service_rpc_naming_snake.proto b/tools/goctl/rpc/generator/test_service_rpc_naming_snake.proto deleted file mode 100644 index fe0589e1..00000000 --- a/tools/goctl/rpc/generator/test_service_rpc_naming_snake.proto +++ /dev/null @@ -1,27 +0,0 @@ -// test proto -syntax = "proto3"; - -package snake_package; - -message StreamReq { - string name = 1; -} - -message Stream_Resp { - string greet = 1; -} - -message lowercase { - string in = 1; - string lower = 2; -} - -message CamelCase { - string Camel = 1; -} - -service Stream_Greeter { - rpc snake_service(StreamReq) returns (Stream_Resp); - rpc ServiceCamelCase(CamelCase) returns (CamelCase); - rpc servicelowercase(lowercase) returns (lowercase); -} \ No newline at end of file diff --git a/tools/goctl/rpc/generator/test_stream.proto b/tools/goctl/rpc/generator/test_stream.proto deleted file mode 100644 index 0ea1b15d..00000000 --- a/tools/goctl/rpc/generator/test_stream.proto +++ /dev/null @@ -1,17 +0,0 @@ -// test proto -syntax = "proto3"; - -package stream; - -message StreamReq { - string name = 1; -} - -message StreamResp { - string greet = 1; -} - -service StreamGreeter { - // greet service - rpc greet (StreamReq) returns (StreamResp); -} \ No newline at end of file diff --git a/tools/goctl/rpc/generator/test_word_option.proto b/tools/goctl/rpc/generator/test_word_option.proto deleted file mode 100644 index 4f7753a5..00000000 --- a/tools/goctl/rpc/generator/test_word_option.proto +++ /dev/null @@ -1,18 +0,0 @@ -// test proto -syntax = "proto3"; - -package stream; - -option go_package="user"; - -message StreamReq { - string name = 1; -} - -message StreamResp { - string greet = 1; -} - -service StreamGreeter { - rpc greet(StreamReq) returns (StreamResp); -} \ No newline at end of file diff --git a/tools/goctl/util/stringx/string.go b/tools/goctl/util/stringx/string.go index cb8e204d..bc49de7e 100644 --- a/tools/goctl/util/stringx/string.go +++ b/tools/goctl/util/stringx/string.go @@ -29,9 +29,11 @@ func (s String) IsEmptyOrSpace() bool { func (s String) Lower() string { return strings.ToLower(s.source) } + func (s String) Upper() string { return strings.ToUpper(s.source) } + func (s String) Title() string { if s.IsEmptyOrSpace() { return s.source