You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/tools/goctl/api/parser/g4/ast/apiparser.go

447 lines
11 KiB
Go

package ast
import (
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"github.com/antlr/antlr4/runtime/Go/antlr"
"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
)
type (
Parser struct {
linePrefix string
debug bool
log console.Console
antlr.DefaultErrorListener
}
ParserOption func(p *Parser)
)
func NewParser(options ...ParserOption) *Parser {
p := &Parser{
log: console.NewColorConsole(),
}
for _, opt := range options {
opt(p)
}
return p
}
// Accept can parse any terminalNode of api tree by fn.
// -- for debug
func (p *Parser) Accept(fn func(p *api.ApiParserParser, visitor *ApiVisitor) interface{}, content string) (v interface{}, err error) {
defer func() {
p := recover()
if p != nil {
switch e := p.(type) {
case error:
err = e
default:
err = fmt.Errorf("%+v", p)
}
}
}()
inputStream := antlr.NewInputStream(content)
lexer := api.NewApiParserLexer(inputStream)
lexer.RemoveErrorListeners()
tokens := antlr.NewCommonTokenStream(lexer, antlr.LexerDefaultTokenChannel)
apiParser := api.NewApiParserParser(tokens)
apiParser.RemoveErrorListeners()
apiParser.AddErrorListener(p)
var visitorOptions []VisitorOption
visitorOptions = append(visitorOptions, WithVisitorPrefix(p.linePrefix))
if p.debug {
visitorOptions = append(visitorOptions, WithVisitorDebug())
}
visitor := NewApiVisitor(visitorOptions...)
v = fn(apiParser, visitor)
return
}
// Parse is used to parse the api from the specified file name
func (p *Parser) Parse(filename string) (*Api, error) {
data, err := p.readContent(filename)
if err != nil {
return nil, err
}
return p.parse(filename, data)
}
// ParseContent is used to parse the api from the specified content
func (p *Parser) ParseContent(content string) (*Api, error) {
return p.parse("", content)
}
// parse is used to parse api from the content
// filename is only used to mark the file where the error is located
func (p *Parser) parse(filename, content string) (*Api, error) {
root, err := p.invoke(filename, content)
if err != nil {
return nil, err
}
var apiAstList []*Api
apiAstList = append(apiAstList, root)
for _, imp := range root.Import {
path := imp.Value.Text()
data, err := p.readContent(path)
if err != nil {
return nil, err
}
nestedApi, err := p.invoke(path, data)
if err != nil {
return nil, err
}
err = p.valid(root, nestedApi)
if err != nil {
return nil, err
}
apiAstList = append(apiAstList, nestedApi)
}
err = p.checkTypeDeclaration(apiAstList)
if err != nil {
return nil, err
}
allApi := p.memberFill(apiAstList)
return allApi, nil
}
func (p *Parser) invoke(linePrefix, content string) (v *Api, err error) {
defer func() {
p := recover()
if p != nil {
switch e := p.(type) {
case error:
err = e
default:
err = fmt.Errorf("%+v", p)
}
}
}()
if linePrefix != "" {
p.linePrefix = linePrefix
}
inputStream := antlr.NewInputStream(content)
lexer := api.NewApiParserLexer(inputStream)
lexer.RemoveErrorListeners()
tokens := antlr.NewCommonTokenStream(lexer, antlr.LexerDefaultTokenChannel)
apiParser := api.NewApiParserParser(tokens)
apiParser.RemoveErrorListeners()
apiParser.AddErrorListener(p)
var visitorOptions []VisitorOption
visitorOptions = append(visitorOptions, WithVisitorPrefix(p.linePrefix))
if p.debug {
visitorOptions = append(visitorOptions, WithVisitorDebug())
}
visitor := NewApiVisitor(visitorOptions...)
v = apiParser.Api().Accept(visitor).(*Api)
v.LinePrefix = p.linePrefix
return
}
func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
err := p.nestedApiCheck(mainApi, nestedApi)
if err != nil {
return err
}
mainHandlerMap := make(map[string]PlaceHolder)
mainRouteMap := make(map[string]PlaceHolder)
mainTypeMap := make(map[string]PlaceHolder)
routeMap := func(list []*ServiceRoute) (map[string]PlaceHolder, map[string]PlaceHolder) {
handlerMap := make(map[string]PlaceHolder)
routeMap := make(map[string]PlaceHolder)
for _, g := range list {
handler := g.GetHandler()
if handler.IsNotNil() {
var handlerName = handler.Text()
handlerMap[handlerName] = Holder
path := fmt.Sprintf("%s://%s", g.Route.Method.Text(), g.Route.Path.Text())
routeMap[path] = Holder
}
}
return handlerMap, routeMap
}
for _, each := range mainApi.Service {
h, r := routeMap(each.ServiceApi.ServiceRoute)
for k, v := range h {
mainHandlerMap[k] = v
}
for k, v := range r {
mainRouteMap[k] = v
}
}
for _, each := range mainApi.Type {
mainTypeMap[each.NameExpr().Text()] = Holder
}
// duplicate route check
err = p.duplicateRouteCheck(nestedApi, mainHandlerMap, mainRouteMap)
if err != nil {
return err
}
// duplicate type check
for _, each := range nestedApi.Type {
if _, ok := mainTypeMap[each.NameExpr().Text()]; ok {
return fmt.Errorf("%s line %d:%d duplicate type declaration '%s'",
nestedApi.LinePrefix, each.NameExpr().Line(), each.NameExpr().Column(), each.NameExpr().Text())
}
}
return nil
}
func (p *Parser) duplicateRouteCheck(nestedApi *Api, mainHandlerMap map[string]PlaceHolder, mainRouteMap map[string]PlaceHolder) error {
for _, each := range nestedApi.Service {
for _, r := range each.ServiceApi.ServiceRoute {
handler := r.GetHandler()
if !handler.IsNotNil() {
return fmt.Errorf("%s handler not exist near line %d", nestedApi.LinePrefix, r.Route.Method.Line())
}
if _, ok := mainHandlerMap[handler.Text()]; ok {
return fmt.Errorf("%s line %d:%d duplicate handler '%s'",
nestedApi.LinePrefix, handler.Line(), handler.Column(), handler.Text())
}
path := fmt.Sprintf("%s://%s", r.Route.Method.Text(), r.Route.Path.Text())
if _, ok := mainRouteMap[path]; ok {
return fmt.Errorf("%s line %d:%d duplicate route '%s'",
nestedApi.LinePrefix, r.Route.Method.Line(), r.Route.Method.Column(), r.Route.Method.Text()+" "+r.Route.Path.Text())
}
}
}
return nil
}
func (p *Parser) nestedApiCheck(mainApi *Api, nestedApi *Api) error {
if len(nestedApi.Import) > 0 {
importToken := nestedApi.Import[0].Import
return fmt.Errorf("%s line %d:%d the nested api does not support import",
nestedApi.LinePrefix, importToken.Line(), importToken.Column())
}
if mainApi.Syntax != nil && nestedApi.Syntax != nil {
if mainApi.Syntax.Version.Text() != nestedApi.Syntax.Version.Text() {
syntaxToken := nestedApi.Syntax.Syntax
return fmt.Errorf("%s line %d:%d multiple syntax declaration, expecting syntax '%s', but found '%s'",
nestedApi.LinePrefix, syntaxToken.Line(), syntaxToken.Column(), mainApi.Syntax.Version.Text(), nestedApi.Syntax.Version.Text())
}
}
if len(mainApi.Service) > 0 {
mainService := mainApi.Service[0]
for _, service := range nestedApi.Service {
if mainService.ServiceApi.Name.Text() != service.ServiceApi.Name.Text() {
return fmt.Errorf("%s multiple service name declaration, expecting service name '%s', but found '%s'",
nestedApi.LinePrefix, mainService.ServiceApi.Name.Text(), service.ServiceApi.Name.Text())
}
}
}
return nil
}
func (p *Parser) memberFill(apiList []*Api) *Api {
var root Api
for index, each := range apiList {
if index == 0 {
root.Syntax = each.Syntax
root.Info = each.Info
root.Import = each.Import
}
root.Type = append(root.Type, each.Type...)
root.Service = append(root.Service, each.Service...)
}
return &root
}
// checkTypeDeclaration checks whether a struct type has been declared in context
func (p *Parser) checkTypeDeclaration(apiList []*Api) error {
types := make(map[string]TypeExpr)
for _, root := range apiList {
for _, each := range root.Type {
types[each.NameExpr().Text()] = each
}
}
for _, apiItem := range apiList {
linePrefix := apiItem.LinePrefix
err := p.checkTypes(apiItem, linePrefix, types)
if err != nil {
return err
}
err = p.checkServices(apiItem, types, linePrefix)
if err != nil {
return err
}
}
return nil
}
func (p *Parser) checkServices(apiItem *Api, types map[string]TypeExpr, linePrefix string) error {
for _, service := range apiItem.Service {
for _, each := range service.ServiceApi.ServiceRoute {
route := each.Route
err := p.checkRequestBody(route, types, linePrefix)
if err != nil {
return err
}
if route.Reply != nil && route.Reply.Name.IsNotNil() && route.Reply.Name.Expr().IsNotNil() {
reply := route.Reply.Name
var structName string
switch tp := reply.(type) {
case *Literal:
structName = tp.Literal.Text()
case *Array:
switch innerTp := tp.Literal.(type) {
case *Literal:
structName = innerTp.Literal.Text()
case *Pointer:
structName = innerTp.Name.Text()
}
}
if api.IsBasicType(structName) {
continue
}
_, ok := types[structName]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, route.Reply.Name.Expr().Line(), route.Reply.Name.Expr().Column(), structName)
}
}
}
}
return nil
}
func (p *Parser) checkRequestBody(route *Route, types map[string]TypeExpr, linePrefix string) error {
if route.Req != nil && route.Req.Name.IsNotNil() && route.Req.Name.Expr().IsNotNil() {
_, ok := types[route.Req.Name.Expr().Text()]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, route.Req.Name.Expr().Line(), route.Req.Name.Expr().Column(), route.Req.Name.Expr().Text())
}
}
return nil
}
func (p *Parser) checkTypes(apiItem *Api, linePrefix string, types map[string]TypeExpr) error {
for _, each := range apiItem.Type {
tp, ok := each.(*TypeStruct)
if !ok {
continue
}
for _, member := range tp.Fields {
err := p.checkType(linePrefix, types, member.DataType)
if err != nil {
return err
}
}
}
return nil
}
func (p *Parser) checkType(linePrefix string, types map[string]TypeExpr, expr DataType) error {
if expr == nil {
return nil
}
switch v := expr.(type) {
case *Literal:
name := v.Literal.Text()
if api.IsBasicType(name) {
return nil
}
_, ok := types[name]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, v.Literal.Line(), v.Literal.Column(), name)
}
case *Pointer:
name := v.Name.Text()
if api.IsBasicType(name) {
return nil
}
_, ok := types[name]
if !ok {
return fmt.Errorf("%s line %d:%d can not found declaration '%s' in context",
linePrefix, v.Name.Line(), v.Name.Column(), name)
}
case *Map:
return p.checkType(linePrefix, types, v.Value)
case *Array:
return p.checkType(linePrefix, types, v.Literal)
default:
return nil
}
return nil
}
func (p *Parser) readContent(filename string) (string, error) {
filename = strings.ReplaceAll(filename, `"`, "")
abs, err := filepath.Abs(filename)
if err != nil {
return "", err
}
data, err := ioutil.ReadFile(abs)
if err != nil {
return "", err
}
return string(data), nil
}
func (p *Parser) SyntaxError(_ antlr.Recognizer, _ interface{}, line, column int, msg string, _ antlr.RecognitionException) {
str := fmt.Sprintf(`%s line %d:%d %s`, p.linePrefix, line, column, msg)
if p.debug {
p.log.Error(str)
}
panic(str)
}
func WithParserDebug() ParserOption {
return func(p *Parser) {
p.debug = true
}
}
func WithParserPrefix(prefix string) ParserOption {
return func(p *Parser) {
p.linePrefix = prefix
}
}