diff --git a/tools/goctl/model/sql/example/sql/user.sql b/tools/goctl/model/sql/example/sql/user.sql index 3a6b4241..0e705744 100644 --- a/tools/goctl/model/sql/example/sql/user.sql +++ b/tools/goctl/model/sql/example/sql/user.sql @@ -8,12 +8,14 @@ CREATE TABLE `user` `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号', `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公\r开', `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称', + `type` tinyint(1) COLLATE utf8mb4_general_ci DEFAULT 0 COMMENT '用户类型', `create_time` timestamp NULL, `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, PRIMARY KEY (`id`), UNIQUE KEY `name_index` (`name`), UNIQUE KEY `name_index2` (`name`), UNIQUE KEY `user_index` (`user`), + UNIQUE KEY `type_index` (`type`), UNIQUE KEY `mobile_index` (`mobile`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; diff --git a/tools/goctl/model/sql/model/informationschemamodel.go b/tools/goctl/model/sql/model/informationschemamodel.go index 6a8d95b2..5d53422b 100644 --- a/tools/goctl/model/sql/model/informationschemamodel.go +++ b/tools/goctl/model/sql/model/informationschemamodel.go @@ -6,6 +6,7 @@ import ( "github.com/tal-tech/go-zero/core/stores/sqlx" "github.com/tal-tech/go-zero/tools/goctl/model/sql/util" + su "github.com/tal-tech/go-zero/tools/goctl/util" ) const indexPri = "PRIMARY" @@ -144,14 +145,15 @@ func (m *InformationSchemaModel) FindIndex(db, table, column string) ([]*DbIndex // Convert converts column data into Table func (c *ColumnData) Convert() (*Table, error) { var table Table - table.Table = c.Table - table.Db = c.Db + table.Table = su.EscapeGolangKeyword(c.Table) + table.Db = su.EscapeGolangKeyword(c.Db) table.Columns = c.Columns table.UniqueIndex = map[string][]*Column{} table.NormalIndex = map[string][]*Column{} m := make(map[string][]*Column) for _, each := range c.Columns { + each.Name = su.EscapeGolangKeyword(each.Name) each.Comment = util.TrimNewLine(each.Comment) if each.Index != nil { m[each.Index.IndexName] = append(m[each.Index.IndexName], each) diff --git a/tools/goctl/model/sql/parser/parser.go b/tools/goctl/model/sql/parser/parser.go index 27821024..eb46882c 100644 --- a/tools/goctl/model/sql/parser/parser.go +++ b/tools/goctl/model/sql/parser/parser.go @@ -10,6 +10,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/model/sql/converter" "github.com/tal-tech/go-zero/tools/goctl/model/sql/model" "github.com/tal-tech/go-zero/tools/goctl/model/sql/util" + su "github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/zeromicro/ddl-parser/parser" @@ -49,11 +50,12 @@ type ( // Parse parses ddl into golang structure func Parse(filename, database string) ([]*Table, error) { p := parser.NewParser() - tables, err := p.From(filename) + ts, err := p.From(filename) if err != nil { return nil, err } + tables := GetSafeTables(ts) indexNameGen := func(column ...string) string { return strings.Join(column, "_") } @@ -167,7 +169,7 @@ func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string joinRet := strings.Join(list, ",") if uniqueSet.Contains(joinRet) { - log.Warning("table %s: duplicate unique index %s", tableName, joinRet) + log.Warning("[checkDuplicateUniqueIndex]: table %s: duplicate unique index %s", tableName, joinRet) delete(uniqueIndex, k) continue } @@ -213,10 +215,10 @@ func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, ma if column.Constraint != nil { if column.Name == primaryColumn { if !column.Constraint.AutoIncrement && dataType == "int64" { - log.Warning("%s: The primary key is recommended to add constraint `AUTO_INCREMENT`", column.Name) + log.Warning("[convertColumns]: The primary key %q is recommended to add constraint `AUTO_INCREMENT`", column.Name) } } else if column.Constraint.NotNull && !column.Constraint.HasDefaultValue { - log.Warning("%s: The column is recommended to add constraint `DEFAULT`", column.Name) + log.Warning("[convertColumns]: The column %q is recommended to add constraint `DEFAULT`", column.Name) } } @@ -302,7 +304,7 @@ func ConvertDataType(table *model.Table) (*Table, error) { if len(each) == 1 { one := each[0] if one.Name == table.PrimaryKey.Name { - log.Warning("table %s: duplicate unique index with primary key, %s", table.Table, one.Name) + log.Warning("[ConvertDataType]: table q%, duplicate unique index with primary key: %q", table.Table, one.Name) continue } } @@ -316,7 +318,7 @@ func ConvertDataType(table *model.Table) (*Table, error) { uniqueKey := strings.Join(uniqueJoin, ",") if uniqueIndexSet.Contains(uniqueKey) { - log.Warning("table %s: duplicate unique index, %s", table.Table, uniqueKey) + log.Warning("[ConvertDataType]: table %q, duplicate unique index %q", table.Table, uniqueKey) continue } @@ -351,3 +353,33 @@ func getTableFields(table *model.Table) (map[string]*Field, error) { } return fieldM, nil } + +func GetSafeTables(tables []*parser.Table) []*parser.Table { + var list []*parser.Table + for _, t := range tables { + table := GetSafeTable(t) + list = append(list, table) + } + + return list +} + +func GetSafeTable(table *parser.Table) *parser.Table { + table.Name = su.EscapeGolangKeyword(table.Name) + for _, c := range table.Columns { + c.Name = su.EscapeGolangKeyword(c.Name) + } + + for _, e := range table.Constraints { + var uniqueKeys, primaryKeys []string + for _, u := range e.ColumnUniqueKey { + uniqueKeys = append(uniqueKeys, su.EscapeGolangKeyword(u)) + } + for _, p := range e.ColumnPrimaryKey { + primaryKeys = append(primaryKeys, su.EscapeGolangKeyword(p)) + } + e.ColumnUniqueKey = uniqueKeys + e.ColumnPrimaryKey = primaryKeys + } + return table +} diff --git a/tools/goctl/util/string.go b/tools/goctl/util/string.go index e8325f0d..0aa93845 100644 --- a/tools/goctl/util/string.go +++ b/tools/goctl/util/string.go @@ -1,6 +1,37 @@ package util -import "strings" +import ( + "strings" + + "github.com/tal-tech/go-zero/tools/goctl/util/console" +) + +var goKeyword = map[string]string{ + "var": "variable", + "const": "constant", + "package": "pkg", + "func": "function", + "return": "rtn", + "defer": "dfr", + "go": "goo", + "select": "slt", + "struct": "structure", + "interface": "itf", + "chan": "channel", + "type": "tp", + "map": "mp", + "range": "rg", + "break": "brk", + "case": "caz", + "continue": "ctn", + "for": "fr", + "fallthrough": "fth", + "else": "es", + "if": "ef", + "switch": "swt", + "goto": "gt", + "default": "dft", +} // Title returns a string value with s[0] which has been convert into upper case that // there are not empty input text @@ -64,3 +95,18 @@ func isLetter(r rune) bool { func isNumber(r rune) bool { return '0' <= r && r <= '9' } + +func EscapeGolangKeyword(s string) string { + if !isGolangKeyword(s) { + return s + } + + r := goKeyword[s] + console.Info("[EscapeGolangKeyword]: go keyword is forbidden %q, converted into %q", s, r) + return r +} + +func isGolangKeyword(s string) bool { + _, ok := goKeyword[s] + return ok +} diff --git a/tools/goctl/util/string_test.go b/tools/goctl/util/string_test.go index 5680ead7..46a0c47a 100644 --- a/tools/goctl/util/string_test.go +++ b/tools/goctl/util/string_test.go @@ -1,6 +1,7 @@ package util import ( + "strings" "testing" "github.com/stretchr/testify/assert" @@ -64,3 +65,10 @@ func TestSafeString(t *testing.T) { assert.Equal(t, e.expected, SafeString(e.input)) } } + +func TestEscapeGoKeyword(t *testing.T) { + for k := range goKeyword { + assert.Equal(t, goKeyword[k], EscapeGolangKeyword(k)) + assert.False(t, isGolangKeyword(strings.Title(k))) + } +}