diff --git a/rest/server.go b/rest/server.go index 64d6f6e4..962c5999 100644 --- a/rest/server.go +++ b/rest/server.go @@ -103,6 +103,13 @@ func WithJwtTransition(secret, prevSecret string) RouteOption { } } +func WithMiddlewares(ms []Middleware, rs ...Route) []Route { + for i := len(ms) - 1; i >= 0; i-- { + rs = WithMiddleware(ms[i], rs...) + } + return rs +} + func WithMiddleware(middleware Middleware, rs ...Route) []Route { routes := make([]Route, len(rs)) diff --git a/rest/server_test.go b/rest/server_test.go index 01929454..fe7831ab 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -68,3 +68,75 @@ func TestWithMiddleware(t *testing.T) { "wan": "2020", }, m) } + +func TestMultiMiddleware(t *testing.T) { + m := make(map[string]string) + router := router.NewPatRouter() + handler := func(w http.ResponseWriter, r *http.Request) { + var v struct { + Nickname string `form:"nickname"` + Zipcode int64 `form:"zipcode"` + } + + err := httpx.Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname])) + assert.Nil(t, err) + } + rs := WithMiddlewares([]Middleware{ + func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var v struct { + Name string `path:"name"` + Year string `path:"year"` + } + assert.Nil(t, httpx.ParsePath(r, &v)) + m[v.Name] = v.Year + next.ServeHTTP(w, r) + } + }, + func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var v struct { + Name string `form:"nickname"` + Zipcode string `form:"zipcode"` + } + assert.Nil(t, httpx.ParseForm(r, &v)) + assert.NotEmpty(t, m) + m[v.Name] = v.Zipcode + v.Zipcode + next.ServeHTTP(w, r) + } + }, + }, Route{ + Method: http.MethodGet, + Path: "/first/:name/:year", + Handler: handler, + }, Route{ + Method: http.MethodGet, + Path: "/second/:name/:year", + Handler: handler, + }) + + urls := []string{ + "http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000", + "http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000", + } + for _, route := range rs { + assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler)) + } + for _, url := range urls { + r, err := http.NewRequest(http.MethodGet, url, nil) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "whatever:200000200000", rr.Body.String()) + } + + assert.EqualValues(t, map[string]string{ + "kevin": "2017", + "wan": "2020", + "whatever": "200000200000", + }, m) +} diff --git a/tools/goctl/api/gogen/genroutes.go b/tools/goctl/api/gogen/genroutes.go index bceea619..06f2069b 100644 --- a/tools/goctl/api/gogen/genroutes.go +++ b/tools/goctl/api/gogen/genroutes.go @@ -31,9 +31,9 @@ func RegisterHandlers(engine *rest.Server, serverCtx *svc.ServiceContext) { } ` routesAdditionTemplate = ` - engine.AddRoutes([]rest.Route{ + engine.AddRoutes( {{.routes}} - }{{.jwt}}{{.signature}}) + {{.jwt}}{{.signature}}) ` ) @@ -52,6 +52,7 @@ type ( jwtEnabled bool signatureEnabled bool authName string + middleware []string } route struct { method string @@ -87,8 +88,22 @@ func genRoutes(dir string, api *spec.ApiSpec, force bool) error { if g.signatureEnabled { signature = fmt.Sprintf(", rest.WithSignature(serverCtx.Config.%s.Signature)", g.authName) } + + var routes string + if len(g.middleware) > 0 { + var params = g.middleware + for i := range params { + params[i] = "serverCtx." + params[i] + } + var middlewareStr = strings.Join(params, ", ") + routes = fmt.Sprintf("rest.WithMultiMiddleware([]rest.Middleware{ %s }, []rest.Route{\n %s \n}),", + middlewareStr, strings.TrimSpace(gbuilder.String())) + } else { + routes = fmt.Sprintf("[]rest.Route{\n %s \n},", strings.TrimSpace(gbuilder.String())) + } + if err := gt.Execute(&builder, map[string]string{ - "routes": strings.TrimSpace(gbuilder.String()), + "routes": routes, "jwt": jwt, "signature": signature, }); err != nil { @@ -185,6 +200,11 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) { groupedRoutes.authName = value groupedRoutes.jwtEnabled = true } + if value, ok := apiutil.GetAnnotationValue(g.Annotations, "server", "middleware"); ok { + for _, item := range strings.Split(value, ",") { + groupedRoutes.middleware = append(groupedRoutes.middleware, item) + } + } routes = append(routes, groupedRoutes) } diff --git a/tools/goctl/api/gogen/gensvc.go b/tools/goctl/api/gogen/gensvc.go index 9efd9dd3..e71c17d1 100644 --- a/tools/goctl/api/gogen/gensvc.go +++ b/tools/goctl/api/gogen/gensvc.go @@ -9,16 +9,20 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/templatex" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/vars" ) const ( contextFilename = "servicecontext.go" contextTemplate = `package svc -import {{.configImport}} +import ( + {{.configImport}} +) type ServiceContext struct { Config {{.config}} + {{.middleware}} } func NewServiceContext(c {{.config}}) *ServiceContext { @@ -53,12 +57,22 @@ func genServiceContext(dir string, api *spec.ApiSpec) error { return err } + var middlewareStr string + for _, item := range getMiddleware(api) { + middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item) + } + var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\"" + if len(middlewareStr) > 0 { + configImport += fmt.Sprintf("\n\"%s/rest\"", vars.ProjectOpenSourceUrl) + } + t := template.Must(template.New("contextTemplate").Parse(text)) buffer := new(bytes.Buffer) err = t.Execute(buffer, map[string]string{ "configImport": configImport, "config": "config.Config", + "middleware": middlewareStr, }) if err != nil { return nil diff --git a/tools/goctl/api/gogen/util.go b/tools/goctl/api/gogen/util.go index 0e0e3ef5..6068bcee 100644 --- a/tools/goctl/api/gogen/util.go +++ b/tools/goctl/api/gogen/util.go @@ -66,6 +66,18 @@ func getAuths(api *spec.ApiSpec) []string { return authNames.KeysStr() } +func getMiddleware(api *spec.ApiSpec) []string { + result := collection.NewSet() + for _, g := range api.Service.Groups { + if value, ok := util.GetAnnotationValue(g.Annotations, "server", "middleware"); ok { + for _, item := range strings.Split(value, ",") { + result.Add(strings.TrimSpace(item)) + } + } + } + return result.KeysStr() +} + func formatCode(code string) string { ret, err := goformat.Source([]byte(code)) if err != nil { diff --git a/tools/goctl/api/parser/parser_test.go b/tools/goctl/api/parser/parser_test.go index 6df80950..34e5c180 100644 --- a/tools/goctl/api/parser/parser_test.go +++ b/tools/goctl/api/parser/parser_test.go @@ -119,6 +119,24 @@ service A-api { } ` +const apiHasMiddleware = ` +type Request struct { + Name string ` + "`" + `path:"name,options=you|me"` + "`" + ` +} + +type Response struct { + Message string ` + "`" + `json:"message"` + "`" + ` +} + +@server( + middleware: TokenValidate +) +service A-api { + @handler GreetHandler + get /greet/from/:name(Request) returns (Response) +} +` + func TestParser(t *testing.T) { filename := "greet.api" err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm) @@ -198,3 +216,16 @@ func TestAnonymousAnnotation(t *testing.T) { assert.Equal(t, len(api.Service.Routes), 1) assert.Equal(t, api.Service.Routes[0].Annotations[0].Value, "GreetHandler") } + +func TestApiHasMiddleware(t *testing.T) { + filename := "greet.api" + err := ioutil.WriteFile(filename, []byte(apiHasMiddleware), os.ModePerm) + assert.Nil(t, err) + defer os.Remove(filename) + + parser, err := NewParser(filename) + assert.Nil(t, err) + + _, err = parser.Parse() + assert.Nil(t, err) +}