api add middleware support (#140)

* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl support import

* goctl support import

* support return ()

* revert

* refactor and rename folder to group

* remove no need

* add anonymous annotation

* optimized

* rename

* rename

* update test

* api add middleware support: usage:

@server(
    middleware: M1, M2
)

* api add middleware support: usage:

@server(
    middleware: M1, M2
)

* simple logic

* should reverse middlewares

* optimized

* optimized

* rename

Co-authored-by: kingxt <dream4kingxt@163.com>
master
kingxt 4 years ago committed by GitHub
parent c9b0ac1ee4
commit aa3c391919
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -103,6 +103,13 @@ func WithJwtTransition(secret, prevSecret string) RouteOption {
} }
} }
func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
for i := len(ms) - 1; i >= 0; i-- {
rs = WithMiddleware(ms[i], rs...)
}
return rs
}
func WithMiddleware(middleware Middleware, rs ...Route) []Route { func WithMiddleware(middleware Middleware, rs ...Route) []Route {
routes := make([]Route, len(rs)) routes := make([]Route, len(rs))

@ -68,3 +68,75 @@ func TestWithMiddleware(t *testing.T) {
"wan": "2020", "wan": "2020",
}, m) }, m)
} }
func TestMultiMiddleware(t *testing.T) {
m := make(map[string]string)
router := router.NewPatRouter()
handler := func(w http.ResponseWriter, r *http.Request) {
var v struct {
Nickname string `form:"nickname"`
Zipcode int64 `form:"zipcode"`
}
err := httpx.Parse(r, &v)
assert.Nil(t, err)
_, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
assert.Nil(t, err)
}
rs := WithMiddlewares([]Middleware{
func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var v struct {
Name string `path:"name"`
Year string `path:"year"`
}
assert.Nil(t, httpx.ParsePath(r, &v))
m[v.Name] = v.Year
next.ServeHTTP(w, r)
}
},
func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var v struct {
Name string `form:"nickname"`
Zipcode string `form:"zipcode"`
}
assert.Nil(t, httpx.ParseForm(r, &v))
assert.NotEmpty(t, m)
m[v.Name] = v.Zipcode + v.Zipcode
next.ServeHTTP(w, r)
}
},
}, Route{
Method: http.MethodGet,
Path: "/first/:name/:year",
Handler: handler,
}, Route{
Method: http.MethodGet,
Path: "/second/:name/:year",
Handler: handler,
})
urls := []string{
"http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
}
for _, route := range rs {
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler))
}
for _, url := range urls {
r, err := http.NewRequest(http.MethodGet, url, nil)
assert.Nil(t, err)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, r)
assert.Equal(t, "whatever:200000200000", rr.Body.String())
}
assert.EqualValues(t, map[string]string{
"kevin": "2017",
"wan": "2020",
"whatever": "200000200000",
}, m)
}

