Add strict flag (#2248)

Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
master
anqiansong 2 years ago committed by GitHub
parent a1466e1707
commit f70805ee60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,7 @@ package model
import ( import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/zeromicro/go-zero/tools/goctl/model/mongo" "github.com/zeromicro/go-zero/tools/goctl/model/mongo"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/command" "github.com/zeromicro/go-zero/tools/goctl/model/sql/command"
) )
@ -77,6 +78,7 @@ func init() {
pgDatasourceCmd.Flags().StringVarP(&command.VarStringDir, "dir", "d", "", "The target dir") pgDatasourceCmd.Flags().StringVarP(&command.VarStringDir, "dir", "d", "", "The target dir")
pgDatasourceCmd.Flags().StringVar(&command.VarStringStyle, "style", "", "The file naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]") pgDatasourceCmd.Flags().StringVar(&command.VarStringStyle, "style", "", "The file naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]")
pgDatasourceCmd.Flags().BoolVar(&command.VarBoolIdea, "idea", false, "For idea plugin [optional]") pgDatasourceCmd.Flags().BoolVar(&command.VarBoolIdea, "idea", false, "For idea plugin [optional]")
pgDatasourceCmd.Flags().BoolVar(&command.VarBoolStrict, "strict", false, "Generate model in strict mode")
pgDatasourceCmd.Flags().StringVar(&command.VarStringHome, "home", "", "The goctl home path of the template, --home and --remote cannot be set at the same time, if they are, --remote has higher priority") pgDatasourceCmd.Flags().StringVar(&command.VarStringHome, "home", "", "The goctl home path of the template, --home and --remote cannot be set at the same time, if they are, --remote has higher priority")
pgDatasourceCmd.Flags().StringVar(&command.VarStringRemote, "remote", "", "The remote git repo of the template, --home and --remote cannot be set at the same time, if they are, --remote has higher priority\n\tThe git repo directory must be consistent with the https://github.com/zeromicro/go-zero-template directory structure") pgDatasourceCmd.Flags().StringVar(&command.VarStringRemote, "remote", "", "The remote git repo of the template, --home and --remote cannot be set at the same time, if they are, --remote has higher priority\n\tThe git repo directory must be consistent with the https://github.com/zeromicro/go-zero-template directory structure")
pgDatasourceCmd.Flags().StringVar(&command.VarStringBranch, "branch", "", "The branch of the remote repo, it does work with --remote") pgDatasourceCmd.Flags().StringVar(&command.VarStringBranch, "branch", "", "The branch of the remote repo, it does work with --remote")
@ -90,6 +92,8 @@ func init() {
mongoCmd.Flags().StringVar(&mongo.VarStringRemote, "remote", "", "The remote git repo of the template, --home and --remote cannot be set at the same time, if they are, --remote has higher priority\nThe git repo directory must be consistent with the https://github.com/zeromicro/go-zero-template directory structure") mongoCmd.Flags().StringVar(&mongo.VarStringRemote, "remote", "", "The remote git repo of the template, --home and --remote cannot be set at the same time, if they are, --remote has higher priority\nThe git repo directory must be consistent with the https://github.com/zeromicro/go-zero-template directory structure")
mongoCmd.Flags().StringVar(&mongo.VarStringBranch, "branch", "", "The branch of the remote repo, it does work with --remote") mongoCmd.Flags().StringVar(&mongo.VarStringBranch, "branch", "", "The branch of the remote repo, it does work with --remote")
mysqlCmd.PersistentFlags().BoolVar(&command.VarBoolStrict, "strict", false, "Generate model in strict mode")
mysqlCmd.AddCommand(datasourceCmd) mysqlCmd.AddCommand(datasourceCmd)
mysqlCmd.AddCommand(ddlCmd) mysqlCmd.AddCommand(ddlCmd)
pgCmd.AddCommand(pgDatasourceCmd) pgCmd.AddCommand(pgDatasourceCmd)

@ -10,6 +10,7 @@ import (
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/postgres" "github.com/zeromicro/go-zero/core/stores/postgres"
"github.com/zeromicro/go-zero/core/stores/sqlx" "github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/tools/goctl/config" "github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/command/migrationnotes" "github.com/zeromicro/go-zero/tools/goctl/model/sql/command/migrationnotes"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/gen" "github.com/zeromicro/go-zero/tools/goctl/model/sql/gen"
@ -47,6 +48,8 @@ var (
VarStringRemote string VarStringRemote string
// VarStringBranch describes the git branch of the repository. // VarStringBranch describes the git branch of the repository.
VarStringBranch string VarStringBranch string
// VarBoolStrict describes whether the strict mode is enabled.
VarBoolStrict bool
) )
var errNotMatched = errors.New("sql not matched") var errNotMatched = errors.New("sql not matched")
@ -77,7 +80,16 @@ func MysqlDDL(_ *cobra.Command, _ []string) error {
return err return err
} }
return fromDDL(src, dir, cfg, cache, idea, database) arg := ddlArg{
src: src,
dir: dir,
cfg: cfg,
cache: cache,
idea: idea,
database: database,
strict: VarBoolStrict,
}
return fromDDL(arg)
} }
// MySqlDataSource generates model code from datasource // MySqlDataSource generates model code from datasource
@ -108,7 +120,16 @@ func MySqlDataSource(_ *cobra.Command, _ []string) error {
return err return err
} }
return fromMysqlDataSource(url, dir, patterns, cfg, cache, idea) arg := dataSourceArg{
url: url,
dir: dir,
tablePat: patterns,
cfg: cfg,
cache: cache,
idea: idea,
strict: VarBoolStrict,
}
return fromMysqlDataSource(arg)
} }
type pattern map[string]struct{} type pattern map[string]struct{}
@ -180,12 +201,20 @@ func PostgreSqlDataSource(_ *cobra.Command, _ []string) error {
return err return err
} }
return fromPostgreSqlDataSource(url, pattern, dir, schema, cfg, cache, idea) return fromPostgreSqlDataSource(url, pattern, dir, schema, cfg, cache, idea, VarBoolStrict)
} }
func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database string) error { type ddlArg struct {
log := console.NewConsole(idea) src, dir string
src = strings.TrimSpace(src) cfg *config.Config
cache, idea bool
database string
strict bool
}
func fromDDL(arg ddlArg) error {
log := console.NewConsole(arg.idea)
src := strings.TrimSpace(arg.src)
if len(src) == 0 { if len(src) == 0 {
return errors.New("expected path or path globbing patterns, but nothing found") return errors.New("expected path or path globbing patterns, but nothing found")
} }
@ -199,13 +228,13 @@ func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database str
return errNotMatched return errNotMatched
} }
generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log)) generator, err := gen.NewDefaultGenerator(arg.dir, arg.cfg, gen.WithConsoleOption(log))
if err != nil { if err != nil {
return err return err
} }
for _, file := range files { for _, file := range files {
err = generator.StartFromDDL(file, cache, database) err = generator.StartFromDDL(file, arg.cache, arg.strict, arg.database)
if err != nil { if err != nil {
return err return err
} }
@ -214,25 +243,33 @@ func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database str
return nil return nil
} }
func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config, cache, idea bool) error { type dataSourceArg struct {
log := console.NewConsole(idea) url, dir string
if len(url) == 0 { tablePat pattern
cfg *config.Config
cache, idea bool
strict bool
}
func fromMysqlDataSource(arg dataSourceArg) error {
log := console.NewConsole(arg.idea)
if len(arg.url) == 0 {
log.Error("%v", "expected data source of mysql, but nothing found") log.Error("%v", "expected data source of mysql, but nothing found")
return nil return nil
} }
if len(tablePat) == 0 { if len(arg.tablePat) == 0 {
log.Error("%v", "expected table or table globbing patterns, but nothing found") log.Error("%v", "expected table or table globbing patterns, but nothing found")
return nil return nil
} }
dsn, err := mysql.ParseDSN(url) dsn, err := mysql.ParseDSN(arg.url)
if err != nil { if err != nil {
return err return err
} }
logx.Disable() logx.Disable()
databaseSource := strings.TrimSuffix(url, "/"+dsn.DBName) + "/information_schema" databaseSource := strings.TrimSuffix(arg.url, "/"+dsn.DBName) + "/information_schema"
db := sqlx.NewMysql(databaseSource) db := sqlx.NewMysql(databaseSource)
im := model.NewInformationSchemaModel(db) im := model.NewInformationSchemaModel(db)
@ -243,7 +280,7 @@ func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config,
matchTables := make(map[string]*model.Table) matchTables := make(map[string]*model.Table)
for _, item := range tables { for _, item := range tables {
if !tablePat.Match(item) { if !arg.tablePat.Match(item) {
continue continue
} }
@ -264,15 +301,15 @@ func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config,
return errors.New("no tables matched") return errors.New("no tables matched")
} }
generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log)) generator, err := gen.NewDefaultGenerator(arg.dir, arg.cfg, gen.WithConsoleOption(log))
if err != nil { if err != nil {
return err return err
} }
return generator.StartFromInformationSchema(matchTables, cache) return generator.StartFromInformationSchema(matchTables, arg.cache, arg.strict)
} }
func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Config, cache, idea bool) error { func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Config, cache, idea, strict bool) error {
log := console.NewConsole(idea) log := console.NewConsole(idea)
if len(url) == 0 { if len(url) == 0 {
log.Error("%v", "expected data source of postgresql, but nothing found") log.Error("%v", "expected data source of postgresql, but nothing found")
@ -324,5 +361,5 @@ func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Conf
return err return err
} }
return generator.StartFromInformationSchema(matchTables, cache) return generator.StartFromInformationSchema(matchTables, cache, strict)
} }

