diff --git a/tools/goctl/api/parser/parser_test.go b/tools/goctl/api/gogen/gen_test.go similarity index 66% rename from tools/goctl/api/parser/parser_test.go rename to tools/goctl/api/gogen/gen_test.go index 34e5c180..5ba5c3e1 100644 --- a/tools/goctl/api/parser/parser_test.go +++ b/tools/goctl/api/gogen/gen_test.go @@ -1,11 +1,15 @@ -package parser +package gogen import ( + goformat "go/format" "io/ioutil" "os" + "path/filepath" + "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/tools/goctl/api/parser" ) const testApiTemplate = ` @@ -137,13 +141,50 @@ service A-api { } ` +const apiJwt = ` +type Request struct { + Name string ` + "`" + `path:"name,options=you|me"` + "`" + ` +} + +type Response struct { + Message string ` + "`" + `json:"message"` + "`" + ` +} + +@server( + jwt: Auth +) +service A-api { + @handler GreetHandler + get /greet/from/:name(Request) returns (Response) +} +` + +const apiJwtWithMiddleware = ` +type Request struct { + Name string ` + "`" + `path:"name,options=you|me"` + "`" + ` +} + +type Response struct { + Message string ` + "`" + `json:"message"` + "`" + ` +} + +@server( + jwt: Auth + middleware: TokenValidate +) +service A-api { + @handler GreetHandler + get /greet/from/:name(Request) returns (Response) +} +` + func TestParser(t *testing.T) { filename := "greet.api" err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm) assert.Nil(t, err) defer os.Remove(filename) - parser, err := NewParser(filename) + parser, err := parser.NewParser(filename) assert.Nil(t, err) api, err := parser.Parse() @@ -157,6 +198,8 @@ func TestParser(t *testing.T) { assert.Equal(t, api.Service.Routes[1].RequestType.Name, "Request") assert.Equal(t, api.Service.Routes[1].ResponseType.Name, "") + + validate(t, filename) } func TestMultiService(t *testing.T) { @@ -165,7 +208,7 @@ func TestMultiService(t *testing.T) { assert.Nil(t, err) defer os.Remove(filename) - parser, err := NewParser(filename) + parser, err := parser.NewParser(filename) assert.Nil(t, err) api, err := parser.Parse() @@ -173,6 +216,8 @@ func TestMultiService(t *testing.T) { assert.Equal(t, len(api.Service.Routes), 2) assert.Equal(t, len(api.Service.Groups), 2) + + validate(t, filename) } func TestApiNoInfo(t *testing.T) { @@ -181,11 +226,13 @@ func TestApiNoInfo(t *testing.T) { assert.Nil(t, err) defer os.Remove(filename) - parser, err := NewParser(filename) + parser, err := parser.NewParser(filename) assert.Nil(t, err) _, err = parser.Parse() assert.Nil(t, err) + + validate(t, filename) } func TestInvalidApiFile(t *testing.T) { @@ -194,7 +241,7 @@ func TestInvalidApiFile(t *testing.T) { assert.Nil(t, err) defer os.Remove(filename) - parser, err := NewParser(filename) + parser, err := parser.NewParser(filename) assert.Nil(t, err) _, err = parser.Parse() @@ -207,7 +254,7 @@ func TestAnonymousAnnotation(t *testing.T) { assert.Nil(t, err) defer os.Remove(filename) - parser, err := NewParser(filename) + parser, err := parser.NewParser(filename) assert.Nil(t, err) api, err := parser.Parse() @@ -215,6 +262,8 @@ func TestAnonymousAnnotation(t *testing.T) { assert.Equal(t, len(api.Service.Routes), 1) assert.Equal(t, api.Service.Routes[0].Annotations[0].Value, "GreetHandler") + + validate(t, filename) } func TestApiHasMiddleware(t *testing.T) { @@ -223,9 +272,61 @@ func TestApiHasMiddleware(t *testing.T) { assert.Nil(t, err) defer os.Remove(filename) - parser, err := NewParser(filename) + parser, err := parser.NewParser(filename) + assert.Nil(t, err) + + _, err = parser.Parse() + assert.Nil(t, err) + + validate(t, filename) +} + +func TestApiHasJwt(t *testing.T) { + filename := "jwt.api" + err := ioutil.WriteFile(filename, []byte(apiJwt), os.ModePerm) + assert.Nil(t, err) + defer os.Remove(filename) + + parser, err := parser.NewParser(filename) assert.Nil(t, err) _, err = parser.Parse() assert.Nil(t, err) + + validate(t, filename) +} + +func TestApiHasJwtAndMiddleware(t *testing.T) { + filename := "jwt.api" + err := ioutil.WriteFile(filename, []byte(apiJwtWithMiddleware), os.ModePerm) + assert.Nil(t, err) + defer os.Remove(filename) + + parser, err := parser.NewParser(filename) + assert.Nil(t, err) + + _, err = parser.Parse() + assert.Nil(t, err) + + validate(t, filename) +} + +func validate(t *testing.T, api string) { + dir := "_go" + err := DoGenProject(api, dir, true) + defer os.RemoveAll(dir) + assert.Nil(t, err) + filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if strings.HasSuffix(path, ".go") { + code, err := ioutil.ReadFile(path) + assert.Nil(t, err) + assert.Nil(t, validateCode(string(code))) + } + return nil + }) +} + +func validateCode(code string) error { + _, err := goformat.Source([]byte(code)) + return err } diff --git a/tools/goctl/api/gogen/genroutes.go b/tools/goctl/api/gogen/genroutes.go index 35a1fe61..f8303778 100644 --- a/tools/goctl/api/gogen/genroutes.go +++ b/tools/goctl/api/gogen/genroutes.go @@ -32,8 +32,8 @@ func RegisterHandlers(engine *rest.Server, serverCtx *svc.ServiceContext) { ` routesAdditionTemplate = ` engine.AddRoutes( - {{.routes}} - {{.jwt}}{{.signature}}) + {{.routes}} {{.jwt}}{{.signature}} + ) ` ) @@ -71,6 +71,7 @@ func genRoutes(dir string, api *spec.ApiSpec, force bool) error { gt := template.Must(template.New("groupTemplate").Parse(routesAdditionTemplate)) for _, g := range groups { var gbuilder strings.Builder + gbuilder.WriteString("[]rest.Route{") for _, r := range g.routes { fmt.Fprintf(&gbuilder, ` { @@ -80,26 +81,29 @@ func genRoutes(dir string, api *spec.ApiSpec, force bool) error { },`, r.method, r.path, r.handler) } + var jwt string if g.jwtEnabled { - jwt = fmt.Sprintf(", rest.WithJwt(serverCtx.Config.%s.AccessSecret)", g.authName) + jwt = fmt.Sprintf("\n rest.WithJwt(serverCtx.Config.%s.AccessSecret),", g.authName) } var signature string if g.signatureEnabled { - signature = fmt.Sprintf(", rest.WithSignature(serverCtx.Config.%s.Signature)", g.authName) + signature = fmt.Sprintf("\n rest.WithSignature(serverCtx.Config.%s.Signature),", g.authName) } var routes string if len(g.middleware) > 0 { + gbuilder.WriteString("\n}...,") var params = g.middleware for i := range params { params[i] = "serverCtx." + params[i] } var middlewareStr = strings.Join(params, ", ") - routes = fmt.Sprintf("rest.WithMiddlewares(\n[]rest.Middleware{ %s }, \n[]rest.Route{\n %s \n}...,\n),", + routes = fmt.Sprintf("rest.WithMiddlewares(\n[]rest.Middleware{ %s }, \n %s \n),", middlewareStr, strings.TrimSpace(gbuilder.String())) } else { - routes = fmt.Sprintf("[]rest.Route{\n %s \n},", strings.TrimSpace(gbuilder.String())) + gbuilder.WriteString("\n},") + routes = strings.TrimSpace(gbuilder.String()) } if err := gt.Execute(&builder, map[string]string{