chore: refactor to simplify disabling builtin middlewares (#2031)

* chore: refactor to simplify disabling builtin middlewares

* chore: rename methods
master
Kevin Wan 2 years ago committed by GitHub
parent 6976ba7e13
commit 018ca82048
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,11 +29,11 @@ type engine struct {
routes []featuredRoutes routes []featuredRoutes
unauthorizedCallback handler.UnauthorizedCallback unauthorizedCallback handler.UnauthorizedCallback
unsignedCallback handler.UnsignedCallback unsignedCallback handler.UnsignedCallback
disableDefaultMiddlewares bool
middlewares []Middleware middlewares []Middleware
shedder load.Shedder shedder load.Shedder
priorityShedder load.Shedder priorityShedder load.Shedder
tlsConfig *tls.Config tlsConfig *tls.Config
chain *alice.Chain
} }
func newEngine(c RestConf) *engine { func newEngine(c RestConf) *engine {
@ -87,7 +87,7 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
route Route, verifier func(chain alice.Chain) alice.Chain) error { route Route, verifier func(chain alice.Chain) alice.Chain) error {
var chain alice.Chain var chain alice.Chain
if ng.chain == nil { if !ng.disableDefaultMiddlewares {
chain = alice.New( chain = alice.New(
handler.TracingHandler(ng.conf.Name, route.Path), handler.TracingHandler(ng.conf.Name, route.Path),
ng.getLogHandler(), ng.getLogHandler(),
@ -101,15 +101,12 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)), handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
handler.GunzipHandler, handler.GunzipHandler,
) )
} else {
chain = *ng.chain
} }
chain = ng.appendAuthHandler(fr, chain, verifier)
for _, middleware := range ng.middlewares { for _, middleware := range ng.middlewares {
chain = chain.Append(convertMiddleware(middleware)) chain = chain.Append(convertMiddleware(middleware))
} }
chain = ng.appendAuthHandler(fr, chain, verifier)
handle := chain.ThenFunc(route.Handler) handle := chain.ThenFunc(route.Handler)
return router.Handle(route.Method, route.Path, handle) return router.Handle(route.Method, route.Path, handle)
@ -213,10 +210,6 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) {
ng.tlsConfig = cfg ng.tlsConfig = cfg
} }
func (ng *engine) setChainConfig(chain *alice.Chain) {
ng.chain = chain
}
func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) { func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
ng.unauthorizedCallback = callback ng.unauthorizedCallback = callback
} }

@ -250,7 +250,9 @@ func TestEngine_checkedChain(t *testing.T) {
} }
} }
server := MustNewServer(RestConf{}, WithChain(middleware1(), middleware2())) server := MustNewServer(RestConf{}, DisableDefaultMiddlewares())
server.Use(ToMiddleware(middleware1()))
server.Use(ToMiddleware(middleware2()))
server.router = chainRouter{} server.router = chainRouter{}
server.AddRoutes( server.AddRoutes(
[]Route{ []Route{

@ -7,7 +7,6 @@ import (
"path" "path"
"time" "time"
"github.com/justinas/alice"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/handler" "github.com/zeromicro/go-zero/rest/handler"
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
@ -96,6 +95,13 @@ func (s *Server) Use(middleware Middleware) {
s.ngin.use(middleware) s.ngin.use(middleware)
} }
// DisableDefaultMiddlewares returns a RunOption that disables the builtin middlewares.
func DisableDefaultMiddlewares() RunOption {
return func(svr *Server) {
svr.ngin.disableDefaultMiddlewares = true
}
}
// ToMiddleware converts the given handler to a Middleware. // ToMiddleware converts the given handler to a Middleware.
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
return func(handle http.HandlerFunc) http.HandlerFunc { return func(handle http.HandlerFunc) http.HandlerFunc {
@ -243,17 +249,6 @@ func WithTLSConfig(cfg *tls.Config) RunOption {
} }
} }
// WithChain returns a RunOption that with given chain config.
func WithChain(middlewares ...func(http.Handler) http.Handler) RunOption {
return func(svr *Server) {
chain := alice.New()
for _, middleware := range middlewares {
chain = chain.Append(middleware)
}
svr.ngin.setChainConfig(&chain)
}
}
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set. // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
return func(svr *Server) { return func(svr *Server) {

@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf" "github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/service"
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/rest/router" "github.com/zeromicro/go-zero/rest/router"
) )
@ -102,6 +103,18 @@ Port: 54321
} }
} }
func TestNewServerError(t *testing.T) {
_, err := NewServer(RestConf{
ServiceConf: service.ServiceConf{
Log: logx.LogConf{
// file mode, no path specified
Mode: "file",
},
},
})
assert.NotNil(t, err)
}
func TestWithMaxBytes(t *testing.T) { func TestWithMaxBytes(t *testing.T) {
const maxBytes = 1000 const maxBytes = 1000
var fr featuredRoutes var fr featuredRoutes
@ -320,6 +333,7 @@ Port: 54321
rt := router.NewRouter() rt := router.NewRouter()
svr, err := NewServer(cnf, WithRouter(rt)) svr, err := NewServer(cnf, WithRouter(rt))
assert.Nil(t, err) assert.Nil(t, err)
defer svr.Stop()
opt := WithCors("local") opt := WithCors("local")
opt(svr) opt(svr)
@ -408,3 +422,16 @@ Port: 54321
out := <-ch out := <-ch
assert.Equal(t, expect, out) assert.Equal(t, expect, out)
} }
func TestHandleError(t *testing.T) {
assert.NotPanics(t, func() {
handleError(nil)
handleError(http.ErrServerClosed)
})
}
func TestValidateSecret(t *testing.T) {
assert.Panics(t, func() {
validateSecret("short")
})
}

Loading…
Cancel
Save