You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/tools/goctl/rpc/parser/pbast.go

644 lines
15 KiB
Go

package parser
import (
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
"sort"
"strings"
"github.com/tal-tech/go-zero/core/lang"
sx "github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
flagStar = "*"
flagDot = "."
suffixServer = "Server"
referenceContext = "context"
unknownPrefix = "XXX_"
ignoreJsonTagExpression = `json:"-"`
)
var (
errorParseError = errors.New("pb parse error")
typeTemplate = `type (
{{.types}}
)`
structTemplate = `{{if .type}}type {{end}}{{.name}} struct {
{{.fields}}
}`
fieldTemplate = `{{if .hasDoc}}{{.doc}}
{{end}}{{.name}} {{.type}} {{.tag}}{{if .hasComment}}{{.comment}}{{end}}`
anyTypeTemplate = "Any struct {\n\tTypeUrl string `json:\"typeUrl\"`\n\tValue []byte `json:\"value\"`\n}"
objectM = make(map[string]*Struct)
)
type (
astParser struct {
filterStruct map[string]lang.PlaceholderType
filterEnum map[string]*Enum
console.Console
fileSet *token.FileSet
proto *Proto
}
Field struct {
Name stringx.String
Type Type
JsonTag string
Document []string
Comment []string
}
Struct struct {
Name stringx.String
Document []string
Comment []string
Field []*Field
}
ConstLit struct {
Name stringx.String
Document []string
Comment []string
Lit []*Lit
}
Lit struct {
Key string
Value int
}
Type struct {
// eg:context.Context
Expression string
// eg: *context.Context
StarExpression string
// Invoke Type Expression
InvokeTypeExpression string
// eg:context
Package string
// eg:Context
Name string
}
Func struct {
Name stringx.String
ParameterIn Type
ParameterOut Type
Document []string
}
RpcService struct {
Name stringx.String
Funcs []*Func
}
// parsing for rpc
PbAst struct {
// deprecated: containsAny will be removed in the feature
ContainsAny bool
Imports map[string]string
Structure map[string]*Struct
Service []*RpcService
*Proto
}
)
func MustNewAstParser(proto *Proto, log console.Console) *astParser {
return &astParser{
filterStruct: proto.Message,
filterEnum: proto.Enum,
Console: log,
fileSet: token.NewFileSet(),
proto: proto,
}
}
func (a *astParser) Parse() (*PbAst, error) {
var pbAst PbAst
pbAst.ContainsAny = a.proto.ContainsAny
pbAst.Proto = a.proto
pbAst.Structure = make(map[string]*Struct)
pbAst.Imports = make(map[string]string)
structure, imports, services, err := a.parse(a.proto.PbSrc)
if err != nil {
return nil, err
}
dependencyStructure, err := a.parseExternalDependency()
if err != nil {
return nil, err
}
for k, v := range structure {
pbAst.Structure[k] = v
}
for k, v := range dependencyStructure {
pbAst.Structure[k] = v
}
for key, path := range imports {
pbAst.Imports[key] = path
}
pbAst.Service = append(pbAst.Service, services...)
return &pbAst, nil
}
func (a *astParser) parse(pbSrc string) (structure map[string]*Struct, imports map[string]string, services []*RpcService, retErr error) {
structure = make(map[string]*Struct)
imports = make(map[string]string)
data, err := ioutil.ReadFile(pbSrc)
if err != nil {
retErr = err
return
}
fSet := a.fileSet
f, err := parser.ParseFile(fSet, "", data, parser.ParseComments)
if err != nil {
retErr = err
return
}
commentMap := ast.NewCommentMap(fSet, f, f.Comments)
f.Comments = commentMap.Filter(f).Comments()
strucs, function := a.mustScope(f.Scope, a.mustGetIndentName(f.Name))
for k, v := range strucs {
if v == nil {
continue
}
structure[k] = v
}
importList := f.Imports
for _, item := range importList {
name := a.mustGetIndentName(item.Name)
if item.Path != nil {
imports[name] = item.Path.Value
}
}
services = append(services, function...)
return
}
func (a *astParser) parseExternalDependency() (map[string]*Struct, error) {
m := make(map[string]*Struct)
for _, impo := range a.proto.Import {
ret, _, _, err := a.parse(impo.OriginalPbPath)
if err != nil {
return nil, err
}
for k, v := range ret {
m[k] = v
}
}
return m, nil
}
func (a *astParser) mustScope(scope *ast.Scope, sourcePackage string) (map[string]*Struct, []*RpcService) {
if scope == nil {
return nil, nil
}
objects := scope.Objects
structs := make(map[string]*Struct)
serviceList := make([]*RpcService, 0)
for name, obj := range objects {
decl := obj.Decl
if decl == nil {
continue
}
typeSpec, ok := decl.(*ast.TypeSpec)
if !ok {
continue
}
tp := typeSpec.Type
switch v := tp.(type) {
case *ast.StructType:
st, err := a.parseObject(name, v, sourcePackage)
a.Must(err)
structs[st.Name.Lower()] = st
case *ast.InterfaceType:
if !strings.HasSuffix(name, suffixServer) {
continue
}
list := a.mustServerFunctions(v, sourcePackage)
serviceList = append(serviceList, &RpcService{
Name: stringx.From(strings.TrimSuffix(name, suffixServer)),
Funcs: list,
})
}
}
targetStruct := make(map[string]*Struct)
for st := range a.filterStruct {
lower := strings.ToLower(st)
targetStruct[lower] = structs[lower]
}
return targetStruct, serviceList
}
func (a *astParser) mustServerFunctions(v *ast.InterfaceType, sourcePackage string) []*Func {
funcs := make([]*Func, 0)
methodObject := v.Methods
if methodObject == nil {
return nil
}
for _, method := range methodObject.List {
var item Func
name := a.mustGetIndentName(method.Names[0])
doc := a.parseCommentOrDoc(method.Doc)
item.Name = stringx.From(name)
item.Document = doc
types := method.Type
if types == nil {
funcs = append(funcs, &item)
continue
}
v, ok := types.(*ast.FuncType)
if !ok {
continue
}
params := v.Params
if params != nil {
inList, err := a.parseFields(params.List, true, sourcePackage)
a.Must(err)
for _, data := range inList {
if data.Type.Package == referenceContext {
continue
}
item.ParameterIn = data.Type
break
}
}
results := v.Results
if results != nil {
outList, err := a.parseFields(results.List, true, sourcePackage)
a.Must(err)
for _, data := range outList {
if data.Type.Package == referenceContext {
continue
}
item.ParameterOut = data.Type
break
}
}
funcs = append(funcs, &item)
}
return funcs
}
func (a *astParser) getFieldType(v string, sourcePackage string) Type {
var pkg, name, expression, starExpression, invokeTypeExpression string
if strings.Contains(v, ".") {
starExpression = v
if strings.Contains(v, "*") {
leftIndex := strings.Index(v, "*")
rightIndex := strings.Index(v, ".")
if leftIndex >= 0 {
invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
} else {
invokeTypeExpression = v[rightIndex+1:]
}
} else {
if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
leftIndex := strings.Index(v, "]")
rightIndex := strings.Index(v, ".")
invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
} else {
rightIndex := strings.Index(v, ".")
invokeTypeExpression = v[rightIndex+1:]
}
}
} else {
expression = strings.TrimPrefix(v, flagStar)
switch v {
case "double", "float", "int32", "int64", "uint32", "uint64", "sint32", "sint64", "fixed32", "fixed64", "sfixed32", "sfixed64",
"bool", "string", "bytes":
invokeTypeExpression = v
break
default:
name = expression
invokeTypeExpression = v
if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
starExpression = strings.ReplaceAll(v, flagStar, flagStar+sourcePackage+".")
} else {
starExpression = fmt.Sprintf("*%v.%v", sourcePackage, name)
invokeTypeExpression = v
}
}
}
expression = strings.TrimPrefix(starExpression, flagStar)
index := strings.LastIndex(expression, flagDot)
if index > 0 {
pkg = expression[0:index]
name = expression[index+1:]
} else {
pkg = sourcePackage
}
return Type{
Expression: expression,
StarExpression: starExpression,
InvokeTypeExpression: invokeTypeExpression,
Package: pkg,
Name: name,
}
}
func (a *astParser) parseObject(structName string, tp *ast.StructType, sourcePackage string) (*Struct, error) {
if data, ok := objectM[structName]; ok {
return data, nil
}
var st Struct
st.Name = stringx.From(structName)
if tp == nil {
return &st, nil
}
fields := tp.Fields
if fields == nil {
objectM[structName] = &st
return &st, nil
}
fieldList := fields.List
members, err := a.parseFields(fieldList, false, sourcePackage)
if err != nil {
return nil, err
}
for _, m := range members {
var field Field
field.Name = m.Name
field.Type = m.Type
field.JsonTag = m.JsonTag
field.Document = m.Document
field.Comment = m.Comment
st.Field = append(st.Field, &field)
}
objectM[structName] = &st
return &st, nil
}
func (a *astParser) parseFields(fields []*ast.Field, onlyType bool, sourcePackage string) ([]*Field, error) {
ret := make([]*Field, 0)
for _, field := range fields {
var item Field
tag := a.parseTag(field.Tag)
if tag == "" && !onlyType {
continue
}
if tag == ignoreJsonTagExpression {
continue
}
item.JsonTag = tag
name := a.parseName(field.Names)
if strings.HasPrefix(name, unknownPrefix) {
continue
}
item.Name = stringx.From(name)
typeName, err := a.parseType(field.Type)
if err != nil {
return nil, err
}
item.Type = a.getFieldType(typeName, sourcePackage)
if onlyType {
ret = append(ret, &item)
continue
}
docs := a.parseCommentOrDoc(field.Doc)
comments := a.parseCommentOrDoc(field.Comment)
item.Document = docs
item.Comment = comments
isInline := name == ""
if isInline {
return nil, a.wrapError(field.Pos(), "unexpected inline type:%s", name)
}
ret = append(ret, &item)
}
return ret, nil
}
func (a *astParser) parseTag(basicLit *ast.BasicLit) string {
if basicLit == nil {
return ""
}
value := basicLit.Value
splits := strings.Split(value, " ")
if len(splits) == 1 {
return fmt.Sprintf("`%s`", strings.ReplaceAll(splits[0], "`", ""))
} else {
return fmt.Sprintf("`%s`", strings.ReplaceAll(splits[1], "`", ""))
}
}
// returns
// resp1:type's string expression,like int、string、[]int64、map[string]User、*User
// resp2:error
func (a *astParser) parseType(expr ast.Expr) (string, error) {
if expr == nil {
return "", errorParseError
}
switch v := expr.(type) {
case *ast.StarExpr:
stringExpr, err := a.parseType(v.X)
if err != nil {
return "", err
}
e := fmt.Sprintf("*%s", stringExpr)
return e, nil
case *ast.Ident:
return a.mustGetIndentName(v), nil
case *ast.MapType:
keyStringExpr, err := a.parseType(v.Key)
if err != nil {
return "", err
}
valueStringExpr, err := a.parseType(v.Value)
if err != nil {
return "", err
}
e := fmt.Sprintf("map[%s]%s", keyStringExpr, valueStringExpr)
return e, nil
case *ast.ArrayType:
stringExpr, err := a.parseType(v.Elt)
if err != nil {
return "", err
}
e := fmt.Sprintf("[]%s", stringExpr)
return e, nil
case *ast.InterfaceType:
return "interface{}", nil
case *ast.SelectorExpr:
join := make([]string, 0)
xIdent, ok := v.X.(*ast.Ident)
xIndentName := a.mustGetIndentName(xIdent)
if ok {
join = append(join, xIndentName)
}
sel := v.Sel
join = append(join, a.mustGetIndentName(sel))
return strings.Join(join, "."), nil
case *ast.ChanType:
return "", a.wrapError(v.Pos(), "unexpected type 'chan'")
case *ast.FuncType:
return "", a.wrapError(v.Pos(), "unexpected type 'func'")
case *ast.StructType:
return "", a.wrapError(v.Pos(), "unexpected inline struct type")
default:
return "", a.wrapError(v.Pos(), "unexpected type '%v'", v)
}
}
func (a *astParser) parseName(names []*ast.Ident) string {
if len(names) == 0 {
return ""
}
name := names[0]
return a.mustGetIndentName(name)
}
func (a *astParser) parseCommentOrDoc(cg *ast.CommentGroup) []string {
if cg == nil {
return nil
}
comments := make([]string, 0)
for _, comment := range cg.List {
if comment == nil {
continue
}
text := strings.TrimSpace(comment.Text)
if text == "" {
continue
}
comments = append(comments, text)
}
return comments
}
func (a *astParser) mustGetIndentName(ident *ast.Ident) string {
if ident == nil {
return ""
}
return ident.Name
}
func (a *astParser) wrapError(pos token.Pos, format string, arg ...interface{}) error {
file := a.fileSet.Position(pos)
return fmt.Errorf("line %v: %s", file.Line, fmt.Sprintf(format, arg...))
}
func (f *Func) GetDoc() string {
return strings.Join(f.Document, util.NL)
}
func (f *Func) HaveDoc() bool {
return len(f.Document) > 0
}
func (a *PbAst) GenEnumCode() (string, error) {
var element []string
for _, item := range a.Enum {
code, err := item.GenEnumCode()
if err != nil {
return "", err
}
element = append(element, code)
}
return strings.Join(element, util.NL), nil
}
func (a *PbAst) GenTypesCode() (string, error) {
types := make([]string, 0)
sts := make([]*Struct, 0)
for _, item := range a.Structure {
sts = append(sts, item)
}
sort.Slice(sts, func(i, j int) bool {
return sts[i].Name.Source() < sts[j].Name.Source()
})
for _, s := range sts {
structCode, err := s.genCode(false)
if err != nil {
return "", err
}
if structCode == "" {
continue
}
types = append(types, structCode)
}
types = append(types, a.genAnyCode())
for _, item := range a.Enum {
typeCode, err := item.GenEnumTypeCode()
if err != nil {
return "", err
}
types = append(types, typeCode)
}
buffer, err := util.With("type").Parse(typeTemplate).Execute(map[string]interface{}{
"types": strings.Join(types, util.NL+util.NL),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}
func (a *PbAst) genAnyCode() string {
if !a.ContainsAny {
return ""
}
return anyTypeTemplate
}
func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
fields := make([]string, 0)
for _, f := range s.Field {
var comment, doc string
if len(f.Comment) > 0 {
comment = f.Comment[0]
}
doc = strings.Join(f.Document, util.NL)
buffer, err := util.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{
"name": f.Name.Title(),
"type": f.Type.InvokeTypeExpression,
"tag": f.JsonTag,
"hasDoc": len(f.Document) > 0,
"doc": doc,
"hasComment": len(f.Comment) > 0,
"comment": comment,
})
if err != nil {
return "", err
}
fields = append(fields, buffer.String())
}
buffer, err := util.With("struct").Parse(structTemplate).Execute(map[string]interface{}{
"type": containsTypeStatement,
"name": s.Name.Title(),
"fields": strings.Join(fields, util.NL),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}