@ -31,9 +31,9 @@ func RegisterHandlers(engine *rest.Server, serverCtx *svc.ServiceContext) {
} }
` `
routesAdditionTemplate = ` routesAdditionTemplate = `
engine.AddRoutes([]rest.Route{ engine.AddRoutes(
{{.routes}} {{.routes}}
}{{.jwt}}{{.signature}}) {{.jwt}}{{.signature}})
` `
) )
@ -52,6 +52,7 @@ type (
jwtEnabled bool jwtEnabled bool
signatureEnabled bool signatureEnabled bool
authName string authName string
middleware []string
} }
route struct { route struct {
method string method string
@ -87,8 +88,22 @@ func genRoutes(dir string, api *spec.ApiSpec, force bool) error {
if g.signatureEnabled { if g.signatureEnabled {
signature = fmt.Sprintf(", rest.WithSignature(serverCtx.Config.%s.Signature)", g.authName) signature = fmt.Sprintf(", rest.WithSignature(serverCtx.Config.%s.Signature)", g.authName)
} }
var routes string
if len(g.middleware) > 0 {
var params = g.middleware
for i := range params {
params[i] = "serverCtx." + params[i]
}
var middlewareStr = strings.Join(params, ", ")
routes = fmt.Sprintf("rest.WithMultiMiddleware([]rest.Middleware{ %s }, []rest.Route{\n %s \n}),",
middlewareStr, strings.TrimSpace(gbuilder.String()))
} else {
routes = fmt.Sprintf("[]rest.Route{\n %s \n},", strings.TrimSpace(gbuilder.String()))
}
if err := gt.Execute(&builder, map[string]string{ if err := gt.Execute(&builder, map[string]string{
"routes": strings.TrimSpace(gbuilder.String()), "routes": routes,
"jwt": jwt, "jwt": jwt,
"signature": signature, "signature": signature,
}); err != nil { }); err != nil {
@ -185,6 +200,11 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
groupedRoutes.authName = value groupedRoutes.authName = value
groupedRoutes.jwtEnabled = true groupedRoutes.jwtEnabled = true
} }
if value, ok := apiutil.GetAnnotationValue(g.Annotations, "server", "middleware"); ok {
for _, item := range strings.Split(value, ",") {
groupedRoutes.middleware = append(groupedRoutes.middleware, item)
}
}
routes = append(routes, groupedRoutes) routes = append(routes, groupedRoutes)
} }

@ -9,16 +9,20 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/templatex"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars"
) )
const ( const (
contextFilename = "servicecontext.go" contextFilename = "servicecontext.go"
contextTemplate = `package svc contextTemplate = `package svc
import {{.configImport}} import (
{{.configImport}}
)
type ServiceContext struct { type ServiceContext struct {
Config {{.config}} Config {{.config}}
{{.middleware}}
} }
func NewServiceContext(c {{.config}}) *ServiceContext { func NewServiceContext(c {{.config}}) *ServiceContext {
@ -53,12 +57,22 @@ func genServiceContext(dir string, api *spec.ApiSpec) error {
return err return err
} }
var middlewareStr string
for _, item := range getMiddleware(api) {
middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
}
var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\"" var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
if len(middlewareStr) > 0 {
configImport += fmt.Sprintf("\n\"%s/rest\"", vars.ProjectOpenSourceUrl)
}
t := template.Must(template.New("contextTemplate").Parse(text)) t := template.Must(template.New("contextTemplate").Parse(text))
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
err = t.Execute(buffer, map[string]string{ err = t.Execute(buffer, map[string]string{
"configImport": configImport, "configImport": configImport,
"config": "config.Config", "config": "config.Config",
"middleware": middlewareStr,
}) })
if err != nil { if err != nil {
return nil return nil

@ -66,6 +66,18 @@ func getAuths(api *spec.ApiSpec) []string {
return authNames.KeysStr() return authNames.KeysStr()
} }
func getMiddleware(api *spec.ApiSpec) []string {
result := collection.NewSet()
for _, g := range api.Service.Groups {
if value, ok := util.GetAnnotationValue(g.Annotations, "server", "middleware"); ok {
for _, item := range strings.Split(value, ",") {
result.Add(strings.TrimSpace(item))
}
}
}
return result.KeysStr()
}
func formatCode(code string) string { func formatCode(code string) string {
ret, err := goformat.Source([]byte(code)) ret, err := goformat.Source([]byte(code))
if err != nil { if err != nil {

@ -119,6 +119,24 @@ service A-api {
} }
` `
const apiHasMiddleware = `
type Request struct {
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
}
type Response struct {
Message string ` + "`" + `json:"message"` + "`" + `
}
@server(
middleware: TokenValidate
)
service A-api {
@handler GreetHandler
get /greet/from/:name(Request) returns (Response)
}
`
func TestParser(t *testing.T) { func TestParser(t *testing.T) {
filename := "greet.api" filename := "greet.api"
err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm) err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm)
@ -198,3 +216,16 @@ func TestAnonymousAnnotation(t *testing.T) {
assert.Equal(t, len(api.Service.Routes), 1) assert.Equal(t, len(api.Service.Routes), 1)
assert.Equal(t, api.Service.Routes[0].Annotations[0].Value, "GreetHandler") assert.Equal(t, api.Service.Routes[0].Annotations[0].Value, "GreetHandler")
} }
func TestApiHasMiddleware(t *testing.T) {
filename := "greet.api"
err := ioutil.WriteFile(filename, []byte(apiHasMiddleware), os.ModePerm)
assert.Nil(t, err)
defer os.Remove(filename)
parser, err := NewParser(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err)
}

Loading…
Cancel
Save