(goctl:) fix circle import in case new parser (#3750)

Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
master
kesonan 1 year ago committed by GitHub
parent c46bcf7e1b
commit 5e63002cf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,31 @@
package importstack
import "errors"
// ErrImportCycleNotAllowed defines an error for circular importing
var ErrImportCycleNotAllowed = errors.New("import cycle not allowed")
// ImportStack a stack of import paths
type ImportStack []string
func New() *ImportStack {
return &ImportStack{}
}
func (s *ImportStack) Push(p string) error {
for _, x := range *s {
if x == p {
return ErrImportCycleNotAllowed
}
}
*s = append(*s, p)
return nil
}
func (s *ImportStack) Pop() {
*s = (*s)[0 : len(*s)-1]
}
func (s *ImportStack) List() []string {
return *s
}

@ -5,8 +5,10 @@ import (
"sort" "sort"
"strings" "strings"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/tools/goctl/api/spec" "github.com/zeromicro/go-zero/tools/goctl/api/spec"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/ast" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/ast"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/importstack"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/placeholder" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/placeholder"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/token" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/token"
) )
@ -390,9 +392,14 @@ func Parse(filename string, src interface{}) (*spec.ApiSpec, error) {
return nil, err return nil, err
} }
var importManager = make(map[string]placeholder.Type) is := importstack.New()
importManager[ast.Filename] = placeholder.PlaceHolder err := is.Push(ast.Filename)
api, err := convert2API(ast, importManager) if err != nil {
return nil, err
}
importSet := map[string]lang.PlaceholderType{}
api, err := convert2API(ast, importSet, is)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -5,7 +5,9 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/ast" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/ast"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/importstack"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/placeholder" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/placeholder"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/token" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/token"
) )
@ -18,15 +20,17 @@ type API struct {
importStmt []ast.ImportStmt // ImportStmt block does not participate in code generation. importStmt []ast.ImportStmt // ImportStmt block does not participate in code generation.
TypeStmt []ast.TypeStmt TypeStmt []ast.TypeStmt
ServiceStmts []*ast.ServiceStmt ServiceStmts []*ast.ServiceStmt
importManager map[string]placeholder.Type importManager *importstack.ImportStack
importSet map[string]lang.PlaceholderType
} }
func convert2API(a *ast.AST, importManager map[string]placeholder.Type) (*API, error) { func convert2API(a *ast.AST, importSet map[string]lang.PlaceholderType, is *importstack.ImportStack) (*API, error) {
var api = new(API) var api = new(API)
api.importManager = make(map[string]placeholder.Type) api.importManager = is
api.importSet = make(map[string]lang.PlaceholderType)
api.Filename = a.Filename api.Filename = a.Filename
for k, v := range importManager { for k, v := range importSet {
api.importManager[k] = v api.importSet[k] = v
} }
one := a.Stmts[0] one := a.Stmts[0]
syntax, ok := one.(*ast.SyntaxStmt) syntax, ok := one.(*ast.SyntaxStmt)
@ -230,9 +234,6 @@ func (api *API) getAtServerValue(atServer *ast.AtServerStmt, key string) string
} }
func (api *API) mergeAPI(in *API) error { func (api *API) mergeAPI(in *API) error {
for k, v := range in.importManager {
api.importManager[k] = v
}
if api.Syntax.Value.Format() != in.Syntax.Value.Format() { if api.Syntax.Value.Format() != in.Syntax.Value.Format() {
return ast.SyntaxError(in.Syntax.Value.Pos(), return ast.SyntaxError(in.Syntax.Value.Pos(),
"multiple syntax value expression, expected <%s>, got <%s>", "multiple syntax value expression, expected <%s>, got <%s>",
@ -269,11 +270,15 @@ func (api *API) parseImportedAPI(imports []ast.ImportStmt) ([]*API, error) {
impPath = filepath.Join(dir, impPath) impPath = filepath.Join(dir, impPath)
} }
// import cycle check // import cycle check
if _, ok := api.importManager[impPath]; ok { if err := api.importManager.Push(impPath); err != nil {
return nil, ast.SyntaxError(tok.Position, "import circle not allowed") return nil, ast.SyntaxError(tok.Position, err.Error())
} else { }
api.importManager[impPath] = placeholder.PlaceHolder
if _, ok := api.importSet[impPath]; ok {
api.importManager.Pop()
continue
} }
api.importSet[impPath] = lang.Placeholder
p := New(impPath, "") p := New(impPath, "")
ast := p.Parse() ast := p.Parse()
@ -281,7 +286,7 @@ func (api *API) parseImportedAPI(imports []ast.ImportStmt) ([]*API, error) {
return nil, err return nil, err
} }
nestedApi, err := convert2API(ast, api.importManager) nestedApi, err := convert2API(ast, api.importSet, api.importManager)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -290,6 +295,7 @@ func (api *API) parseImportedAPI(imports []ast.ImportStmt) ([]*API, error) {
return nil, err return nil, err
} }
api.importManager.Pop()
list = append(list, nestedApi) list = append(list, nestedApi)
if err != nil { if err != nil {

@ -26,7 +26,6 @@ const (
stringOpen stringOpen
stringClose stringClose
// string mode end // string mode end
) )
var missingInput = errors.New("missing input") var missingInput = errors.New("missing input")
@ -268,6 +267,7 @@ func (s *Scanner) scanNanosecond(bgPos int) token.Token {
return s.illegalToken() return s.illegalToken()
} }
s.readRune() s.readRune()
return token.Token{ return token.Token{
Type: token.DURATION, Type: token.DURATION,
Text: string(s.data[bgPos:s.position]), Text: string(s.data[bgPos:s.position]),
@ -485,6 +485,7 @@ func (s *Scanner) scanLineComment() token.Token {
for s.ch != '\n' && s.ch != 0 { for s.ch != '\n' && s.ch != 0 {
s.readRune() s.readRune()
} }
return token.Token{ return token.Token{
Type: token.COMMENT, Type: token.COMMENT,
Text: string(s.data[position:s.position]), Text: string(s.data[position:s.position]),
@ -546,6 +547,7 @@ func (s *Scanner) assertExpected(actual token.Type, expected ...token.Type) erro
strings.Join(expects, " | "), strings.Join(expects, " | "),
actual.String(), actual.String(),
)) ))
return errors.New(text) return errors.New(text)
} }
@ -560,6 +562,7 @@ func (s *Scanner) assertExpectedString(actual string, expected ...string) error
strings.Join(expects, " | "), strings.Join(expects, " | "),
actual, actual,
)) ))
return errors.New(text) return errors.New(text)
} }
@ -647,21 +650,22 @@ func NewScanner(filename string, src interface{}) (*Scanner, error) {
} }
func readData(filename string, src interface{}) ([]byte, error) { func readData(filename string, src interface{}) ([]byte, error) {
data, err := os.ReadFile(filename) if strings.HasSuffix(filename, ".api") {
if err == nil { data, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
return data, nil return data, nil
} }
switch v := src.(type) { switch v := src.(type) {
case []byte: case []byte:
data = append(data, v...) return v, nil
case *bytes.Buffer: case *bytes.Buffer:
data = v.Bytes() return v.Bytes(), nil
case string: case string:
data = []byte(v) return []byte(v), nil
default: default:
return nil, fmt.Errorf("unsupported type: %T", src) return nil, fmt.Errorf("unsupported type: %T", src)
} }
return data, nil
} }

Loading…
Cancel
Save