You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/tools/goctl/pkg/parser/api/parser/api.go

334 lines
8.2 KiB
Go

package parser
import (
"fmt"
"path/filepath"
"strings"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/ast"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/placeholder"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/token"
)
// API is the parsed api file.
type API struct {
Filename string
Syntax *ast.SyntaxStmt
info *ast.InfoStmt // Info block does not participate in code generation.
importStmt []ast.ImportStmt // ImportStmt block does not participate in code generation.
TypeStmt []ast.TypeStmt
ServiceStmts []*ast.ServiceStmt
importManager map[string]placeholder.Type
}
func convert2API(a *ast.AST, importManager map[string]placeholder.Type) (*API, error) {
var api = new(API)
api.importManager = make(map[string]placeholder.Type)
api.Filename = a.Filename
for k, v := range importManager {
api.importManager[k] = v
}
one := a.Stmts[0]
syntax, ok := one.(*ast.SyntaxStmt)
if !ok {
syntax = &ast.SyntaxStmt{
Syntax: ast.NewTokenNode(token.Token{
Type: token.IDENT,
Text: token.Syntax,
}),
Assign: ast.NewTokenNode(token.Token{
Type: token.ASSIGN,
Text: "=",
}),
Value: ast.NewTokenNode(token.Token{
Type: token.STRING,
Text: `"v1"`,
}),
}
}
api.Syntax = syntax
var hasSyntax, hasInfo bool
for i := 0; i < len(a.Stmts); i++ {
one := a.Stmts[i]
switch val := one.(type) {
case *ast.SyntaxStmt:
if hasSyntax {
return nil, ast.DuplicateStmtError(val.Pos(), "duplicate syntax statement")
} else {
hasSyntax = true
}
case *ast.InfoStmt:
if api.info != nil {
if hasInfo {
return nil, ast.DuplicateStmtError(val.Pos(), "duplicate info statement")
}
} else {
hasInfo = true
}
api.info = val
case ast.ImportStmt:
api.importStmt = append(api.importStmt, val)
case ast.TypeStmt:
api.TypeStmt = append(api.TypeStmt, val)
case *ast.ServiceStmt:
api.ServiceStmts = append(api.ServiceStmts, val)
}
}
if err := api.SelfCheck(); err != nil {
return nil, err
}
return api, nil
}
func (api *API) checkImportStmt() error {
f := newFilter()
b := f.addCheckItem("import value expression")
for _, v := range api.importStmt {
switch val := v.(type) {
case *ast.ImportLiteralStmt:
b.check(val.Value)
case *ast.ImportGroupStmt:
b.check(val.Values...)
}
}
return f.error()
}
func (api *API) checkInfoStmt() error {
if api.info == nil {
return nil
}
f := newFilter()
b := f.addCheckItem("info key expression")
for _, v := range api.info.Values {
b.check(v.Key)
}
return f.error()
}
func (api *API) checkServiceStmt() error {
f := newFilter()
serviceNameChecker := f.addCheckItem("service name expression")
handlerChecker := f.addCheckItem("handler expression")
pathChecker := f.addCheckItem("path expression")
var serviceName = map[string]string{}
for _, v := range api.ServiceStmts {
name := strings.TrimSuffix(v.Name.Format(""), "-api")
if sn, ok := serviceName[name]; ok {
if sn != name {
serviceNameChecker.errorManager.add(ast.SyntaxError(v.Name.Pos(), "multiple service name"))
}
} else {
serviceName[name] = name
}
var group = api.getAtServerValue(v.AtServerStmt, "prefix")
for _, item := range v.Routes {
handlerChecker.check(item.AtHandler.Name)
path := fmt.Sprintf("[%s]:%s", group, item.Route.Format(""))
pathChecker.check(ast.NewTokenNode(token.Token{
Text: path,
Position: item.Route.Pos(),
}))
}
}
return f.error()
}
func (api *API) checkTypeStmt() error {
f := newFilter()
b := f.addCheckItem("type expression")
for _, v := range api.TypeStmt {
switch val := v.(type) {
case *ast.TypeLiteralStmt:
b.check(val.Expr.Name)
case *ast.TypeGroupStmt:
for _, expr := range val.ExprList {
b.check(expr.Name)
}
}
}
return f.error()
}
func (api *API) checkTypeDeclareContext() error {
var typeMap = map[string]placeholder.Type{}
for _, v := range api.TypeStmt {
switch tp := v.(type) {
case *ast.TypeLiteralStmt:
typeMap[tp.Expr.Name.Token.Text] = placeholder.PlaceHolder
case *ast.TypeGroupStmt:
for _, v := range tp.ExprList {
typeMap[v.Name.Token.Text] = placeholder.PlaceHolder
}
}
}
return api.checkTypeContext(typeMap)
}
func (api *API) checkTypeContext(declareContext map[string]placeholder.Type) error {
var em = newErrorManager()
for _, v := range api.TypeStmt {
switch tp := v.(type) {
case *ast.TypeLiteralStmt:
em.add(api.checkTypeExprContext(declareContext, tp.Expr.DataType))
case *ast.TypeGroupStmt:
for _, v := range tp.ExprList {
em.add(api.checkTypeExprContext(declareContext, v.DataType))
}
}
}
return em.error()
}
func (api *API) checkTypeExprContext(declareContext map[string]placeholder.Type, tp ast.DataType) error {
switch val := tp.(type) {
case *ast.ArrayDataType:
return api.checkTypeExprContext(declareContext, val.DataType)
case *ast.BaseDataType:
if IsBaseType(val.Base.Token.Text) {
return nil
}
_, ok := declareContext[val.Base.Token.Text]
if !ok {
return ast.SyntaxError(val.Base.Pos(), "unresolved type <%s>", val.Base.Token.Text)
}
return nil
case *ast.MapDataType:
var manager = newErrorManager()
manager.add(api.checkTypeExprContext(declareContext, val.Key))
manager.add(api.checkTypeExprContext(declareContext, val.Value))
return manager.error()
case *ast.PointerDataType:
return api.checkTypeExprContext(declareContext, val.DataType)
case *ast.SliceDataType:
return api.checkTypeExprContext(declareContext, val.DataType)
case *ast.StructDataType:
var manager = newErrorManager()
for _, e := range val.Elements {
manager.add(api.checkTypeExprContext(declareContext, e.DataType))
}
return manager.error()
}
return nil
}
func (api *API) getAtServerValue(atServer *ast.AtServerStmt, key string) string {
if atServer == nil {
return ""
}
for _, val := range atServer.Values {
if val.Key.Token.Text == key {
return val.Value.Token.Text
}
}
return ""
}
func (api *API) mergeAPI(in *API) error {
for k, v := range in.importManager {
api.importManager[k] = v
}
if api.Syntax.Value.Format() != in.Syntax.Value.Format() {
return ast.SyntaxError(in.Syntax.Value.Pos(),
"multiple syntax value expression, expected <%s>, got <%s>",
api.Syntax.Value.Format(),
in.Syntax.Value.Format(),
)
}
api.TypeStmt = append(api.TypeStmt, in.TypeStmt...)
api.ServiceStmts = append(api.ServiceStmts, in.ServiceStmts...)
return nil
}
func (api *API) parseImportedAPI(imports []ast.ImportStmt) ([]*API, error) {
var list []*API
if len(imports) == 0 {
return list, nil
}
var importValueSet = map[string]token.Token{}
for _, imp := range imports {
switch val := imp.(type) {
case *ast.ImportLiteralStmt:
importValueSet[strings.ReplaceAll(val.Value.Token.Text, `"`, "")] = val.Value.Token
case *ast.ImportGroupStmt:
for _, v := range val.Values {
importValueSet[strings.ReplaceAll(v.Token.Text, `"`, "")] = v.Token
}
}
}
dir := filepath.Dir(api.Filename)
for impPath, tok := range importValueSet {
if !filepath.IsAbs(impPath) {
impPath = filepath.Join(dir, impPath)
}
// import cycle check
if _, ok := api.importManager[impPath]; ok {
return nil, ast.SyntaxError(tok.Position, "import circle not allowed")
} else {
api.importManager[impPath] = placeholder.PlaceHolder
}
p := New(impPath, "")
ast := p.Parse()
if err := p.CheckErrors(); err != nil {
return nil, err
}
nestedApi, err := convert2API(ast, api.importManager)
if err != nil {
return nil, err
}
if err = nestedApi.parseReverse(); err != nil {
return nil, err
}
list = append(list, nestedApi)
if err != nil {
return nil, err
}
}
return list, nil
}
func (api *API) parseReverse() error {
list, err := api.parseImportedAPI(api.importStmt)
if err != nil {
return err
}
for _, e := range list {
if err = api.mergeAPI(e); err != nil {
return err
}
}
return nil
}
func (api *API) SelfCheck() error {
if err := api.parseReverse(); err != nil {
return err
}
if err := api.checkImportStmt(); err != nil {
return err
}
if err := api.checkInfoStmt(); err != nil {
return err
}
if err := api.checkTypeStmt(); err != nil {
return err
}
if err := api.checkServiceStmt(); err != nil {
return err
}
return api.checkTypeDeclareContext()
}