diff --git a/tools/goctl/model/sql/command/command.go b/tools/goctl/model/sql/command/command.go index fa016c5e..03f29f18 100644 --- a/tools/goctl/model/sql/command/command.go +++ b/tools/goctl/model/sql/command/command.go @@ -50,7 +50,7 @@ func MysqlDDL(ctx *cli.Context) error { return err } - return fromDDl(src, dir, cfg, cache, idea, database) + return fromDDL(src, dir, cfg, cache, idea, database) } // MySqlDataSource generates model code from datasource @@ -102,7 +102,7 @@ func PostgreSqlDataSource(ctx *cli.Context) error { return fromPostgreSqlDataSource(url, pattern, dir, schema, cfg, cache, idea) } -func fromDDl(src, dir string, cfg *config.Config, cache, idea bool, database string) error { +func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database string) error { log := console.NewConsole(idea) src = strings.TrimSpace(src) if len(src) == 0 { diff --git a/tools/goctl/model/sql/command/command_test.go b/tools/goctl/model/sql/command/command_test.go index d2495772..8da71a17 100644 --- a/tools/goctl/model/sql/command/command_test.go +++ b/tools/goctl/model/sql/command/command_test.go @@ -24,12 +24,12 @@ func TestFromDDl(t *testing.T) { err := gen.Clean() assert.Nil(t, err) - err = fromDDl("./user.sql", t.TempDir(), cfg, true, false, "go_zero") + err = fromDDL("./user.sql", t.TempDir(), cfg, true, false, "go_zero") assert.Equal(t, errNotMatched, err) // case dir is not exists unknownDir := filepath.Join(t.TempDir(), "test", "user.sql") - err = fromDDl(unknownDir, t.TempDir(), cfg, true, false, "go_zero") + err = fromDDL(unknownDir, t.TempDir(), cfg, true, false, "go_zero") assert.True(t, func() bool { switch err.(type) { case *os.PathError: @@ -40,7 +40,7 @@ func TestFromDDl(t *testing.T) { }()) // case empty src - err = fromDDl("", t.TempDir(), cfg, true, false, "go_zero") + err = fromDDL("", t.TempDir(), cfg, true, false, "go_zero") if err != nil { assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error()) } @@ -70,9 +70,18 @@ func TestFromDDl(t *testing.T) { _, err = os.Stat(user2Sql) assert.Nil(t, err) - err = fromDDl(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false, "go_zero") - assert.Nil(t, err) + filename := filepath.Join(tempDir, "usermodel.go") + fromDDL := func(db string) { + err = fromDDL(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false, db) + assert.Nil(t, err) - _, err = os.Stat(filepath.Join(tempDir, "usermodel.go")) - assert.Nil(t, err) + _, err = os.Stat(filename) + assert.Nil(t, err) + } + + fromDDL("go_zero") + _ = os.Remove(filename) + fromDDL("go-zero") + _ = os.Remove(filename) + fromDDL("1gozero") } diff --git a/tools/goctl/model/sql/example/makefile b/tools/goctl/model/sql/example/makefile index fa80267c..c08e09ee 100644 --- a/tools/goctl/model/sql/example/makefile +++ b/tools/goctl/model/sql/example/makefile @@ -5,6 +5,10 @@ fromDDLWithCache: goctl template clean goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/cache" -cache +fromDDLWithCacheAndDb: + goctl template clean + goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/cache_db" -database="1gozero" -cache + fromDDLWithoutCache: goctl template clean; goctl model mysql ddl -src="./sql/*.sql" -dir="./sql/model/nocache" diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index 300c47cd..b70e3383 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -146,7 +146,7 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error { return err } - name := modelFilename + ".go" + name := util.SafeString(modelFilename) + ".go" filename := filepath.Join(dirAbs, name) if util.FileExists(filename) { g.Warning("%s already exists, ignored.", name) diff --git a/tools/goctl/model/sql/gen/keys.go b/tools/goctl/model/sql/gen/keys.go index 8e0c6b91..db4003e0 100644 --- a/tools/goctl/model/sql/gen/keys.go +++ b/tools/goctl/model/sql/gen/keys.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/tal-tech/go-zero/tools/goctl/model/sql/parser" + "github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) @@ -59,9 +60,16 @@ func genCacheKey(db, table stringx.String, in []*parser.Field) Key { keyLeft, keyRight, dataKeyRight, keyExpression, dataKeyExpression string ) - varLeftJoin = append(varLeftJoin, "cache", db.Source(), table.Source()) - varRightJon = append(varRightJon, "cache", db.Source(), table.Source()) - keyLeftJoin = append(keyLeftJoin, db.Source(), table.Source()) + dbName, tableName := util.SafeString(db.Source()), util.SafeString(table.Source()) + if len(dbName) > 0 { + varLeftJoin = append(varLeftJoin, "cache", dbName, tableName) + varRightJon = append(varRightJon, "cache", dbName, tableName) + keyLeftJoin = append(keyLeftJoin, dbName, tableName) + } else { + varLeftJoin = append(varLeftJoin, "cache", tableName) + varRightJon = append(varRightJon, "cache", tableName) + keyLeftJoin = append(keyLeftJoin, tableName) + } for _, each := range in { varLeftJoin = append(varLeftJoin, each.Name.Source()) @@ -75,11 +83,11 @@ func genCacheKey(db, table stringx.String, in []*parser.Field) Key { varLeftJoin = append(varLeftJoin, "prefix") keyLeftJoin = append(keyLeftJoin, "key") - varLeft = varLeftJoin.Camel().With("").Untitle() + varLeft = util.SafeString(varLeftJoin.Camel().With("").Untitle()) varRight = fmt.Sprintf(`"%s"`, varRightJon.Camel().Untitle().With(":").Source()+":") varExpression = fmt.Sprintf(`%s = %s`, varLeft, varRight) - keyLeft = keyLeftJoin.Camel().With("").Untitle() + keyLeft = util.SafeString(keyLeftJoin.Camel().With("").Untitle()) keyRight = fmt.Sprintf(`fmt.Sprintf("%s%s", %s, %s)`, "%s", keyRightArgJoin.With(":").Source(), varLeft, keyRightJoin.With(", ").Source()) dataKeyRight = fmt.Sprintf(`fmt.Sprintf("%s%s", %s, %s)`, "%s", keyRightArgJoin.With(":").Source(), varLeft, dataRightJoin.With(", ").Source()) keyExpression = fmt.Sprintf("%s := %s", keyLeft, keyRight) diff --git a/tools/goctl/util/env/env_test.go b/tools/goctl/util/env/env_test.go new file mode 100644 index 00000000..044a870a --- /dev/null +++ b/tools/goctl/util/env/env_test.go @@ -0,0 +1,96 @@ +package env + +import ( + "bytes" + "fmt" + "os/exec" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/vars" +) + +func TestLookUpGo(t *testing.T) { + xGo, err := LookUpGo() + if err != nil { + return + } + + assert.True(t, util.FileExists(xGo)) + output, errOutput, err := execCommand(xGo, "version") + if err != nil { + return + } + + if len(errOutput) > 0 { + return + } + assert.Equal(t, wrapVersion(), output) +} + +func TestLookUpProtoc(t *testing.T) { + xProtoc, err := LookUpProtoc() + if err != nil { + return + } + + assert.True(t, util.FileExists(xProtoc)) + output, errOutput, err := execCommand(xProtoc, "--version") + if err != nil { + return + } + + if len(errOutput) > 0 { + return + } + assert.True(t, len(output) > 0) +} + +func TestLookUpProtocGenGo(t *testing.T) { + xProtocGenGo, err := LookUpProtocGenGo() + if err != nil { + return + } + assert.True(t, util.FileExists(xProtocGenGo)) +} + +func TestLookPath(t *testing.T) { + xGo, err := LookPath("go") + if err != nil { + return + } + assert.True(t, util.FileExists(xGo)) +} + +func TestCanExec(t *testing.T) { + canExec := runtime.GOOS != vars.OsJs && runtime.GOOS != vars.OsIOS + assert.Equal(t, canExec, CanExec()) +} + +func execCommand(cmd string, arg ...string) (stdout string, stderr string, err error) { + output := bytes.NewBuffer(nil) + errOutput := bytes.NewBuffer(nil) + c := exec.Command(cmd, arg...) + c.Stdout = output + c.Stderr = errOutput + err = c.Run() + if err != nil { + return + } + if errOutput.Len() > 0 { + stderr = errOutput.String() + return + } + stdout = strings.TrimSpace(output.String()) + return +} + +func wrapVersion() string { + version := runtime.Version() + os := runtime.GOOS + arch := runtime.GOARCH + return fmt.Sprintf("go version %s %s/%s", version, os, arch) +} diff --git a/tools/goctl/util/string.go b/tools/goctl/util/string.go index 019cb145..e8325f0d 100644 --- a/tools/goctl/util/string.go +++ b/tools/goctl/util/string.go @@ -32,3 +32,35 @@ func Index(slice []string, item string) int { return -1 } + +// SafeString converts the input string into a safe naming style in golang +func SafeString(in string) string { + if len(in) == 0 { + return in + } + + data := strings.Map(func(r rune) rune { + if isSafeRune(r) { + return r + } + return '_' + }, in) + + headRune := rune(data[0]) + if isNumber(headRune) { + return "_" + data + } + return data +} + +func isSafeRune(r rune) bool { + return isLetter(r) || isNumber(r) || r == '_' +} + +func isLetter(r rune) bool { + return 'A' <= r && r <= 'z' +} + +func isNumber(r rune) bool { + return '0' <= r && r <= '9' +} diff --git a/tools/goctl/util/string_test.go b/tools/goctl/util/string_test.go new file mode 100644 index 00000000..5680ead7 --- /dev/null +++ b/tools/goctl/util/string_test.go @@ -0,0 +1,66 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type data struct { + input string + expected string +} + +func TestTitle(t *testing.T) { + list := []*data{ + {input: "_", expected: "_"}, + {input: "abc", expected: "Abc"}, + {input: "ABC", expected: "ABC"}, + {input: "", expected: ""}, + {input: " abc", expected: " abc"}, + } + for _, e := range list { + assert.Equal(t, e.expected, Title(e.input)) + } +} + +func TestUntitle(t *testing.T) { + list := []*data{ + {input: "_", expected: "_"}, + {input: "Abc", expected: "abc"}, + {input: "ABC", expected: "aBC"}, + {input: "", expected: ""}, + {input: " abc", expected: " abc"}, + } + + for _, e := range list { + assert.Equal(t, e.expected, Untitle(e.input)) + } +} + +func TestIndex(t *testing.T) { + list := []string{"a", "b", "c"} + assert.Equal(t, 1, Index(list, "b")) + assert.Equal(t, -1, Index(list, "d")) +} + +func TestSafeString(t *testing.T) { + list := []*data{ + {input: "_", expected: "_"}, + {input: "a-b-c", expected: "a_b_c"}, + {input: "123abc", expected: "_123abc"}, + {input: "汉abc", expected: "_abc"}, + {input: "汉a字", expected: "_a_"}, + {input: "キャラクターabc", expected: "______abc"}, + {input: "-a_B-C", expected: "_a_B_C"}, + {input: "a_B C", expected: "a_B_C"}, + {input: "A#B#C", expected: "A_B_C"}, + {input: "_123", expected: "_123"}, + {input: "", expected: ""}, + {input: "\t", expected: "_"}, + {input: "\n", expected: "_"}, + } + for _, e := range list { + assert.Equal(t, e.expected, SafeString(e.input)) + } +}