feat(goctl): supports api multi-level importing (#1747)

* feat(goctl): supports api  multi-level importing

Resolves: #1744

* fix(goctl): import-cycle, etc.

import-cycle will not be allowed
e.g., a.api -> b.api -> a.api
regular multiple-import will be allowed
e.g., a.api -> b.api -> c.api
                   -> c.api

* refactor(goctl): adds comments to exported var

* fix(goctl): typo in a comment
master
Fyn 3 years ago committed by GitHub
parent 252fabcc4b
commit 6d9dfc08f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,12 +15,18 @@ import (
type ( type (
// Parser provides api parsing capabilities // Parser provides api parsing capabilities
Parser struct { Parser struct {
linePrefix string
debug bool
log console.Console
antlr.DefaultErrorListener antlr.DefaultErrorListener
linePrefix string
debug bool
log console.Console
src string src string
skipCheckTypeDeclaration bool skipCheckTypeDeclaration bool
handlerMap map[string]PlaceHolder
routeMap map[string]PlaceHolder
typeMap map[string]PlaceHolder
fileMap map[string]PlaceHolder
importStatck importStack
syntax *SyntaxExpr
} }
// ParserOption defines an function with argument Parser // ParserOption defines an function with argument Parser
@ -35,6 +41,10 @@ func NewParser(options ...ParserOption) *Parser {
for _, opt := range options { for _, opt := range options {
opt(p) opt(p)
} }
p.handlerMap = make(map[string]PlaceHolder)
p.routeMap = make(map[string]PlaceHolder)
p.typeMap = make(map[string]PlaceHolder)
p.fileMap = make(map[string]PlaceHolder)
return p return p
} }
@ -84,6 +94,7 @@ func (p *Parser) Parse(filename string) (*Api, error) {
return nil, err return nil, err
} }
p.importStatck.push(p.src)
return p.parse(filename, data) return p.parse(filename, data)
} }
@ -100,6 +111,7 @@ func (p *Parser) ParseContent(content string, filename ...string) (*Api, error)
p.src = abs p.src = abs
} }
p.importStatck.push(p.src)
return p.parse(f, content) return p.parse(f, content)
} }
@ -113,12 +125,43 @@ func (p *Parser) parse(filename, content string) (*Api, error) {
var apiAstList []*Api var apiAstList []*Api
apiAstList = append(apiAstList, root) apiAstList = append(apiAstList, root)
for _, imp := range root.Import { p.storeVerificationInfo(root)
p.syntax = root.Syntax
impApiAstList, err := p.invokeImportedApi(root.Import)
if err != nil {
return nil, err
}
apiAstList = append(apiAstList, impApiAstList...)
if !p.skipCheckTypeDeclaration {
err = p.checkTypeDeclaration(apiAstList)
if err != nil {
return nil, err
}
}
allApi := p.memberFill(apiAstList)
return allApi, nil
}
func (p *Parser) invokeImportedApi(imports []*ImportExpr) ([]*Api, error) {
var apiAstList []*Api
for _, imp := range imports {
dir := filepath.Dir(p.src) dir := filepath.Dir(p.src)
impPath := strings.ReplaceAll(imp.Value.Text(), "\"", "") impPath := strings.ReplaceAll(imp.Value.Text(), "\"", "")
if !filepath.IsAbs(impPath) { if !filepath.IsAbs(impPath) {
impPath = filepath.Join(dir, impPath) impPath = filepath.Join(dir, impPath)
} }
// import cycle check
if err := p.importStatck.push(impPath); err != nil {
return nil, err
}
// ignore already imported file
if p.alreadyImported(impPath) {
continue
}
p.fileMap[impPath] = PlaceHolder{}
data, err := p.readContent(impPath) data, err := p.readContent(impPath)
if err != nil { if err != nil {
return nil, err return nil, err
@ -129,23 +172,26 @@ func (p *Parser) parse(filename, content string) (*Api, error) {
return nil, err return nil, err
} }
err = p.valid(root, nestedApi) err = p.valid(nestedApi)
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.storeVerificationInfo(nestedApi)
apiAstList = append(apiAstList, nestedApi) apiAstList = append(apiAstList, nestedApi)
} list, err := p.invokeImportedApi(nestedApi.Import)
p.importStatck.pop()
apiAstList = append(apiAstList, list...)
if !p.skipCheckTypeDeclaration {
err = p.checkTypeDeclaration(apiAstList)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
return apiAstList, nil
}
allApi := p.memberFill(apiAstList) func (p *Parser) alreadyImported(filename string) bool {
return allApi, nil _, ok := p.fileMap[filename]
return ok
} }
func (p *Parser) invoke(linePrefix, content string) (v *Api, err error) { func (p *Parser) invoke(linePrefix, content string) (v *Api, err error) {
@ -184,58 +230,48 @@ func (p *Parser) invoke(linePrefix, content string) (v *Api, err error) {
return return
} }
func (p *Parser) valid(mainApi, nestedApi *Api) error { // storeVerificationInfo stores information for verification
err := p.nestedApiCheck(mainApi, nestedApi) func (p *Parser) storeVerificationInfo(api *Api) {
if err != nil { routeMap := func(list []*ServiceRoute) {
return err
}
mainHandlerMap := make(map[string]PlaceHolder)
mainRouteMap := make(map[string]PlaceHolder)
mainTypeMap := make(map[string]PlaceHolder)
routeMap := func(list []*ServiceRoute) (map[string]PlaceHolder, map[string]PlaceHolder) {
handlerMap := make(map[string]PlaceHolder)
routeMap := make(map[string]PlaceHolder)
for _, g := range list { for _, g := range list {
handler := g.GetHandler() handler := g.GetHandler()
if handler.IsNotNil() { if handler.IsNotNil() {
handlerName := handler.Text() handlerName := handler.Text()
handlerMap[handlerName] = Holder p.handlerMap[handlerName] = Holder
route := fmt.Sprintf("%s://%s", g.Route.Method.Text(), g.Route.Path.Text()) route := fmt.Sprintf("%s://%s", g.Route.Method.Text(), g.Route.Path.Text())
routeMap[route] = Holder p.routeMap[route] = Holder
} }
} }
}
return handlerMap, routeMap for _, each := range api.Service {
routeMap(each.ServiceApi.ServiceRoute)
} }
for _, each := range mainApi.Service { for _, each := range api.Type {
h, r := routeMap(each.ServiceApi.ServiceRoute) p.typeMap[each.NameExpr().Text()] = Holder
}
}
for k, v := range h { func (p *Parser) valid(nestedApi *Api) error {
mainHandlerMap[k] = v
}
for k, v := range r { if p.syntax != nil && nestedApi.Syntax != nil {
mainRouteMap[k] = v if p.syntax.Version.Text() != nestedApi.Syntax.Version.Text() {
syntaxToken := nestedApi.Syntax.Syntax
return fmt.Errorf("%s line %d:%d multiple syntax declaration, expecting syntax '%s', but found '%s'",
nestedApi.LinePrefix, syntaxToken.Line(), syntaxToken.Column(), p.syntax.Version.Text(), nestedApi.Syntax.Version.Text())
} }
} }
for _, each := range mainApi.Type {
mainTypeMap[each.NameExpr().Text()] = Holder
}
// duplicate route check // duplicate route check
err = p.duplicateRouteCheck(nestedApi, mainHandlerMap, mainRouteMap) err := p.duplicateRouteCheck(nestedApi)
if err != nil { if err != nil {
return err return err
} }
// duplicate type check // duplicate type check
for _, each := range nestedApi.Type { for _, each := range nestedApi.Type {
if _, ok := mainTypeMap[each.NameExpr().Text()]; ok { if _, ok := p.typeMap[each.NameExpr().Text()]; ok {
return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'", return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'",
nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text()) nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text())
} }
@ -244,7 +280,7 @@ func (p *Parser) valid(mainApi, nestedApi *Api) error {
return nil return nil
} }
func (p *Parser) duplicateRouteCheck(nestedApi *Api, mainHandlerMap, mainRouteMap map[string]PlaceHolder) error { func (p *Parser) duplicateRouteCheck(nestedApi *Api) error {
for _, each := range nestedApi.Service { for _, each := range nestedApi.Service {
var prefix, group string var prefix, group string
if each.AtServer != nil { if each.AtServer != nil {
@ -267,13 +303,13 @@ func (p *Parser) duplicateRouteCheck(nestedApi *Api, mainHandlerMap, mainRouteMa
if len(group) > 0 { if len(group) > 0 {
handlerKey = fmt.Sprintf("%s/%s", group, handler.Text()) handlerKey = fmt.Sprintf("%s/%s", group, handler.Text())
} }
if _, ok := mainHandlerMap[handlerKey]; ok { if _, ok := p.handlerMap[handlerKey]; ok {
return fmt.Errorf("%s line %d:%d duplicate handler '%s'", return fmt.Errorf("%s line %d:%d duplicate handler '%s'",
nestedApi.LinePrefix, handler.Line(), handler.Column(), handlerKey) nestedApi.LinePrefix, handler.Line(), handler.Column(), handlerKey)
} }
p := fmt.Sprintf("%s://%s", r.Route.Method.Text(), path.Join(prefix, r.Route.Path.Text())) routeKey := fmt.Sprintf("%s://%s", r.Route.Method.Text(), path.Join(prefix, r.Route.Path.Text()))
if _, ok := mainRouteMap[p]; ok { if _, ok := p.routeMap[routeKey]; ok {
return fmt.Errorf("%s line %d:%d duplicate route '%s'", return fmt.Errorf("%s line %d:%d duplicate route '%s'",
nestedApi.LinePrefix, r.Route.Method.Line(), r.Route.Method.Column(), r.Route.Method.Text()+" "+r.Route.Path.Text()) nestedApi.LinePrefix, r.Route.Method.Line(), r.Route.Method.Column(), r.Route.Method.Text()+" "+r.Route.Path.Text())
} }

@ -0,0 +1,99 @@
package ast
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
func Test_ImportCycle(t *testing.T) {
const (
mainFilename = "main.api"
subAFilename = "a.api"
subBFilename = "b.api"
mainSrc = `import "./a.api"`
subASrc = `import "./b.api"`
subBSrc = `import "./a.api"`
)
var err error
dir := pathx.MustTempDir()
defer os.RemoveAll(dir)
mainPath := filepath.Join(dir, mainFilename)
err = ioutil.WriteFile(mainPath, []byte(mainSrc), 0777)
require.NoError(t, err)
subAPath := filepath.Join(dir, subAFilename)
err = ioutil.WriteFile(subAPath, []byte(subASrc), 0777)
require.NoError(t, err)
subBPath := filepath.Join(dir, subBFilename)
err = ioutil.WriteFile(subBPath, []byte(subBSrc), 0777)
require.NoError(t, err)
_, err = NewParser().Parse(mainPath)
assert.ErrorIs(t, err, ErrImportCycleNotAllowed)
}
func Test_MultiImportedShouldAllowed(t *testing.T) {
const (
mainFilename = "main.api"
subAFilename = "a.api"
subBFilename = "b.api"
mainSrc = "import \"./b.api\"\n" +
"import \"./a.api\"\n" +
"type Main { b B `json:\"b\"`}"
subASrc = "import \"./b.api\"\n" +
"type A { b B `json:\"b\"`}\n"
subBSrc = `type B{}`
)
var err error
dir := pathx.MustTempDir()
defer os.RemoveAll(dir)
mainPath := filepath.Join(dir, mainFilename)
err = ioutil.WriteFile(mainPath, []byte(mainSrc), 0777)
require.NoError(t, err)
subAPath := filepath.Join(dir, subAFilename)
err = ioutil.WriteFile(subAPath, []byte(subASrc), 0777)
require.NoError(t, err)
subBPath := filepath.Join(dir, subBFilename)
err = ioutil.WriteFile(subBPath, []byte(subBSrc), 0777)
require.NoError(t, err)
_, err = NewParser().Parse(mainPath)
assert.NoError(t, err)
}
func Test_RedundantDeclarationShouldNotBeAllowed(t *testing.T) {
const (
mainFilename = "main.api"
subAFilename = "a.api"
subBFilename = "b.api"
mainSrc = "import \"./a.api\"\n" +
"import \"./b.api\"\n"
subASrc = `import "./b.api"
type A{}`
subBSrc = `type A{}`
)
var err error
dir := pathx.MustTempDir()
defer os.RemoveAll(dir)
mainPath := filepath.Join(dir, mainFilename)
err = ioutil.WriteFile(mainPath, []byte(mainSrc), 0777)
require.NoError(t, err)
subAPath := filepath.Join(dir, subAFilename)
err = ioutil.WriteFile(subAPath, []byte(subASrc), 0777)
require.NoError(t, err)
subBPath := filepath.Join(dir, subBFilename)
err = ioutil.WriteFile(subBPath, []byte(subBSrc), 0777)
require.NoError(t, err)
_, err = NewParser().Parse(mainPath)
assert.Error(t, err)
assert.Contains(t, err.Error(), "duplicate type declaration")
}

@ -0,0 +1,23 @@
package ast
import "errors"
// ErrImportCycleNotAllowed defines an error for circular importing
var ErrImportCycleNotAllowed = errors.New("import cycle not allowed")
// importStack a stack of import paths
type importStack []string
func (s *importStack) push(p string) error {
for _, x := range *s {
if x == p {
return ErrImportCycleNotAllowed
}
}
*s = append(*s, p)
return nil
}
func (s *importStack) pop() {
*s = (*s)[0 : len(*s)-1]
}
Loading…
Cancel
Save