From 543d5907103394ed4549254676f574897a237338 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Wed, 1 Dec 2021 17:45:48 +0800 Subject: [PATCH] fixes #987 (#1283) * fixes #987 * chore: fix test failure * chore: add comments --- tools/goctl/api/gogen/genlogic.go | 50 ++++++++++++++++--- tools/goctl/api/parser/g4/ast/service.go | 7 +-- .../goctl/api/parser/g4/test/service_test.go | 2 +- 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/tools/goctl/api/gogen/genlogic.go b/tools/goctl/api/gogen/genlogic.go index e425cf7a..873a7aaa 100644 --- a/tools/goctl/api/gogen/genlogic.go +++ b/tools/goctl/api/gogen/genlogic.go @@ -3,8 +3,10 @@ package gogen import ( "fmt" "path" + "strconv" "strings" + "github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api" "github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/config" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" @@ -64,12 +66,8 @@ func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group, var requestString string if len(route.ResponseTypeName()) > 0 { resp := responseGoTypeName(route, typesPacket) - responseString = "(" + resp + ", error)" - if strings.HasPrefix(resp, "*") { - returnString = fmt.Sprintf("return &%s{}, nil", strings.TrimPrefix(resp, "*")) - } else { - returnString = fmt.Sprintf("return %s{}, nil", resp) - } + responseString = "(resp " + resp + ", err error)" + returnString = "return" } else { responseString = "error" returnString = "return nil" @@ -116,9 +114,47 @@ func genLogicImports(route spec.Route, parentPkg string) string { var imports []string imports = append(imports, `"context"`+"\n") imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, contextDir))) - if len(route.ResponseTypeName()) > 0 || len(route.RequestTypeName()) > 0 { + if shallImportTypesPackage(route) { imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir))) } imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL)) return strings.Join(imports, "\n\t") } + +func onlyPrimitiveTypes(val string) bool { + fields := strings.FieldsFunc(val, func(r rune) bool { + return r == '[' || r == ']' || r == ' ' + }) + + for _, field := range fields { + if field == "map" { + continue + } + // ignore array dimension number, like [5]int + if _, err := strconv.Atoi(field); err == nil { + continue + } + if !api.IsBasicType(field) { + return false + } + } + + return true +} + +func shallImportTypesPackage(route spec.Route) bool { + if len(route.RequestTypeName()) > 0 { + return true + } + + respTypeName := route.ResponseTypeName() + if len(respTypeName) == 0 { + return false + } + + if onlyPrimitiveTypes(respTypeName) { + return false + } + + return true +} diff --git a/tools/goctl/api/parser/g4/ast/service.go b/tools/goctl/api/parser/g4/ast/service.go index db5c5369..e3eb8855 100644 --- a/tools/goctl/api/parser/g4/ast/service.go +++ b/tools/goctl/api/parser/g4/ast/service.go @@ -267,11 +267,8 @@ func (v *ApiVisitor) VisitReplybody(ctx *api.ReplybodyContext) interface{} { } case *Literal: lit := dataType.Literal.Text() - if api.IsGolangKeyWord(dataType.Literal.Text()) { - v.panic(dataType.Literal, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", dataType.Literal.Text())) - } - if api.IsBasicType(lit) { - v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text())) + if api.IsGolangKeyWord(lit) { + v.panic(dataType.Literal, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit)) } default: v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text())) diff --git a/tools/goctl/api/parser/g4/test/service_test.go b/tools/goctl/api/parser/g4/test/service_test.go index 1d64a111..366b2e69 100644 --- a/tools/goctl/api/parser/g4/test/service_test.go +++ b/tools/goctl/api/parser/g4/test/service_test.go @@ -174,7 +174,7 @@ func TestRoute(t *testing.T) { assert.Error(t, err) _, err = parser.Accept(fn, ` post /foo/bar returns (int)`) - assert.Error(t, err) + assert.Nil(t, err) _, err = parser.Accept(fn, ` post /foo/bar returns (*int)`) assert.Error(t, err)