diff --git a/tools/goctl/util/file.go b/tools/goctl/util/file.go index 3f7bf9a8..4032f339 100644 --- a/tools/goctl/util/file.go +++ b/tools/goctl/util/file.go @@ -95,12 +95,30 @@ func GetGitHome() (string, error) { // GetTemplateDir returns the category path value in GoctlHome where could get it by GetGoctlHome func GetTemplateDir(category string) (string, error) { - goctlHome, err := GetGoctlHome() + home, err := GetGoctlHome() if err != nil { return "", err } + if home == goctlHome { + // backward compatible, it will be removed in the feature + // backward compatible start + beforeTemplateDir := filepath.Join(home, version.GetGoctlVersion(), category) + fs, _ := ioutil.ReadDir(beforeTemplateDir) + var hasContent bool + for _, e := range fs { + if e.Size() > 0 { + hasContent = true + } + } + if hasContent { + return beforeTemplateDir, nil + } + // backward compatible end + + return filepath.Join(home, category), nil + } - return filepath.Join(goctlHome, version.GetGoctlVersion(), category), nil + return filepath.Join(home, version.GetGoctlVersion(), category), nil } // InitTemplates creates template files GoctlHome where could get it by GetGoctlHome diff --git a/tools/goctl/util/file_test.go b/tools/goctl/util/file_test.go index b0927f07..77aa5a59 100644 --- a/tools/goctl/util/file_test.go +++ b/tools/goctl/util/file_test.go @@ -1,13 +1,66 @@ package util import ( + "io/ioutil" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/internal/version" ) +func TestGetTemplateDir(t *testing.T) { + category := "foo" + t.Run("before_have_templates", func(t *testing.T) { + home := t.TempDir() + RegisterGoctlHome("") + RegisterGoctlHome(home) + v := version.GetGoctlVersion() + dir := filepath.Join(home, v, category) + err := MkdirIfNotExist(dir) + if err != nil { + return + } + tempFile := filepath.Join(dir, "bar.txt") + err = ioutil.WriteFile(tempFile, []byte("foo"), os.ModePerm) + if err != nil { + return + } + templateDir, err := GetTemplateDir(category) + if err != nil { + return + } + assert.Equal(t, dir, templateDir) + RegisterGoctlHome("") + }) + + t.Run("before_has_no_template", func(t *testing.T) { + home := t.TempDir() + RegisterGoctlHome("") + RegisterGoctlHome(home) + dir := filepath.Join(home, category) + err := MkdirIfNotExist(dir) + if err != nil { + return + } + templateDir, err := GetTemplateDir(category) + if err != nil { + return + } + assert.Equal(t, dir, templateDir) + }) + + t.Run("default", func(t *testing.T) { + RegisterGoctlHome("") + dir, err := GetTemplateDir(category) + if err != nil { + return + } + assert.Contains(t, dir, version.BuildVersion) + }) +} + func TestGetGitHome(t *testing.T) { homeDir, err := os.UserHomeDir() if err != nil {