From f70805ee607b38b352667ce1fe4eabe9e83a40ef Mon Sep 17 00:00:00 2001 From: anqiansong Date: Sun, 28 Aug 2022 18:55:52 +0800 Subject: [PATCH] Add strict flag (#2248) Co-authored-by: Kevin Wan --- tools/goctl/model/cmd.go | 4 + tools/goctl/model/sql/command/command.go | 75 ++++++++++++++----- tools/goctl/model/sql/command/command_test.go | 33 +++++++- tools/goctl/model/sql/converter/types.go | 12 +-- tools/goctl/model/sql/converter/types_test.go | 10 +-- tools/goctl/model/sql/gen/gen.go | 12 +-- tools/goctl/model/sql/gen/gen_test.go | 11 +-- tools/goctl/model/sql/parser/parser.go | 19 ++--- tools/goctl/model/sql/parser/parser_test.go | 7 +- 9 files changed, 126 insertions(+), 57 deletions(-) diff --git a/tools/goctl/model/cmd.go b/tools/goctl/model/cmd.go index 99de7c40..131fe048 100644 --- a/tools/goctl/model/cmd.go +++ b/tools/goctl/model/cmd.go @@ -2,6 +2,7 @@ package model import ( "github.com/spf13/cobra" + "github.com/zeromicro/go-zero/tools/goctl/model/mongo" "github.com/zeromicro/go-zero/tools/goctl/model/sql/command" ) @@ -77,6 +78,7 @@ func init() { pgDatasourceCmd.Flags().StringVarP(&command.VarStringDir, "dir", "d", "", "The target dir") pgDatasourceCmd.Flags().StringVar(&command.VarStringStyle, "style", "", "The file naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]") pgDatasourceCmd.Flags().BoolVar(&command.VarBoolIdea, "idea", false, "For idea plugin [optional]") + pgDatasourceCmd.Flags().BoolVar(&command.VarBoolStrict, "strict", false, "Generate model in strict mode") pgDatasourceCmd.Flags().StringVar(&command.VarStringHome, "home", "", "The goctl home path of the template, --home and --remote cannot be set at the same time, if they are, --remote has higher priority") pgDatasourceCmd.Flags().StringVar(&command.VarStringRemote, "remote", "", "The remote git repo of the template, --home and --remote cannot be set at the same time, if they are, --remote has higher priority\n\tThe git repo directory must be consistent with the https://github.com/zeromicro/go-zero-template directory structure") pgDatasourceCmd.Flags().StringVar(&command.VarStringBranch, "branch", "", "The branch of the remote repo, it does work with --remote") @@ -90,6 +92,8 @@ func init() { mongoCmd.Flags().StringVar(&mongo.VarStringRemote, "remote", "", "The remote git repo of the template, --home and --remote cannot be set at the same time, if they are, --remote has higher priority\nThe git repo directory must be consistent with the https://github.com/zeromicro/go-zero-template directory structure") mongoCmd.Flags().StringVar(&mongo.VarStringBranch, "branch", "", "The branch of the remote repo, it does work with --remote") + mysqlCmd.PersistentFlags().BoolVar(&command.VarBoolStrict, "strict", false, "Generate model in strict mode") + mysqlCmd.AddCommand(datasourceCmd) mysqlCmd.AddCommand(ddlCmd) pgCmd.AddCommand(pgDatasourceCmd) diff --git a/tools/goctl/model/sql/command/command.go b/tools/goctl/model/sql/command/command.go index 9fd65777..a12f2a2b 100644 --- a/tools/goctl/model/sql/command/command.go +++ b/tools/goctl/model/sql/command/command.go @@ -10,6 +10,7 @@ import ( "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stores/postgres" "github.com/zeromicro/go-zero/core/stores/sqlx" + "github.com/zeromicro/go-zero/tools/goctl/config" "github.com/zeromicro/go-zero/tools/goctl/model/sql/command/migrationnotes" "github.com/zeromicro/go-zero/tools/goctl/model/sql/gen" @@ -47,6 +48,8 @@ var ( VarStringRemote string // VarStringBranch describes the git branch of the repository. VarStringBranch string + // VarBoolStrict describes whether the strict mode is enabled. + VarBoolStrict bool ) var errNotMatched = errors.New("sql not matched") @@ -77,7 +80,16 @@ func MysqlDDL(_ *cobra.Command, _ []string) error { return err } - return fromDDL(src, dir, cfg, cache, idea, database) + arg := ddlArg{ + src: src, + dir: dir, + cfg: cfg, + cache: cache, + idea: idea, + database: database, + strict: VarBoolStrict, + } + return fromDDL(arg) } // MySqlDataSource generates model code from datasource @@ -108,7 +120,16 @@ func MySqlDataSource(_ *cobra.Command, _ []string) error { return err } - return fromMysqlDataSource(url, dir, patterns, cfg, cache, idea) + arg := dataSourceArg{ + url: url, + dir: dir, + tablePat: patterns, + cfg: cfg, + cache: cache, + idea: idea, + strict: VarBoolStrict, + } + return fromMysqlDataSource(arg) } type pattern map[string]struct{} @@ -180,12 +201,20 @@ func PostgreSqlDataSource(_ *cobra.Command, _ []string) error { return err } - return fromPostgreSqlDataSource(url, pattern, dir, schema, cfg, cache, idea) + return fromPostgreSqlDataSource(url, pattern, dir, schema, cfg, cache, idea, VarBoolStrict) } -func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database string) error { - log := console.NewConsole(idea) - src = strings.TrimSpace(src) +type ddlArg struct { + src, dir string + cfg *config.Config + cache, idea bool + database string + strict bool +} + +func fromDDL(arg ddlArg) error { + log := console.NewConsole(arg.idea) + src := strings.TrimSpace(arg.src) if len(src) == 0 { return errors.New("expected path or path globbing patterns, but nothing found") } @@ -199,13 +228,13 @@ func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database str return errNotMatched } - generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log)) + generator, err := gen.NewDefaultGenerator(arg.dir, arg.cfg, gen.WithConsoleOption(log)) if err != nil { return err } for _, file := range files { - err = generator.StartFromDDL(file, cache, database) + err = generator.StartFromDDL(file, arg.cache, arg.strict, arg.database) if err != nil { return err } @@ -214,25 +243,33 @@ func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database str return nil } -func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config, cache, idea bool) error { - log := console.NewConsole(idea) - if len(url) == 0 { +type dataSourceArg struct { + url, dir string + tablePat pattern + cfg *config.Config + cache, idea bool + strict bool +} + +func fromMysqlDataSource(arg dataSourceArg) error { + log := console.NewConsole(arg.idea) + if len(arg.url) == 0 { log.Error("%v", "expected data source of mysql, but nothing found") return nil } - if len(tablePat) == 0 { + if len(arg.tablePat) == 0 { log.Error("%v", "expected table or table globbing patterns, but nothing found") return nil } - dsn, err := mysql.ParseDSN(url) + dsn, err := mysql.ParseDSN(arg.url) if err != nil { return err } logx.Disable() - databaseSource := strings.TrimSuffix(url, "/"+dsn.DBName) + "/information_schema" + databaseSource := strings.TrimSuffix(arg.url, "/"+dsn.DBName) + "/information_schema" db := sqlx.NewMysql(databaseSource) im := model.NewInformationSchemaModel(db) @@ -243,7 +280,7 @@ func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config, matchTables := make(map[string]*model.Table) for _, item := range tables { - if !tablePat.Match(item) { + if !arg.tablePat.Match(item) { continue } @@ -264,15 +301,15 @@ func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config, return errors.New("no tables matched") } - generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log)) + generator, err := gen.NewDefaultGenerator(arg.dir, arg.cfg, gen.WithConsoleOption(log)) if err != nil { return err } - return generator.StartFromInformationSchema(matchTables, cache) + return generator.StartFromInformationSchema(matchTables, arg.cache, arg.strict) } -func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Config, cache, idea bool) error { +func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Config, cache, idea, strict bool) error { log := console.NewConsole(idea) if len(url) == 0 { log.Error("%v", "expected data source of postgresql, but nothing found") @@ -324,5 +361,5 @@ func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Conf return err } - return generator.StartFromInformationSchema(matchTables, cache) + return generator.StartFromInformationSchema(matchTables, cache, strict) } diff --git a/tools/goctl/model/sql/command/command_test.go b/tools/goctl/model/sql/command/command_test.go index bd35713f..330a6630 100644 --- a/tools/goctl/model/sql/command/command_test.go +++ b/tools/goctl/model/sql/command/command_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/tools/goctl/config" "github.com/zeromicro/go-zero/tools/goctl/model/sql/gen" "github.com/zeromicro/go-zero/tools/goctl/util/pathx" @@ -27,12 +28,25 @@ func TestFromDDl(t *testing.T) { err := gen.Clean() assert.Nil(t, err) - err = fromDDL("./user.sql", pathx.MustTempDir(), cfg, true, false, "go_zero") + err = fromDDL(ddlArg{ + src: "./user.sql", + dir: pathx.MustTempDir(), + cfg: cfg, + cache: true, + database: "go-zero", + strict: false, + }) assert.Equal(t, errNotMatched, err) // case dir is not exists unknownDir := filepath.Join(pathx.MustTempDir(), "test", "user.sql") - err = fromDDL(unknownDir, pathx.MustTempDir(), cfg, true, false, "go_zero") + err = fromDDL(ddlArg{ + src: unknownDir, + dir: pathx.MustTempDir(), + cfg: cfg, + cache: true, + database: "go_zero", + }) assert.True(t, func() bool { switch err.(type) { case *os.PathError: @@ -43,7 +57,12 @@ func TestFromDDl(t *testing.T) { }()) // case empty src - err = fromDDL("", pathx.MustTempDir(), cfg, true, false, "go_zero") + err = fromDDL(ddlArg{ + dir: pathx.MustTempDir(), + cfg: cfg, + cache: true, + database: "go_zero", + }) if err != nil { assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error()) } @@ -75,7 +94,13 @@ func TestFromDDl(t *testing.T) { filename := filepath.Join(tempDir, "usermodel.go") fromDDL := func(db string) { - err = fromDDL(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false, db) + err = fromDDL(ddlArg{ + src: filepath.Join(tempDir, "user*.sql"), + dir: tempDir, + cfg: cfg, + cache: true, + database: db, + }) assert.Nil(t, err) _, err = os.Stat(filename) diff --git a/tools/goctl/model/sql/converter/types.go b/tools/goctl/model/sql/converter/types.go index 4949656e..173b2316 100644 --- a/tools/goctl/model/sql/converter/types.go +++ b/tools/goctl/model/sql/converter/types.go @@ -132,28 +132,28 @@ var commonMysqlDataTypeMapString = map[string]string{ } // ConvertDataType converts mysql column type into golang type -func ConvertDataType(dataBaseType int, isDefaultNull, unsigned bool) (string, error) { +func ConvertDataType(dataBaseType int, isDefaultNull, unsigned, strict bool) (string, error) { tp, ok := commonMysqlDataTypeMapInt[dataBaseType] if !ok { return "", fmt.Errorf("unsupported database type: %v", dataBaseType) } - return mayConvertNullType(tp, isDefaultNull, unsigned), nil + return mayConvertNullType(tp, isDefaultNull, unsigned, strict), nil } // ConvertStringDataType converts mysql column type into golang type -func ConvertStringDataType(dataBaseType string, isDefaultNull, unsigned bool) (string, error) { +func ConvertStringDataType(dataBaseType string, isDefaultNull, unsigned, strict bool) (string, error) { tp, ok := commonMysqlDataTypeMapString[strings.ToLower(dataBaseType)] if !ok { return "", fmt.Errorf("unsupported database type: %s", dataBaseType) } - return mayConvertNullType(tp, isDefaultNull, unsigned), nil + return mayConvertNullType(tp, isDefaultNull, unsigned, strict), nil } -func mayConvertNullType(goDataType string, isDefaultNull, unsigned bool) string { +func mayConvertNullType(goDataType string, isDefaultNull, unsigned, strict bool) string { if !isDefaultNull { - if unsigned { + if unsigned && strict { ret, ok := unsignedTypeMap[goDataType] if ok { return ret diff --git a/tools/goctl/model/sql/converter/types_test.go b/tools/goctl/model/sql/converter/types_test.go index e16e0613..c5f8353d 100644 --- a/tools/goctl/model/sql/converter/types_test.go +++ b/tools/goctl/model/sql/converter/types_test.go @@ -8,23 +8,23 @@ import ( ) func TestConvertDataType(t *testing.T) { - v, err := ConvertDataType(parser.TinyInt, false, false) + v, err := ConvertDataType(parser.TinyInt, false, false, true) assert.Nil(t, err) assert.Equal(t, "int64", v) - v, err = ConvertDataType(parser.TinyInt, false, true) + v, err = ConvertDataType(parser.TinyInt, false, true, true) assert.Nil(t, err) assert.Equal(t, "uint64", v) - v, err = ConvertDataType(parser.TinyInt, true, false) + v, err = ConvertDataType(parser.TinyInt, true, false, true) assert.Nil(t, err) assert.Equal(t, "sql.NullInt64", v) - v, err = ConvertDataType(parser.Timestamp, false, false) + v, err = ConvertDataType(parser.Timestamp, false, false, true) assert.Nil(t, err) assert.Equal(t, "time.Time", v) - v, err = ConvertDataType(parser.Timestamp, true, false) + v, err = ConvertDataType(parser.Timestamp, true, false, true) assert.Nil(t, err) assert.Equal(t, "sql.NullTime", v) } diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index 8a8a17a0..4b7b3ed1 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -102,8 +102,8 @@ func newDefaultOption() Option { } } -func (g *defaultGenerator) StartFromDDL(filename string, withCache bool, database string) error { - modelList, err := g.genFromDDL(filename, withCache, database) +func (g *defaultGenerator) StartFromDDL(filename string, withCache, strict bool, database string) error { + modelList, err := g.genFromDDL(filename, withCache, strict, database) if err != nil { return err } @@ -111,10 +111,10 @@ func (g *defaultGenerator) StartFromDDL(filename string, withCache bool, databas return g.createFile(modelList) } -func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.Table, withCache bool) error { +func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.Table, withCache, strict bool) error { m := make(map[string]*codeTuple) for _, each := range tables { - table, err := parser.ConvertDataType(each) + table, err := parser.ConvertDataType(each, strict) if err != nil { return err } @@ -201,11 +201,11 @@ func (g *defaultGenerator) createFile(modelList map[string]*codeTuple) error { } // ret1: key-table name,value-code -func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) ( +func (g *defaultGenerator) genFromDDL(filename string, withCache, strict bool, database string) ( map[string]*codeTuple, error, ) { m := make(map[string]*codeTuple) - tables, err := parser.Parse(filename, database) + tables, err := parser.Parse(filename, database, strict) if err != nil { return nil, err } diff --git a/tools/goctl/model/sql/gen/gen_test.go b/tools/goctl/model/sql/gen/gen_test.go index 89b6e714..e73ecc6e 100644 --- a/tools/goctl/model/sql/gen/gen_test.go +++ b/tools/goctl/model/sql/gen/gen_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stringx" + "github.com/zeromicro/go-zero/tools/goctl/config" "github.com/zeromicro/go-zero/tools/goctl/model/sql/builderx" "github.com/zeromicro/go-zero/tools/goctl/model/sql/parser" @@ -40,7 +41,7 @@ func TestCacheModel(t *testing.T) { }) assert.Nil(t, err) - err = g.StartFromDDL(sqlFile, true, "go_zero") + err = g.StartFromDDL(sqlFile, true, false, "go_zero") assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go")) @@ -51,7 +52,7 @@ func TestCacheModel(t *testing.T) { }) assert.Nil(t, err) - err = g.StartFromDDL(sqlFile, false, "go_zero") + err = g.StartFromDDL(sqlFile, false, false, "go_zero") assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go")) @@ -78,7 +79,7 @@ func TestNamingModel(t *testing.T) { }) assert.Nil(t, err) - err = g.StartFromDDL(sqlFile, true, "go_zero") + err = g.StartFromDDL(sqlFile, true, false, "go_zero") assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go")) @@ -89,7 +90,7 @@ func TestNamingModel(t *testing.T) { }) assert.Nil(t, err) - err = g.StartFromDDL(sqlFile, true, "go_zero") + err = g.StartFromDDL(sqlFile, true, false, "go_zero") assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go")) @@ -186,7 +187,7 @@ func Test_genPublicModel(t *testing.T) { }) require.NoError(t, err) - tables, err := parser.Parse(modelFilename, "") + tables, err := parser.Parse(modelFilename, "", false) require.Equal(t, 1, len(tables)) code, err := g.genModelCustom(*tables[0], false) diff --git a/tools/goctl/model/sql/parser/parser.go b/tools/goctl/model/sql/parser/parser.go index 814888e6..1caab859 100644 --- a/tools/goctl/model/sql/parser/parser.go +++ b/tools/goctl/model/sql/parser/parser.go @@ -8,6 +8,7 @@ import ( "github.com/zeromicro/ddl-parser/parser" "github.com/zeromicro/go-zero/core/collection" + "github.com/zeromicro/go-zero/tools/goctl/model/sql/converter" "github.com/zeromicro/go-zero/tools/goctl/model/sql/model" "github.com/zeromicro/go-zero/tools/goctl/model/sql/util" @@ -61,7 +62,7 @@ func parseNameOriginal(ts []*parser.Table) (nameOriginals [][]string) { } // Parse parses ddl into golang structure -func Parse(filename, database string) ([]*Table, error) { +func Parse(filename, database string, strict bool) ([]*Table, error) { p := parser.NewParser() tables, err := p.From(filename) if err != nil { @@ -124,7 +125,7 @@ func Parse(filename, database string) ([]*Table, error) { return nil, fmt.Errorf("%s: unexpected join primary key", prefix) } - primaryKey, fieldM, err := convertColumns(columns, primaryColumn) + primaryKey, fieldM, err := convertColumns(columns, primaryColumn, strict) if err != nil { return nil, err } @@ -190,7 +191,7 @@ func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string } } -func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, map[string]*Field, error) { +func convertColumns(columns []*parser.Column, primaryColumn string, strict bool) (Primary, map[string]*Field, error) { var ( primaryKey Primary fieldM = make(map[string]*Field) @@ -219,7 +220,7 @@ func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, ma } } - dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned()) + dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned(), strict) if err != nil { return Primary{}, nil, err } @@ -264,10 +265,10 @@ func (t *Table) ContainsTime() bool { } // ConvertDataType converts mysql data type into golang data type -func ConvertDataType(table *model.Table) (*Table, error) { +func ConvertDataType(table *model.Table, strict bool) (*Table, error) { isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES" isPrimaryUnsigned := strings.Contains(table.PrimaryKey.DbColumn.ColumnType, "unsigned") - primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned) + primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned, strict) if err != nil { return nil, err } @@ -292,7 +293,7 @@ func ConvertDataType(table *model.Table) (*Table, error) { AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"), } - fieldM, err := getTableFields(table) + fieldM, err := getTableFields(table, strict) if err != nil { return nil, err } @@ -342,12 +343,12 @@ func ConvertDataType(table *model.Table) (*Table, error) { return &reply, nil } -func getTableFields(table *model.Table) (map[string]*Field, error) { +func getTableFields(table *model.Table, strict bool) (map[string]*Field, error) { fieldM := make(map[string]*Field) for _, each := range table.Columns { isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES" isPrimaryUnsigned := strings.Contains(each.ColumnType, "unsigned") - dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned) + dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned, strict) if err != nil { return nil, err } diff --git a/tools/goctl/model/sql/parser/parser_test.go b/tools/goctl/model/sql/parser/parser_test.go index b5907f34..0d205287 100644 --- a/tools/goctl/model/sql/parser/parser_test.go +++ b/tools/goctl/model/sql/parser/parser_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/tools/goctl/model/sql/model" "github.com/zeromicro/go-zero/tools/goctl/model/sql/util" "github.com/zeromicro/go-zero/tools/goctl/util/pathx" @@ -17,7 +18,7 @@ func TestParsePlainText(t *testing.T) { err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0o777) assert.Nil(t, err) - _, err = Parse(sqlFile, "go_zero") + _, err = Parse(sqlFile, "go_zero", false) assert.NotNil(t, err) } @@ -26,7 +27,7 @@ func TestParseSelect(t *testing.T) { err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0o777) assert.Nil(t, err) - tables, err := Parse(sqlFile, "go_zero") + tables, err := Parse(sqlFile, "go_zero", false) assert.Nil(t, err) assert.Equal(t, 0, len(tables)) } @@ -39,7 +40,7 @@ func TestParseCreateTable(t *testing.T) { err := ioutil.WriteFile(sqlFile, []byte(user), 0o777) assert.Nil(t, err) - tables, err := Parse(sqlFile, "go_zero") + tables, err := Parse(sqlFile, "go_zero", false) assert.Equal(t, 1, len(tables)) table := tables[0] assert.Nil(t, err)