diff --git a/tools/goctl/goctl.go b/tools/goctl/goctl.go index 6a5ea653..4bdcec81 100644 --- a/tools/goctl/goctl.go +++ b/tools/goctl/goctl.go @@ -6,6 +6,7 @@ import ( "runtime" "github.com/logrusorgru/aurora" + "github.com/urfave/cli" "github.com/zeromicro/go-zero/core/load" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/tools/goctl/api/apigen" @@ -30,7 +31,6 @@ import ( rpc "github.com/zeromicro/go-zero/tools/goctl/rpc/cli" "github.com/zeromicro/go-zero/tools/goctl/tpl" "github.com/zeromicro/go-zero/tools/goctl/upgrade" - "github.com/urfave/cli" ) const codeFailure = 1 @@ -58,7 +58,7 @@ var commands = []cli.Command{ }, cli.StringFlag{ Name: "version", - Usage: "the target release version of github.com/zeromicro/go-zero to refactor", + Usage: "the target release version of github.com/zeromicro/go-zero to migrate", }, }, }, diff --git a/tools/goctl/migrate/migrate.go b/tools/goctl/migrate/migrate.go index edbb7d76..ac4f13ac 100644 --- a/tools/goctl/migrate/migrate.go +++ b/tools/goctl/migrate/migrate.go @@ -10,13 +10,19 @@ import ( "io/fs" "io/ioutil" "os" + "os/signal" "path/filepath" + "runtime" "strings" + "syscall" "time" + "github.com/logrusorgru/aurora" + "github.com/urfave/cli" + "github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/tools/goctl/util/console" "github.com/zeromicro/go-zero/tools/goctl/util/ctx" - "github.com/urfave/cli" + "github.com/zeromicro/go-zero/tools/goctl/vars" ) const zeromicroVersion = "1.3.0" @@ -44,7 +50,9 @@ func Migrate(c *cli.Context) error { return err } - console.Success("[OK] refactor finish, execute %q on project root to check status.", "go test -race ./...") + if verbose { + console.Success("[OK] refactor finish, execute %q on project root to check status.", "go test -race ./...") + } return nil } @@ -54,6 +62,23 @@ func rewriteImport(verbose bool) error { time.Sleep(200 * time.Millisecond) } + var doneChan = syncx.NewDoneChan() + defer func() { + doneChan.Close() + }() + go func(dc *syncx.DoneChan) { + c := make(chan os.Signal) + signal.Notify(c, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGINT, syscall.SIGTSTP, syscall.SIGQUIT) + select { + case <-c: + console.Error(` +migrate failed, reason: "User Canceled"`) + os.Exit(0) + case <-dc.Done(): + return + } + }(doneChan) + wd, err := os.Getwd() if err != nil { return err @@ -64,7 +89,8 @@ func rewriteImport(verbose bool) error { } root := project.Dir fsys := os.DirFS(root) - return fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error { + var final []*ast.Package + err = fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error { if !d.IsDir() { return nil } @@ -78,22 +104,70 @@ func rewriteImport(verbose bool) error { return err } - return rewriteFile(pkgs, verbose) + err = rewriteFile(pkgs, verbose) + if err != nil { + return err + } + for _, v := range pkgs { + final = append(final, v) + } + return nil }) + if err != nil { + return err + } + + if verbose { + console.Info("start to write files ... ") + } + return writeFile(final, verbose) } func rewriteFile(pkgs map[string]*ast.Package, verbose bool) error { for _, pkg := range pkgs { for filename, file := range pkg.Files { + var containsDeprecatedBuilderxPkg bool for _, imp := range file.Imports { if !strings.Contains(imp.Path.Value, deprecatedGoZeroMod) { continue } + + if verbose { + console.Debug("[...] migrate %q ... ", filepath.Base(filename)) + } + + if strings.Contains(imp.Path.Value, deprecatedBuilderx) { + containsDeprecatedBuilderxPkg = true + var doNext bool + refactorBuilderx(deprecatedBuilderx, replacementBuilderx, func(allow bool) { + doNext = !allow + if allow { + newPath := strings.ReplaceAll(imp.Path.Value, deprecatedBuilderx, replacementBuilderx) + imp.EndPos = imp.End() + imp.Path.Value = newPath + } + }) + if !doNext { + continue + } + } + newPath := strings.ReplaceAll(imp.Path.Value, deprecatedGoZeroMod, goZeroMod) imp.EndPos = imp.End() imp.Path.Value = newPath } + if containsDeprecatedBuilderxPkg { + replacePkg(file) + } + } + } + return nil +} + +func writeFile(pkgs []*ast.Package, verbose bool) error { + for _, pkg := range pkgs { + for filename, file := range pkg.Files { var w = bytes.NewBuffer(nil) err := format.Node(w, fset, file) if err != nil { @@ -105,9 +179,91 @@ func rewriteFile(pkgs map[string]*ast.Package, verbose bool) error { return fmt.Errorf("[rewriteImport] write file %s error: %+v", filename, err) } if verbose { - console.Success("[OK] rewriting %q ... ", filepath.Base(filename)) + console.Success("[OK] migrate %q success ", filepath.Base(filename)) } } } return nil } + +func replacePkg(file *ast.File) { + scope := file.Scope + if scope == nil { + return + } + obj := scope.Objects + for _, v := range obj { + decl := v.Decl + if decl == nil { + continue + } + vs, ok := decl.(*ast.ValueSpec) + if !ok { + continue + } + values := vs.Values + if len(values) != 1 { + continue + } + value := values[0] + callExpr, ok := value.(*ast.CallExpr) + if !ok { + continue + } + fn := callExpr.Fun + if fn == nil { + continue + } + selector, ok := fn.(*ast.SelectorExpr) + if !ok { + continue + } + x := selector.X + sel := selector.Sel + if x == nil || sel == nil { + continue + } + ident, ok := x.(*ast.Ident) + if !ok { + continue + } + if ident.Name == "builderx" { + ident.Name = "builder" + ident.NamePos = ident.End() + } + if sel.Name == "FieldNames" { + sel.Name = "RawFieldNames" + sel.NamePos = sel.End() + } + } +} + +func refactorBuilderx(deprecated, replacement string, fn func(allow bool)) { + msg := fmt.Sprintf(`Detects a deprecated package in the source code, +Deprecated package: %q +Replacement package: %q +It's recommended to use the replacement package, do you want to replace? +[input 'Y' for yes, 'N' for no]:`, deprecated, replacement) + + if runtime.GOOS != vars.OsWindows { + msg = aurora.Yellow(msg).String() + } + fmt.Print(msg) + var in string + for { + fmt.Scanln(&in) + if len(in) == 0 { + console.Warning("nothing input, please try again [input 'Y' for yes, 'N' for no]:") + continue + } + if strings.EqualFold(in, "Y") { + fn(true) + return + } else if strings.EqualFold(in, "N") { + fn(false) + return + } else { + console.Warning("invalid input, please try again [input 'Y' for yes, 'N' for no]:") + } + } +} diff --git a/tools/goctl/migrate/mod.go b/tools/goctl/migrate/mod.go index d1086281..3258d1e8 100644 --- a/tools/goctl/migrate/mod.go +++ b/tools/goctl/migrate/mod.go @@ -12,7 +12,10 @@ import ( "github.com/zeromicro/go-zero/tools/goctl/util/ctx" ) -const deprecatedGoZeroMod = "github.com/zeromicro/go-zero" +const deprecatedGoZeroMod = "github.com/tal-tech/go-zero" + +const deprecatedBuilderx = "github.com/tal-tech/go-zero/tools/goctl/model/sql/builderx" +const replacementBuilderx = "github.com/zeromicro/go-zero/core/stores/builder" const goZeroMod = "github.com/zeromicro/go-zero" var errInvalidGoMod = errors.New("it's only working for go module")