diff --git a/tools/goctl/model/sql/example/makefile b/tools/goctl/model/sql/example/makefile index 17fccaac..fa80267c 100644 --- a/tools/goctl/model/sql/example/makefile +++ b/tools/goctl/model/sql/example/makefile @@ -11,8 +11,8 @@ fromDDLWithoutCache: # generate model with cache from data source -user=ugozero -password= +user=root +password=password datasource=127.0.0.1:3306 database=gozero diff --git a/tools/goctl/model/sql/example/sql/user.sql b/tools/goctl/model/sql/example/sql/user.sql index 5aaec5e6..4a0cb2d0 100644 --- a/tools/goctl/model/sql/example/sql/user.sql +++ b/tools/goctl/model/sql/example/sql/user.sql @@ -11,6 +11,7 @@ CREATE TABLE `user` ( `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, PRIMARY KEY (`id`), UNIQUE KEY `name_index` (`name`), + UNIQUE KEY `name_index2` (`name`), UNIQUE KEY `user_index` (`user`), UNIQUE KEY `mobile_index` (`mobile`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; diff --git a/tools/goctl/model/sql/parser/parser.go b/tools/goctl/model/sql/parser/parser.go index ce454c44..cfe45273 100644 --- a/tools/goctl/model/sql/parser/parser.go +++ b/tools/goctl/model/sql/parser/parser.go @@ -75,15 +75,21 @@ func Parse(ddl string) (*Table, error) { return nil, err } - fields, primaryKey, fieldM, err := convertColumns(columns, primaryColumn) + primaryKey, fieldM, err := convertColumns(columns, primaryColumn) if err != nil { return nil, err } + var fields []*Field + for _, e := range fieldM { + fields = append(fields, e) + } + var ( uniqueIndex = make(map[string][]*Field) normalIndex = make(map[string][]*Field) ) + for indexName, each := range uniqueKeyMap { for _, columnName := range each { uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName]) @@ -96,6 +102,41 @@ func Parse(ddl string) (*Table, error) { } } + log := console.NewColorConsole() + uniqueSet := collection.NewSet() + for k, i := range uniqueIndex { + var list []string + for _, e := range i { + list = append(list, e.Name.Source()) + } + + joinRet := strings.Join(list, ",") + if uniqueSet.Contains(joinRet) { + log.Warning("table %s: duplicate unique index %s", tableName, joinRet) + delete(uniqueIndex, k) + continue + } + + uniqueSet.AddStr(joinRet) + } + + normalIndexSet := collection.NewSet() + for k, i := range normalIndex { + var list []string + for _, e := range i { + list = append(list, e.Name.Source()) + } + + joinRet := strings.Join(list, ",") + if normalIndexSet.Contains(joinRet) { + log.Warning("table %s: duplicate index %s", tableName, joinRet) + delete(normalIndex, k) + continue + } + + normalIndexSet.Add(joinRet) + } + return &Table{ Name: stringx.From(tableName), PrimaryKey: primaryKey, @@ -105,9 +146,8 @@ func Parse(ddl string) (*Table, error) { }, nil } -func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) ([]*Field, Primary, map[string]*Field, error) { +func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) { var ( - fields []*Field primaryKey Primary fieldM = make(map[string]*Field) ) @@ -135,7 +175,7 @@ func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull) if err != nil { - return nil, Primary{}, nil, err + return Primary{}, nil, err } var field Field @@ -151,10 +191,9 @@ func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) } } - fields = append(fields, &field) fieldM[field.Name.Source()] = &field } - return fields, primaryKey, fieldM, nil + return primaryKey, fieldM, nil } func convertIndexes(indexes []*sqlparser.IndexDefinition) (string, map[string][]string, map[string][]string, error) { @@ -293,7 +332,7 @@ func ConvertDataType(table *model.Table) (*Table, error) { if len(each) == 1 { one := each[0] if one.Name == table.PrimaryKey.Name { - log.Warning("duplicate unique index with primary key, %s", one.Name) + log.Warning("table %s: duplicate unique index with primary key, %s", table.Table, one.Name) continue } } @@ -307,19 +346,30 @@ func ConvertDataType(table *model.Table) (*Table, error) { uniqueKey := strings.Join(uniqueJoin, ",") if uniqueIndexSet.Contains(uniqueKey) { - log.Warning("duplicate unique index, %s", uniqueKey) + log.Warning("table %s: duplicate unique index, %s", table.Table, uniqueKey) continue } + uniqueIndexSet.AddStr(uniqueKey) reply.UniqueIndex[indexName] = list } + normalIndexSet := collection.NewSet() for indexName, each := range table.NormalIndex { var list []*Field + var normalJoin []string for _, c := range each { list = append(list, fieldM[c.Name]) + normalJoin = append(normalJoin, c.Name) + } + + normalKey := strings.Join(normalJoin, ",") + if normalIndexSet.Contains(normalKey) { + log.Warning("table %s: duplicate index, %s", table.Table, normalKey) + continue } + normalIndexSet.AddStr(normalKey) sort.Slice(list, func(i, j int) bool { return list[i].SeqInIndex < list[j].SeqInIndex }) diff --git a/tools/goctl/rpc/generator/gen_test.go b/tools/goctl/rpc/generator/gen_test.go index b1c0720c..11db212e 100644 --- a/tools/goctl/rpc/generator/gen_test.go +++ b/tools/goctl/rpc/generator/gen_test.go @@ -78,9 +78,4 @@ func TestRpcGenerate(t *testing.T) { return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package") }()) } - - // invalid directory - projectDir = filepath.Join(t.TempDir(), ".....") - err = g.Generate("./test.proto", projectDir, nil) - assert.NotNil(t, err) }