@ -10,6 +10,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/tools/goctl/config" "github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/gen" "github.com/zeromicro/go-zero/tools/goctl/model/sql/gen"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx" "github.com/zeromicro/go-zero/tools/goctl/util/pathx"
@ -27,12 +28,25 @@ func TestFromDDl(t *testing.T) {
err := gen.Clean() err := gen.Clean()
assert.Nil(t, err) assert.Nil(t, err)
err = fromDDL("./user.sql", pathx.MustTempDir(), cfg, true, false, "go_zero") err = fromDDL(ddlArg{
src: "./user.sql",
dir: pathx.MustTempDir(),
cfg: cfg,
cache: true,
database: "go-zero",
strict: false,
})
assert.Equal(t, errNotMatched, err) assert.Equal(t, errNotMatched, err)
// case dir is not exists // case dir is not exists
unknownDir := filepath.Join(pathx.MustTempDir(), "test", "user.sql") unknownDir := filepath.Join(pathx.MustTempDir(), "test", "user.sql")
err = fromDDL(unknownDir, pathx.MustTempDir(), cfg, true, false, "go_zero") err = fromDDL(ddlArg{
src: unknownDir,
dir: pathx.MustTempDir(),
cfg: cfg,
cache: true,
database: "go_zero",
})
assert.True(t, func() bool { assert.True(t, func() bool {
switch err.(type) { switch err.(type) {
case *os.PathError: case *os.PathError:
@ -43,7 +57,12 @@ func TestFromDDl(t *testing.T) {
}()) }())
// case empty src // case empty src
err = fromDDL("", pathx.MustTempDir(), cfg, true, false, "go_zero") err = fromDDL(ddlArg{
dir: pathx.MustTempDir(),
cfg: cfg,
cache: true,
database: "go_zero",
})
if err != nil { if err != nil {
assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error()) assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error())
} }
@ -75,7 +94,13 @@ func TestFromDDl(t *testing.T) {
filename := filepath.Join(tempDir, "usermodel.go") filename := filepath.Join(tempDir, "usermodel.go")
fromDDL := func(db string) { fromDDL := func(db string) {
err = fromDDL(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false, db) err = fromDDL(ddlArg{
src: filepath.Join(tempDir, "user*.sql"),
dir: tempDir,
cfg: cfg,
cache: true,
database: db,
})
assert.Nil(t, err) assert.Nil(t, err)
_, err = os.Stat(filename) _, err = os.Stat(filename)

@ -132,28 +132,28 @@ var commonMysqlDataTypeMapString = map[string]string{
} }
// ConvertDataType converts mysql column type into golang type // ConvertDataType converts mysql column type into golang type
func ConvertDataType(dataBaseType int, isDefaultNull, unsigned bool) (string, error) { func ConvertDataType(dataBaseType int, isDefaultNull, unsigned, strict bool) (string, error) {
tp, ok := commonMysqlDataTypeMapInt[dataBaseType] tp, ok := commonMysqlDataTypeMapInt[dataBaseType]
if !ok { if !ok {
return "", fmt.Errorf("unsupported database type: %v", dataBaseType) return "", fmt.Errorf("unsupported database type: %v", dataBaseType)
} }
return mayConvertNullType(tp, isDefaultNull, unsigned), nil return mayConvertNullType(tp, isDefaultNull, unsigned, strict), nil
} }
// ConvertStringDataType converts mysql column type into golang type // ConvertStringDataType converts mysql column type into golang type
func ConvertStringDataType(dataBaseType string, isDefaultNull, unsigned bool) (string, error) { func ConvertStringDataType(dataBaseType string, isDefaultNull, unsigned, strict bool) (string, error) {
tp, ok := commonMysqlDataTypeMapString[strings.ToLower(dataBaseType)] tp, ok := commonMysqlDataTypeMapString[strings.ToLower(dataBaseType)]
if !ok { if !ok {
return "", fmt.Errorf("unsupported database type: %s", dataBaseType) return "", fmt.Errorf("unsupported database type: %s", dataBaseType)
} }
return mayConvertNullType(tp, isDefaultNull, unsigned), nil return mayConvertNullType(tp, isDefaultNull, unsigned, strict), nil
} }
func mayConvertNullType(goDataType string, isDefaultNull, unsigned bool) string { func mayConvertNullType(goDataType string, isDefaultNull, unsigned, strict bool) string {
if !isDefaultNull { if !isDefaultNull {
if unsigned { if unsigned && strict {
ret, ok := unsignedTypeMap[goDataType] ret, ok := unsignedTypeMap[goDataType]
if ok { if ok {
return ret return ret

@ -8,23 +8,23 @@ import (
) )
func TestConvertDataType(t *testing.T) { func TestConvertDataType(t *testing.T) {
v, err := ConvertDataType(parser.TinyInt, false, false) v, err := ConvertDataType(parser.TinyInt, false, false, true)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "int64", v) assert.Equal(t, "int64", v)
v, err = ConvertDataType(parser.TinyInt, false, true) v, err = ConvertDataType(parser.TinyInt, false, true, true)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "uint64", v) assert.Equal(t, "uint64", v)
v, err = ConvertDataType(parser.TinyInt, true, false) v, err = ConvertDataType(parser.TinyInt, true, false, true)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "sql.NullInt64", v) assert.Equal(t, "sql.NullInt64", v)
v, err = ConvertDataType(parser.Timestamp, false, false) v, err = ConvertDataType(parser.Timestamp, false, false, true)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "time.Time", v) assert.Equal(t, "time.Time", v)
v, err = ConvertDataType(parser.Timestamp, true, false) v, err = ConvertDataType(parser.Timestamp, true, false, true)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "sql.NullTime", v) assert.Equal(t, "sql.NullTime", v)
} }

@ -102,8 +102,8 @@ func newDefaultOption() Option {
} }
} }
func (g *defaultGenerator) StartFromDDL(filename string, withCache bool, database string) error { func (g *defaultGenerator) StartFromDDL(filename string, withCache, strict bool, database string) error {
modelList, err := g.genFromDDL(filename, withCache, database) modelList, err := g.genFromDDL(filename, withCache, strict, database)
if err != nil { if err != nil {
return err return err
} }
@ -111,10 +111,10 @@ func (g *defaultGenerator) StartFromDDL(filename string, withCache bool, databas
return g.createFile(modelList) return g.createFile(modelList)
} }
func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.Table, withCache bool) error { func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.Table, withCache, strict bool) error {
m := make(map[string]*codeTuple) m := make(map[string]*codeTuple)
for _, each := range tables { for _, each := range tables {
table, err := parser.ConvertDataType(each) table, err := parser.ConvertDataType(each, strict)
if err != nil { if err != nil {
return err return err
} }
@ -201,11 +201,11 @@ func (g *defaultGenerator) createFile(modelList map[string]*codeTuple) error {
} }
// ret1: key-table name,value-code // ret1: key-table name,value-code
func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) ( func (g *defaultGenerator) genFromDDL(filename string, withCache, strict bool, database string) (
map[string]*codeTuple, error, map[string]*codeTuple, error,
) { ) {
m := make(map[string]*codeTuple) m := make(map[string]*codeTuple)
tables, err := parser.Parse(filename, database) tables, err := parser.Parse(filename, database, strict)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/tools/goctl/config" "github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/builderx" "github.com/zeromicro/go-zero/tools/goctl/model/sql/builderx"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/parser" "github.com/zeromicro/go-zero/tools/goctl/model/sql/parser"
@ -40,7 +41,7 @@ func TestCacheModel(t *testing.T) {
}) })
assert.Nil(t, err) assert.Nil(t, err)
err = g.StartFromDDL(sqlFile, true, "go_zero") err = g.StartFromDDL(sqlFile, true, false, "go_zero")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, func() bool { assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go")) _, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go"))
@ -51,7 +52,7 @@ func TestCacheModel(t *testing.T) {
}) })
assert.Nil(t, err) assert.Nil(t, err)
err = g.StartFromDDL(sqlFile, false, "go_zero") err = g.StartFromDDL(sqlFile, false, false, "go_zero")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, func() bool { assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go")) _, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go"))
@ -78,7 +79,7 @@ func TestNamingModel(t *testing.T) {
}) })
assert.Nil(t, err) assert.Nil(t, err)
err = g.StartFromDDL(sqlFile, true, "go_zero") err = g.StartFromDDL(sqlFile, true, false, "go_zero")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, func() bool { assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go")) _, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go"))
@ -89,7 +90,7 @@ func TestNamingModel(t *testing.T) {
}) })
assert.Nil(t, err) assert.Nil(t, err)
err = g.StartFromDDL(sqlFile, true, "go_zero") err = g.StartFromDDL(sqlFile, true, false, "go_zero")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, func() bool { assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go")) _, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go"))
@ -186,7 +187,7 @@ func Test_genPublicModel(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
tables, err := parser.Parse(modelFilename, "") tables, err := parser.Parse(modelFilename, "", false)
require.Equal(t, 1, len(tables)) require.Equal(t, 1, len(tables))
code, err := g.genModelCustom(*tables[0], false) code, err := g.genModelCustom(*tables[0], false)

@ -8,6 +8,7 @@ import (
"github.com/zeromicro/ddl-parser/parser" "github.com/zeromicro/ddl-parser/parser"
"github.com/zeromicro/go-zero/core/collection" "github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/converter" "github.com/zeromicro/go-zero/tools/goctl/model/sql/converter"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/model" "github.com/zeromicro/go-zero/tools/goctl/model/sql/model"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/util" "github.com/zeromicro/go-zero/tools/goctl/model/sql/util"
@ -61,7 +62,7 @@ func parseNameOriginal(ts []*parser.Table) (nameOriginals [][]string) {
} }
// Parse parses ddl into golang structure // Parse parses ddl into golang structure
func Parse(filename, database string) ([]*Table, error) { func Parse(filename, database string, strict bool) ([]*Table, error) {
p := parser.NewParser() p := parser.NewParser()
tables, err := p.From(filename) tables, err := p.From(filename)
if err != nil { if err != nil {
@ -124,7 +125,7 @@ func Parse(filename, database string) ([]*Table, error) {
return nil, fmt.Errorf("%s: unexpected join primary key", prefix) return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
} }
primaryKey, fieldM, err := convertColumns(columns, primaryColumn) primaryKey, fieldM, err := convertColumns(columns, primaryColumn, strict)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -190,7 +191,7 @@ func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string
} }
} }
func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, map[string]*Field, error) { func convertColumns(columns []*parser.Column, primaryColumn string, strict bool) (Primary, map[string]*Field, error) {
var ( var (
primaryKey Primary primaryKey Primary
fieldM = make(map[string]*Field) fieldM = make(map[string]*Field)
@ -219,7 +220,7 @@ func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, ma
} }
} }
dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned()) dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned(), strict)
if err != nil { if err != nil {
return Primary{}, nil, err return Primary{}, nil, err
} }
@ -264,10 +265,10 @@ func (t *Table) ContainsTime() bool {
} }
// ConvertDataType converts mysql data type into golang data type // ConvertDataType converts mysql data type into golang data type
func ConvertDataType(table *model.Table) (*Table, error) { func ConvertDataType(table *model.Table, strict bool) (*Table, error) {
isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES" isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
isPrimaryUnsigned := strings.Contains(table.PrimaryKey.DbColumn.ColumnType, "unsigned") isPrimaryUnsigned := strings.Contains(table.PrimaryKey.DbColumn.ColumnType, "unsigned")
primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned) primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned, strict)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -292,7 +293,7 @@ func ConvertDataType(table *model.Table) (*Table, error) {
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"), AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
} }
fieldM, err := getTableFields(table) fieldM, err := getTableFields(table, strict)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -342,12 +343,12 @@ func ConvertDataType(table *model.Table) (*Table, error) {
return &reply, nil return &reply, nil
} }
func getTableFields(table *model.Table) (map[string]*Field, error) { func getTableFields(table *model.Table, strict bool) (map[string]*Field, error) {
fieldM := make(map[string]*Field) fieldM := make(map[string]*Field)
for _, each := range table.Columns { for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES" isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
isPrimaryUnsigned := strings.Contains(each.ColumnType, "unsigned") isPrimaryUnsigned := strings.Contains(each.ColumnType, "unsigned")
dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned) dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned, strict)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/model" "github.com/zeromicro/go-zero/tools/goctl/model/sql/model"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/util" "github.com/zeromicro/go-zero/tools/goctl/model/sql/util"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx" "github.com/zeromicro/go-zero/tools/goctl/util/pathx"
@ -17,7 +18,7 @@ func TestParsePlainText(t *testing.T) {
err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0o777) err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0o777)
assert.Nil(t, err) assert.Nil(t, err)
_, err = Parse(sqlFile, "go_zero") _, err = Parse(sqlFile, "go_zero", false)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@ -26,7 +27,7 @@ func TestParseSelect(t *testing.T) {
err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0o777) err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0o777)
assert.Nil(t, err) assert.Nil(t, err)
tables, err := Parse(sqlFile, "go_zero") tables, err := Parse(sqlFile, "go_zero", false)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 0, len(tables)) assert.Equal(t, 0, len(tables))
} }
@ -39,7 +40,7 @@ func TestParseCreateTable(t *testing.T) {
err := ioutil.WriteFile(sqlFile, []byte(user), 0o777) err := ioutil.WriteFile(sqlFile, []byte(user), 0o777)
assert.Nil(t, err) assert.Nil(t, err)
tables, err := Parse(sqlFile, "go_zero") tables, err := Parse(sqlFile, "go_zero", false)
assert.Equal(t, 1, len(tables)) assert.Equal(t, 1, len(tables))
table := tables[0] table := tables[0]
assert.Nil(t, err) assert.Nil(t, err)

Loading…
Cancel
Save