diff --git a/go.mod b/go.mod index 36c5e000..9e81d824 100644 --- a/go.mod +++ b/go.mod @@ -35,8 +35,8 @@ require ( github.com/spaolacci/murmur3 v1.1.0 github.com/stretchr/testify v1.7.0 github.com/urfave/cli v1.22.5 - github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 github.com/zeromicro/antlr v0.0.1 + github.com/zeromicro/ddl-parser v0.0.0-20210712021150-63520aca7348 // indirect go.etcd.io/etcd/api/v3 v3.5.0 go.etcd.io/etcd/client/v3 v3.5.0 go.uber.org/automaxprocs v1.3.0 diff --git a/go.sum b/go.sum index 39a86f53..3aaf5df7 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,8 @@ github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGn github.com/alicebob/miniredis/v2 v2.14.1 h1:GjlbSeoJ24bzdLRs13HoMEeaRZx9kg5nHoRW7QV/nCs= github.com/alicebob/miniredis/v2 v2.14.1/go.mod h1:uS970Sw5Gs9/iK3yBg0l9Uj9s25wXxSpQUE9EaJ/Blg= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec h1:EEyRvzmpEUZ+I8WmD5cw/vY8EqhambkOqy5iFr0908A= +github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -225,8 +227,6 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/urfave/cli v1.22.5 h1:lNq9sAHXK2qfdI8W+GRItjCEkI+2oR4d+MEHy1CKXoU= github.com/urfave/cli v1.22.5/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= -github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 h1:zzrxE1FKn5ryBNl9eKOeqQ58Y/Qpo3Q9QNxKHX5uzzQ= -github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2/go.mod h1:hzfGeIUDq/j97IG+FhNqkowIyEcD88LrW6fyU3K3WqY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= @@ -234,6 +234,10 @@ github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb h1:ZkM6LRnq40pR1Ox github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb/go.mod h1:gqRgreBUhTSL0GeU64rtZ3Uq3wtjOa/TB2YfrtkCbVQ= github.com/zeromicro/antlr v0.0.1 h1:CQpIn/dc0pUjgGQ81y98s/NGOm2Hfru2NNio2I9mQgk= github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M= +github.com/zeromicro/ddl-parser v0.0.0-20210710132903-bc9dbb9789b1 h1:zItUIfobEHTYD9X0fAt9QWEWIFWDa8CypF+Z62zIR+M= +github.com/zeromicro/ddl-parser v0.0.0-20210710132903-bc9dbb9789b1/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8= +github.com/zeromicro/ddl-parser v0.0.0-20210712021150-63520aca7348 h1:OhxL9tn28gDeJVzreIUiE5oVxZCjL3tBJ0XBNw8p5R8= +github.com/zeromicro/ddl-parser v0.0.0-20210712021150-63520aca7348/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8= go.etcd.io/etcd/api/v3 v3.5.0 h1:GsV3S+OfZEOCNXdtNkBSR7kgLobAa/SO6tCxRa0GAYw= go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs= go.etcd.io/etcd/client/pkg/v3 v3.5.0 h1:2aQv6F436YnN7I4VbI8PPYrBhu+SmrTaADcf8Mi/6PU= diff --git a/tools/goctl/model/sql/command/command.go b/tools/goctl/model/sql/command/command.go index 273716e6..213973e0 100644 --- a/tools/goctl/model/sql/command/command.go +++ b/tools/goctl/model/sql/command/command.go @@ -2,7 +2,6 @@ package command import ( "errors" - "io/ioutil" "path/filepath" "strings" @@ -76,22 +75,19 @@ func fromDDl(src, dir string, cfg *config.Config, cache, idea bool) error { return errNotMatched } - var source []string + generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log)) + if err != nil { + return err + } + for _, file := range files { - data, err := ioutil.ReadFile(file) + err = generator.StartFromDDL(file, cache) if err != nil { return err } - - source = append(source, string(data)) - } - - generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log)) - if err != nil { - return err } - return generator.StartFromDDL(strings.Join(source, "\n"), cache) + return nil } func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bool) error { diff --git a/tools/goctl/model/sql/converter/types.go b/tools/goctl/model/sql/converter/types.go index cf6fb3ea..664490d3 100644 --- a/tools/goctl/model/sql/converter/types.go +++ b/tools/goctl/model/sql/converter/types.go @@ -3,9 +3,53 @@ package converter import ( "fmt" "strings" + + "github.com/zeromicro/ddl-parser/parser" ) -var commonMysqlDataTypeMap = map[string]string{ +var commonMysqlDataTypeMap = map[int]string{ + // For consistency, all integer types are converted to int64 + // number + parser.Bool: "int64", + parser.Boolean: "int64", + parser.TinyInt: "int64", + parser.SmallInt: "int64", + parser.MediumInt: "int64", + parser.Int: "int64", + parser.MiddleInt: "int64", + parser.Int1: "int64", + parser.Int2: "int64", + parser.Int3: "int64", + parser.Int4: "int64", + parser.Int8: "int64", + parser.Integer: "int64", + parser.BigInt: "int64", + parser.Float: "float64", + parser.Float4: "float64", + parser.Float8: "float64", + parser.Double: "float64", + parser.Decimal: "float64", + // date&time + parser.Date: "time.Time", + parser.DateTime: "time.Time", + parser.Timestamp: "time.Time", + parser.Time: "string", + parser.Year: "int64", + // string + parser.Char: "string", + parser.VarChar: "string", + parser.Binary: "string", + parser.VarBinary: "string", + parser.TinyText: "string", + parser.Text: "string", + parser.MediumText: "string", + parser.LongText: "string", + parser.Enum: "string", + parser.Set: "string", + parser.Json: "string", +} + +var commonMysqlDataTypeMap2 = map[string]string{ // For consistency, all integer types are converted to int64 // number "bool": "int64", @@ -40,10 +84,20 @@ var commonMysqlDataTypeMap = map[string]string{ } // ConvertDataType converts mysql column type into golang type -func ConvertDataType(dataBaseType string, isDefaultNull bool) (string, error) { - tp, ok := commonMysqlDataTypeMap[strings.ToLower(dataBaseType)] +func ConvertDataType(dataBaseType int, isDefaultNull bool) (string, error) { + tp, ok := commonMysqlDataTypeMap[dataBaseType] + if !ok { + return "", fmt.Errorf("unsupported database type: %v", dataBaseType) + } + + return mayConvertNullType(tp, isDefaultNull), nil +} + +// ConvertStringDataType converts mysql column type into golang type +func ConvertStringDataType(dataBaseType string, isDefaultNull bool) (string, error) { + tp, ok := commonMysqlDataTypeMap2[strings.ToLower(dataBaseType)] if !ok { - return "", fmt.Errorf("unexpected database type: %s", dataBaseType) + return "", fmt.Errorf("unsupported database type: %s", dataBaseType) } return mayConvertNullType(tp, isDefaultNull), nil diff --git a/tools/goctl/model/sql/converter/types_test.go b/tools/goctl/model/sql/converter/types_test.go index 8c2122dd..bc1fbf54 100644 --- a/tools/goctl/model/sql/converter/types_test.go +++ b/tools/goctl/model/sql/converter/types_test.go @@ -4,25 +4,23 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/zeromicro/ddl-parser/parser" ) func TestConvertDataType(t *testing.T) { - v, err := ConvertDataType("tinyint", false) + v, err := ConvertDataType(parser.TinyInt, false) assert.Nil(t, err) assert.Equal(t, "int64", v) - v, err = ConvertDataType("tinyint", true) + v, err = ConvertDataType(parser.TinyInt, true) assert.Nil(t, err) assert.Equal(t, "sql.NullInt64", v) - v, err = ConvertDataType("timestamp", false) + v, err = ConvertDataType(parser.Timestamp, false) assert.Nil(t, err) assert.Equal(t, "time.Time", v) - v, err = ConvertDataType("timestamp", true) + v, err = ConvertDataType(parser.Timestamp, true) assert.Nil(t, err) assert.Equal(t, "sql.NullTime", v) - - _, err = ConvertDataType("float32", false) - assert.NotNil(t, err) } diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index a0860184..3cb4f287 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -90,8 +90,8 @@ func newDefaultOption() Option { } } -func (g *defaultGenerator) StartFromDDL(source string, withCache bool) error { - modelList, err := g.genFromDDL(source, withCache) +func (g *defaultGenerator) StartFromDDL(filename string, withCache bool) error { + modelList, err := g.genFromDDL(filename, withCache) if err != nil { return err } @@ -174,21 +174,20 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error { } // ret1: key-table name,value-code -func (g *defaultGenerator) genFromDDL(source string, withCache bool) (map[string]string, error) { - ddlList := g.split(source) +func (g *defaultGenerator) genFromDDL(filename string, withCache bool) (map[string]string, error) { m := make(map[string]string) - for _, ddl := range ddlList { - table, err := parser.Parse(ddl) - if err != nil { - return nil, err - } + tables, err := parser.Parse(filename) + if err != nil { + return nil, err + } - code, err := g.genModel(*table, withCache) + for _, e := range tables { + code, err := g.genModel(*e, withCache) if err != nil { return nil, err } - m[table.Name.Source()] = code + m[e.Name.Source()] = code } return m, nil diff --git a/tools/goctl/model/sql/gen/gen_test.go b/tools/goctl/model/sql/gen/gen_test.go index 024f5c10..38c59d34 100644 --- a/tools/goctl/model/sql/gen/gen_test.go +++ b/tools/goctl/model/sql/gen/gen_test.go @@ -2,6 +2,7 @@ package gen import ( "database/sql" + "io/ioutil" "os" "path/filepath" "strings" @@ -20,6 +21,11 @@ var source = "CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT, func TestCacheModel(t *testing.T) { logx.Disable() _ = Clean() + + sqlFile := filepath.Join(t.TempDir(), "tmp.sql") + err := ioutil.WriteFile(sqlFile, []byte(source), 0777) + assert.Nil(t, err) + dir := filepath.Join(t.TempDir(), "./testmodel") cacheDir := filepath.Join(dir, "cache") noCacheDir := filepath.Join(dir, "nocache") @@ -28,7 +34,7 @@ func TestCacheModel(t *testing.T) { }) assert.Nil(t, err) - err = g.StartFromDDL(source, true) + err = g.StartFromDDL(sqlFile, true) assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go")) @@ -39,7 +45,7 @@ func TestCacheModel(t *testing.T) { }) assert.Nil(t, err) - err = g.StartFromDDL(source, false) + err = g.StartFromDDL(sqlFile, false) assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go")) @@ -50,6 +56,11 @@ func TestCacheModel(t *testing.T) { func TestNamingModel(t *testing.T) { logx.Disable() _ = Clean() + + sqlFile := filepath.Join(t.TempDir(), "tmp.sql") + err := ioutil.WriteFile(sqlFile, []byte(source), 0777) + assert.Nil(t, err) + dir, _ := filepath.Abs("./testmodel") camelDir := filepath.Join(dir, "camel") snakeDir := filepath.Join(dir, "snake") @@ -61,7 +72,7 @@ func TestNamingModel(t *testing.T) { }) assert.Nil(t, err) - err = g.StartFromDDL(source, true) + err = g.StartFromDDL(sqlFile, true) assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go")) @@ -72,7 +83,7 @@ func TestNamingModel(t *testing.T) { }) assert.Nil(t, err) - err = g.StartFromDDL(source, true) + err = g.StartFromDDL(sqlFile, true) assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go")) diff --git a/tools/goctl/model/sql/gen/keys_test.go b/tools/goctl/model/sql/gen/keys_test.go index 3721c95b..510ba82e 100644 --- a/tools/goctl/model/sql/gen/keys_test.go +++ b/tools/goctl/model/sql/gen/keys_test.go @@ -11,32 +11,28 @@ import ( func TestGenCacheKeys(t *testing.T) { primaryField := &parser.Field{ - Name: stringx.From("id"), - DataBaseType: "bigint", - DataType: "int64", - Comment: "自增id", - SeqInIndex: 1, + Name: stringx.From("id"), + DataType: "int64", + Comment: "自增id", + SeqInIndex: 1, } mobileField := &parser.Field{ - Name: stringx.From("mobile"), - DataBaseType: "varchar", - DataType: "string", - Comment: "手机号", - SeqInIndex: 1, + Name: stringx.From("mobile"), + DataType: "string", + Comment: "手机号", + SeqInIndex: 1, } classField := &parser.Field{ - Name: stringx.From("class"), - DataBaseType: "varchar", - DataType: "string", - Comment: "班级", - SeqInIndex: 1, + Name: stringx.From("class"), + DataType: "string", + Comment: "班级", + SeqInIndex: 1, } nameField := &parser.Field{ - Name: stringx.From("name"), - DataBaseType: "varchar", - DataType: "string", - Comment: "姓名", - SeqInIndex: 2, + Name: stringx.From("name"), + DataType: "string", + Comment: "姓名", + SeqInIndex: 2, } primariCacheKey, uniqueCacheKey := genCacheKeys(parser.Table{ Name: stringx.From("user"), @@ -53,23 +49,20 @@ func TestGenCacheKeys(t *testing.T) { nameField, }, }, - NormalIndex: nil, Fields: []*parser.Field{ primaryField, mobileField, classField, nameField, { - Name: stringx.From("createTime"), - DataBaseType: "timestamp", - DataType: "time.Time", - Comment: "创建时间", + Name: stringx.From("createTime"), + DataType: "time.Time", + Comment: "创建时间", }, { - Name: stringx.From("updateTime"), - DataBaseType: "timestamp", - DataType: "time.Time", - Comment: "更新时间", + Name: stringx.From("updateTime"), + DataType: "time.Time", + Comment: "更新时间", }, }, }) diff --git a/tools/goctl/model/sql/parser/error.go b/tools/goctl/model/sql/parser/error.go deleted file mode 100644 index 301be6c0..00000000 --- a/tools/goctl/model/sql/parser/error.go +++ /dev/null @@ -1,11 +0,0 @@ -package parser - -import ( - "errors" -) - -var ( - errUnsupportDDL = errors.New("unexpected type") - errTableBodyNotFound = errors.New("create table spec not found") - errPrimaryKey = errors.New("unexpected join primary key") -) diff --git a/tools/goctl/model/sql/parser/parser.go b/tools/goctl/model/sql/parser/parser.go index 6654f487..53a16a8e 100644 --- a/tools/goctl/model/sql/parser/parser.go +++ b/tools/goctl/model/sql/parser/parser.go @@ -2,6 +2,7 @@ package parser import ( "fmt" + "path/filepath" "sort" "strings" @@ -11,7 +12,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/model/sql/util" "github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" - "github.com/xwb1989/sqlparser" + "github.com/zeromicro/ddl-parser/parser" ) const timeImport = "time.Time" @@ -22,7 +23,6 @@ type ( Name stringx.String PrimaryKey Primary UniqueIndex map[string][]*Field - NormalIndex map[string][]*Field Fields []*Field } @@ -35,7 +35,6 @@ type ( // Field describes a table field Field struct { Name stringx.String - DataBaseType string DataType string Comment string SeqInIndex int @@ -47,73 +46,115 @@ type ( ) // Parse parses ddl into golang structure -func Parse(ddl string) (*Table, error) { - stmt, err := sqlparser.ParseStrictDDL(ddl) +func Parse(filename string) ([]*Table, error) { + p := parser.NewParser() + tables, err := p.From(filename) if err != nil { return nil, err } - ddlStmt, ok := stmt.(*sqlparser.DDL) - if !ok { - return nil, errUnsupportDDL + indexNameGen := func(column ...string) string { + return strings.Join(column, "_") } - action := ddlStmt.Action - if action != sqlparser.CreateStr { - return nil, fmt.Errorf("expected [CREATE] action,but found: %s", action) - } + prefix := filepath.Base(filename) + var list []*Table + for _, e := range tables { + columns := e.Columns - tableName := ddlStmt.NewName.Name.String() - tableSpec := ddlStmt.TableSpec - if tableSpec == nil { - return nil, errTableBodyNotFound - } + var ( + primaryColumnSet = collection.NewSet() - columns := tableSpec.Columns - indexes := tableSpec.Indexes - primaryColumn, uniqueKeyMap, normalKeyMap, err := convertIndexes(indexes) - if err != nil { - return nil, err - } + primaryColumn string + uniqueKeyMap = make(map[string][]string) + normalKeyMap = make(map[string][]string) + ) - primaryKey, fieldM, err := convertColumns(columns, primaryColumn) - if err != nil { - return nil, err - } + for _, column := range columns { + if column.Constraint != nil { + if column.Constraint.Primary { + primaryColumnSet.AddStr(column.Name) + } - var fields []*Field - for _, e := range fieldM { - fields = append(fields, e) - } + if column.Constraint.Unique { + indexName := indexNameGen(column.Name, "unique") + uniqueKeyMap[indexName] = []string{column.Name} + } - var ( - uniqueIndex = make(map[string][]*Field) - normalIndex = make(map[string][]*Field) - ) + if column.Constraint.Key { + indexName := indexNameGen(column.Name, "idx") + uniqueKeyMap[indexName] = []string{column.Name} + } + } + } - for indexName, each := range uniqueKeyMap { - for _, columnName := range each { - uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName]) + for _, e := range e.Constraints { + if len(e.ColumnPrimaryKey) > 1 { + return nil, fmt.Errorf("%s: unexpected join primary key", prefix) + } + + if len(e.ColumnPrimaryKey) == 1 { + primaryColumn = e.ColumnPrimaryKey[0] + primaryColumnSet.AddStr(e.ColumnPrimaryKey[0]) + } + + if len(e.ColumnUniqueKey) > 0 { + list := append([]string(nil), e.ColumnUniqueKey...) + list = append(list, "unique") + indexName := indexNameGen(list...) + uniqueKeyMap[indexName] = e.ColumnUniqueKey + } } - } - for indexName, each := range normalKeyMap { - for _, columnName := range each { - normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName]) + if primaryColumnSet.Count() > 1 { + return nil, fmt.Errorf("%s: unexpected join primary key", prefix) } + + primaryKey, fieldM, err := convertColumns(columns, primaryColumn) + if err != nil { + return nil, err + } + + var fields []*Field + // sort + for _, c := range columns { + field, ok := fieldM[c.Name] + if ok { + fields = append(fields, field) + } + } + + var ( + uniqueIndex = make(map[string][]*Field) + normalIndex = make(map[string][]*Field) + ) + + for indexName, each := range uniqueKeyMap { + for _, columnName := range each { + uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName]) + } + } + + for indexName, each := range normalKeyMap { + for _, columnName := range each { + normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName]) + } + } + + checkDuplicateUniqueIndex(uniqueIndex, e.Name) + + list = append(list, &Table{ + Name: stringx.From(e.Name), + PrimaryKey: primaryKey, + UniqueIndex: uniqueIndex, + Fields: fields, + }) } - checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex) - return &Table{ - Name: stringx.From(tableName), - PrimaryKey: primaryKey, - UniqueIndex: uniqueIndex, - NormalIndex: normalIndex, - Fields: fields, - }, nil + return list, nil } -func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) { +func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string) { log := console.NewColorConsole() uniqueSet := collection.NewSet() for k, i := range uniqueIndex { @@ -131,26 +172,9 @@ func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string uniqueSet.AddStr(joinRet) } - - normalIndexSet := collection.NewSet() - for k, i := range normalIndex { - var list []string - for _, e := range i { - list = append(list, e.Name.Source()) - } - - joinRet := strings.Join(list, ",") - if normalIndexSet.Contains(joinRet) { - log.Warning("table %s: duplicate index %s", tableName, joinRet) - delete(normalIndex, k) - continue - } - - normalIndexSet.Add(joinRet) - } } -func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) { +func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, map[string]*Field, error) { var ( primaryKey Primary fieldM = make(map[string]*Field) @@ -161,35 +185,35 @@ func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) continue } - var comment string - if column.Type.Comment != nil { - comment = string(column.Type.Comment.Val) - } + var ( + comment string + isDefaultNull bool + ) - isDefaultNull := true - if column.Type.NotNull { - isDefaultNull = false - } else { - if column.Type.Default != nil { + if column.Constraint != nil { + comment = column.Constraint.Comment + isDefaultNull = !column.Constraint.HasDefaultValue + if column.Name == primaryColumn && column.Constraint.AutoIncrement { isDefaultNull = false } } - dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull) + dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull) if err != nil { return Primary{}, nil, err } var field Field - field.Name = stringx.From(column.Name.String()) - field.DataBaseType = column.Type.Type + field.Name = stringx.From(column.Name) field.DataType = dataType field.Comment = util.TrimNewLine(comment) if field.Name.Source() == primaryColumn { primaryKey = Primary{ - Field: field, - AutoIncrement: bool(column.Type.Autoincrement), + Field: field, + } + if column.Constraint != nil { + primaryKey.AutoIncrement = column.Constraint.AutoIncrement } } @@ -198,60 +222,6 @@ func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) return primaryKey, fieldM, nil } -func convertIndexes(indexes []*sqlparser.IndexDefinition) (string, map[string][]string, map[string][]string, error) { - var primaryColumn string - uniqueKeyMap := make(map[string][]string) - normalKeyMap := make(map[string][]string) - - isCreateTimeOrUpdateTime := func(name string) bool { - camelColumnName := stringx.From(name).ToCamel() - // by default, createTime|updateTime findOne is not used. - return camelColumnName == "CreateTime" || camelColumnName == "UpdateTime" - } - - for _, index := range indexes { - info := index.Info - if info == nil { - continue - } - - indexName := index.Info.Name.String() - if info.Primary { - if len(index.Columns) > 1 { - return "", nil, nil, errPrimaryKey - } - columnName := index.Columns[0].Column.String() - if isCreateTimeOrUpdateTime(columnName) { - continue - } - - primaryColumn = columnName - continue - } else if info.Unique { - for _, each := range index.Columns { - columnName := each.Column.String() - if isCreateTimeOrUpdateTime(columnName) { - break - } - - uniqueKeyMap[indexName] = append(uniqueKeyMap[indexName], columnName) - } - } else if info.Spatial { - // do nothing - } else { - for _, each := range index.Columns { - columnName := each.Column.String() - if isCreateTimeOrUpdateTime(columnName) { - break - } - - normalKeyMap[indexName] = append(normalKeyMap[indexName], each.Column.String()) - } - } - } - return primaryColumn, uniqueKeyMap, normalKeyMap, nil -} - // ContainsTime returns true if contains golang type time.Time func (t *Table) ContainsTime() bool { for _, item := range t.Fields { @@ -265,14 +235,13 @@ func (t *Table) ContainsTime() bool { // ConvertDataType converts mysql data type into golang data type func ConvertDataType(table *model.Table) (*Table, error) { isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES" - primaryDataType, err := converter.ConvertDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull) + primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull) if err != nil { return nil, err } var reply Table reply.UniqueIndex = map[string][]*Field{} - reply.NormalIndex = map[string][]*Field{} reply.Name = stringx.From(table.Table) seqInIndex := 0 if table.PrimaryKey.Index != nil { @@ -282,7 +251,6 @@ func ConvertDataType(table *model.Table) (*Table, error) { reply.PrimaryKey = Primary{ Field: Field{ Name: stringx.From(table.PrimaryKey.Name), - DataBaseType: table.PrimaryKey.DataType, DataType: primaryDataType, Comment: table.PrimaryKey.Comment, SeqInIndex: seqInIndex, @@ -338,29 +306,6 @@ func ConvertDataType(table *model.Table) (*Table, error) { reply.UniqueIndex[indexName] = list } - normalIndexSet := collection.NewSet() - for indexName, each := range table.NormalIndex { - var list []*Field - var normalJoin []string - for _, c := range each { - list = append(list, fieldM[c.Name]) - normalJoin = append(normalJoin, c.Name) - } - - normalKey := strings.Join(normalJoin, ",") - if normalIndexSet.Contains(normalKey) { - log.Warning("table %s: duplicate index, %s", table.Table, normalKey) - continue - } - - normalIndexSet.AddStr(normalKey) - sort.Slice(list, func(i, j int) bool { - return list[i].SeqInIndex < list[j].SeqInIndex - }) - - reply.NormalIndex[indexName] = list - } - return &reply, nil } @@ -368,7 +313,7 @@ func getTableFields(table *model.Table) (map[string]*Field, error) { fieldM := make(map[string]*Field) for _, each := range table.Columns { isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES" - dt, err := converter.ConvertDataType(each.DataType, isDefaultNull) + dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull) if err != nil { return nil, err } @@ -379,7 +324,6 @@ func getTableFields(table *model.Table) (map[string]*Field, error) { field := &Field{ Name: stringx.From(each.Name), - DataBaseType: each.DataType, DataType: dt, Comment: each.Comment, SeqInIndex: columnSeqInIndex, diff --git a/tools/goctl/model/sql/parser/parser_test.go b/tools/goctl/model/sql/parser/parser_test.go index d13da8fb..02834cb1 100644 --- a/tools/goctl/model/sql/parser/parser_test.go +++ b/tools/goctl/model/sql/parser/parser_test.go @@ -1,88 +1,47 @@ package parser import ( - "sort" + "io/ioutil" + "path/filepath" "testing" "github.com/stretchr/testify/assert" "github.com/tal-tech/go-zero/tools/goctl/model/sql/model" "github.com/tal-tech/go-zero/tools/goctl/model/sql/util" - "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) func TestParsePlainText(t *testing.T) { - _, err := Parse("plain text") + sqlFile := filepath.Join(t.TempDir(), "tmp.sql") + err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0777) + assert.Nil(t, err) + + _, err = Parse(sqlFile) assert.NotNil(t, err) } func TestParseSelect(t *testing.T) { - _, err := Parse("select * from user") - assert.Equal(t, errUnsupportDDL, err) + sqlFile := filepath.Join(t.TempDir(), "tmp.sql") + err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0777) + assert.Nil(t, err) + + tables, err := Parse(sqlFile) + assert.Nil(t, err) + assert.Equal(t, 0, len(tables)) } func TestParseCreateTable(t *testing.T) { - table, err := Parse("CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL comment '手\\t机 号',\n `class` bigint NOT NULL comment '班级',\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL comment '姓\n 名',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP comment '创建\\r时间',\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;") + sqlFile := filepath.Join(t.TempDir(), "tmp.sql") + err := ioutil.WriteFile(sqlFile, []byte("CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL comment '手\\t机 号',\n `class` bigint NOT NULL comment '班级',\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL comment '姓\n 名',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP comment '创建\\r时间',\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;"), 0777) + assert.Nil(t, err) + + tables, err := Parse(sqlFile) + assert.Equal(t, 1, len(tables)) + table := tables[0] assert.Nil(t, err) assert.Equal(t, "test_user", table.Name.Source()) assert.Equal(t, "id", table.PrimaryKey.Name.Source()) assert.Equal(t, true, table.ContainsTime()) - assert.Equal(t, true, func() bool { - mobileUniqueIndex, ok := table.UniqueIndex["mobile_unique"] - if !ok { - return false - } - - classNameUniqueIndex, ok := table.UniqueIndex["class_name_unique"] - if !ok { - return false - } - - equal := func(f1, f2 []*Field) bool { - sort.Slice(f1, func(i, j int) bool { - return f1[i].Name.Source() < f1[j].Name.Source() - }) - sort.Slice(f2, func(i, j int) bool { - return f2[i].Name.Source() < f2[j].Name.Source() - }) - - if len(f2) != len(f2) { - return false - } - - for index, f := range f1 { - if f1[index].Name.Source() != f.Name.Source() { - return false - } - } - return true - } - - if !equal(mobileUniqueIndex, []*Field{ - { - Name: stringx.From("mobile"), - DataBaseType: "varchar", - DataType: "string", - SeqInIndex: 1, - }, - }) { - return false - } - - return equal(classNameUniqueIndex, []*Field{ - { - Name: stringx.From("class"), - DataBaseType: "bigint", - DataType: "int64", - SeqInIndex: 1, - }, - { - Name: stringx.From("name"), - DataBaseType: "varchar", - DataType: "string", - SeqInIndex: 2, - }, - }) - }()) + assert.Equal(t, 2, len(table.UniqueIndex)) assert.True(t, func() bool { for _, e := range table.Fields { if e.Comment != util.TrimNewLine(e.Comment) {