add user middleware chain function (#1913)

* add user middleware chain function

* fix staticcheck SA4006

* chang code Implementation style

Co-authored-by: kemq1 <kemq1@spdb.com.cn>
master
magickeha 2 years ago committed by GitHub
parent 9b6e4c440c
commit 6976ba7e13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,6 +33,7 @@ type engine struct {
shedder load.Shedder
priorityShedder load.Shedder
tlsConfig *tls.Config
chain *alice.Chain
}
func newEngine(c RestConf) *engine {
@ -85,19 +86,25 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
route Route, verifier func(chain alice.Chain) alice.Chain) error {
chain := alice.New(
handler.TracingHandler(ng.conf.Name, route.Path),
ng.getLogHandler(),
handler.PrometheusHandler(route.Path),
handler.MaxConns(ng.conf.MaxConns),
handler.BreakerHandler(route.Method, route.Path, metrics),
handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
handler.RecoverHandler,
handler.MetricHandler(metrics),
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
handler.GunzipHandler,
)
var chain alice.Chain
if ng.chain == nil {
chain = alice.New(
handler.TracingHandler(ng.conf.Name, route.Path),
ng.getLogHandler(),
handler.PrometheusHandler(route.Path),
handler.MaxConns(ng.conf.MaxConns),
handler.BreakerHandler(route.Method, route.Path, metrics),
handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
handler.RecoverHandler,
handler.MetricHandler(metrics),
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
handler.GunzipHandler,
)
} else {
chain = *ng.chain
}
chain = ng.appendAuthHandler(fr, chain, verifier)
for _, middleware := range ng.middlewares {
@ -206,6 +213,10 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) {
ng.tlsConfig = cfg
}
func (ng *engine) setChainConfig(chain *alice.Chain) {
ng.chain = chain
}
func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
ng.unauthorizedCallback = callback
}

@ -229,6 +229,44 @@ func TestEngine_checkedMaxBytes(t *testing.T) {
}
}
func TestEngine_checkedChain(t *testing.T) {
var called int32
middleware1 := func() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&called, 1)
next.ServeHTTP(w, r)
atomic.AddInt32(&called, 1)
})
}
}
middleware2 := func() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&called, 1)
next.ServeHTTP(w, r)
atomic.AddInt32(&called, 1)
})
}
}
server := MustNewServer(RestConf{}, WithChain(middleware1(), middleware2()))
server.router = chainRouter{}
server.AddRoutes(
[]Route{
{
Method: http.MethodGet,
Path: "/",
Handler: func(_ http.ResponseWriter, _ *http.Request) {
atomic.AddInt32(&called, 1)
},
},
},
)
server.ngin.bindRoutes(chainRouter{})
assert.Equal(t, int32(5), atomic.LoadInt32(&called))
}
func TestEngine_notFoundHandler(t *testing.T) {
logx.Disable()
@ -343,3 +381,19 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
}
type chainRouter struct{}
func (c chainRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
}
func (c chainRouter) Handle(_, _ string, handler http.Handler) error {
handler.ServeHTTP(nil, nil)
return nil
}
func (c chainRouter) SetNotFoundHandler(_ http.Handler) {
}
func (c chainRouter) SetNotAllowedHandler(_ http.Handler) {
}

@ -7,6 +7,7 @@ import (
"path"
"time"
"github.com/justinas/alice"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/handler"
"github.com/zeromicro/go-zero/rest/httpx"
@ -242,6 +243,17 @@ func WithTLSConfig(cfg *tls.Config) RunOption {
}
}
// WithChain returns a RunOption that with given chain config.
func WithChain(middlewares ...func(http.Handler) http.Handler) RunOption {
return func(svr *Server) {
chain := alice.New()
for _, middleware := range middlewares {
chain = chain.Append(middleware)
}
svr.ngin.setChainConfig(&chain)
}
}
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
return func(svr *Server) {

Loading…
Cancel
Save