diff --git a/tools/goctl/pkg/parser/api/importstack/importstack.go b/tools/goctl/pkg/parser/api/importstack/importstack.go new file mode 100644 index 00000000..38546649 --- /dev/null +++ b/tools/goctl/pkg/parser/api/importstack/importstack.go @@ -0,0 +1,31 @@ +package importstack + +import "errors" + +// ErrImportCycleNotAllowed defines an error for circular importing +var ErrImportCycleNotAllowed = errors.New("import cycle not allowed") + +// ImportStack a stack of import paths +type ImportStack []string + +func New() *ImportStack { + return &ImportStack{} +} + +func (s *ImportStack) Push(p string) error { + for _, x := range *s { + if x == p { + return ErrImportCycleNotAllowed + } + } + *s = append(*s, p) + return nil +} + +func (s *ImportStack) Pop() { + *s = (*s)[0 : len(*s)-1] +} + +func (s *ImportStack) List() []string { + return *s +} diff --git a/tools/goctl/pkg/parser/api/parser/analyzer.go b/tools/goctl/pkg/parser/api/parser/analyzer.go index cd03d97c..236630e4 100644 --- a/tools/goctl/pkg/parser/api/parser/analyzer.go +++ b/tools/goctl/pkg/parser/api/parser/analyzer.go @@ -5,8 +5,10 @@ import ( "sort" "strings" + "github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/tools/goctl/api/spec" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/ast" + "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/importstack" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/placeholder" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/token" ) @@ -390,9 +392,14 @@ func Parse(filename string, src interface{}) (*spec.ApiSpec, error) { return nil, err } - var importManager = make(map[string]placeholder.Type) - importManager[ast.Filename] = placeholder.PlaceHolder - api, err := convert2API(ast, importManager) + is := importstack.New() + err := is.Push(ast.Filename) + if err != nil { + return nil, err + } + + importSet := map[string]lang.PlaceholderType{} + api, err := convert2API(ast, importSet, is) if err != nil { return nil, err } diff --git a/tools/goctl/pkg/parser/api/parser/api.go b/tools/goctl/pkg/parser/api/parser/api.go index c67c5b53..10b5a1a4 100644 --- a/tools/goctl/pkg/parser/api/parser/api.go +++ b/tools/goctl/pkg/parser/api/parser/api.go @@ -5,7 +5,9 @@ import ( "path/filepath" "strings" + "github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/ast" + "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/importstack" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/placeholder" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/token" ) @@ -18,15 +20,17 @@ type API struct { importStmt []ast.ImportStmt // ImportStmt block does not participate in code generation. TypeStmt []ast.TypeStmt ServiceStmts []*ast.ServiceStmt - importManager map[string]placeholder.Type + importManager *importstack.ImportStack + importSet map[string]lang.PlaceholderType } -func convert2API(a *ast.AST, importManager map[string]placeholder.Type) (*API, error) { +func convert2API(a *ast.AST, importSet map[string]lang.PlaceholderType, is *importstack.ImportStack) (*API, error) { var api = new(API) - api.importManager = make(map[string]placeholder.Type) + api.importManager = is + api.importSet = make(map[string]lang.PlaceholderType) api.Filename = a.Filename - for k, v := range importManager { - api.importManager[k] = v + for k, v := range importSet { + api.importSet[k] = v } one := a.Stmts[0] syntax, ok := one.(*ast.SyntaxStmt) @@ -230,9 +234,6 @@ func (api *API) getAtServerValue(atServer *ast.AtServerStmt, key string) string } func (api *API) mergeAPI(in *API) error { - for k, v := range in.importManager { - api.importManager[k] = v - } if api.Syntax.Value.Format() != in.Syntax.Value.Format() { return ast.SyntaxError(in.Syntax.Value.Pos(), "multiple syntax value expression, expected <%s>, got <%s>", @@ -269,11 +270,15 @@ func (api *API) parseImportedAPI(imports []ast.ImportStmt) ([]*API, error) { impPath = filepath.Join(dir, impPath) } // import cycle check - if _, ok := api.importManager[impPath]; ok { - return nil, ast.SyntaxError(tok.Position, "import circle not allowed") - } else { - api.importManager[impPath] = placeholder.PlaceHolder + if err := api.importManager.Push(impPath); err != nil { + return nil, ast.SyntaxError(tok.Position, err.Error()) + } + + if _, ok := api.importSet[impPath]; ok { + api.importManager.Pop() + continue } + api.importSet[impPath] = lang.Placeholder p := New(impPath, "") ast := p.Parse() @@ -281,7 +286,7 @@ func (api *API) parseImportedAPI(imports []ast.ImportStmt) ([]*API, error) { return nil, err } - nestedApi, err := convert2API(ast, api.importManager) + nestedApi, err := convert2API(ast, api.importSet, api.importManager) if err != nil { return nil, err } @@ -290,6 +295,7 @@ func (api *API) parseImportedAPI(imports []ast.ImportStmt) ([]*API, error) { return nil, err } + api.importManager.Pop() list = append(list, nestedApi) if err != nil { diff --git a/tools/goctl/pkg/parser/api/scanner/scanner.go b/tools/goctl/pkg/parser/api/scanner/scanner.go index 35a9332d..87767c17 100644 --- a/tools/goctl/pkg/parser/api/scanner/scanner.go +++ b/tools/goctl/pkg/parser/api/scanner/scanner.go @@ -26,7 +26,6 @@ const ( stringOpen stringClose // string mode end - ) var missingInput = errors.New("missing input") @@ -268,6 +267,7 @@ func (s *Scanner) scanNanosecond(bgPos int) token.Token { return s.illegalToken() } s.readRune() + return token.Token{ Type: token.DURATION, Text: string(s.data[bgPos:s.position]), @@ -485,6 +485,7 @@ func (s *Scanner) scanLineComment() token.Token { for s.ch != '\n' && s.ch != 0 { s.readRune() } + return token.Token{ Type: token.COMMENT, Text: string(s.data[position:s.position]), @@ -546,6 +547,7 @@ func (s *Scanner) assertExpected(actual token.Type, expected ...token.Type) erro strings.Join(expects, " | "), actual.String(), )) + return errors.New(text) } @@ -560,6 +562,7 @@ func (s *Scanner) assertExpectedString(actual string, expected ...string) error strings.Join(expects, " | "), actual, )) + return errors.New(text) } @@ -647,21 +650,22 @@ func NewScanner(filename string, src interface{}) (*Scanner, error) { } func readData(filename string, src interface{}) ([]byte, error) { - data, err := os.ReadFile(filename) - if err == nil { + if strings.HasSuffix(filename, ".api") { + data, err := os.ReadFile(filename) + if err != nil { + return nil, err + } return data, nil } switch v := src.(type) { case []byte: - data = append(data, v...) + return v, nil case *bytes.Buffer: - data = v.Bytes() + return v.Bytes(), nil case string: - data = []byte(v) + return []byte(v), nil default: return nil, fmt.Errorf("unsupported type: %T", src) } - - return data, nil }