diff --git a/tools/goctl/util/pathx/file.go b/tools/goctl/util/pathx/file.go index 4fe25283..5858ec51 100644 --- a/tools/goctl/util/pathx/file.go +++ b/tools/goctl/util/pathx/file.go @@ -75,12 +75,23 @@ func FileNameWithoutExt(file string) string { // GetGoctlHome returns the path value of the goctl, the default path is ~/.goctl, if the path has // been set by calling the RegisterGoctlHome method, the user-defined path refers to. -func GetGoctlHome() (string, error) { +func GetGoctlHome() (home string, err error) { + defer func() { + if err != nil { + return + } + info, err := os.Stat(home) + if err == nil && !info.IsDir() { + os.Rename(home, home+".old") + MkdirIfNotExist(home) + } + }() if len(goctlHome) != 0 { - return goctlHome, nil + home = goctlHome + return } - - return GetDefaultGoctlHome() + home, err = GetDefaultGoctlHome() + return } // GetDefaultGoctlHome returns the path value of the goctl home where Join $HOME with .goctl. diff --git a/tools/goctl/util/pathx/file_test.go b/tools/goctl/util/pathx/file_test.go index 7aa81642..00690a17 100644 --- a/tools/goctl/util/pathx/file_test.go +++ b/tools/goctl/util/pathx/file_test.go @@ -74,3 +74,35 @@ func TestGetGitHome(t *testing.T) { expected := filepath.Join(homeDir, goctlDir, gitDir) assert.Equal(t, expected, actual) } + +func TestGetGoctlHome(t *testing.T) { + t.Run("goctl_is_file", func(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "a.tmp") + backupTempFile := tmpFile + ".old" + err := ioutil.WriteFile(tmpFile, nil, 0666) + if err != nil { + return + } + RegisterGoctlHome(tmpFile) + home, err := GetGoctlHome() + if err != nil { + return + } + info, err := os.Stat(home) + assert.Nil(t, err) + assert.True(t, info.IsDir()) + + _, err = os.Stat(backupTempFile) + assert.Nil(t, err) + }) + + t.Run("goctl_is_dir", func(t *testing.T) { + RegisterGoctlHome("") + dir := t.TempDir() + RegisterGoctlHome(dir) + home, err := GetGoctlHome() + assert.Nil(t, err) + assert.Equal(t, dir, home) + }) + +}