diff --git a/rest/engine.go b/rest/engine.go index d7dc8221..7cdc7249 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -25,15 +25,15 @@ const topCpuUsage = 1000 var ErrSignatureConfig = errors.New("bad config for Signature") type engine struct { - conf RestConf - routes []featuredRoutes - unauthorizedCallback handler.UnauthorizedCallback - unsignedCallback handler.UnsignedCallback - middlewares []Middleware - shedder load.Shedder - priorityShedder load.Shedder - tlsConfig *tls.Config - chain *alice.Chain + conf RestConf + routes []featuredRoutes + unauthorizedCallback handler.UnauthorizedCallback + unsignedCallback handler.UnsignedCallback + disableDefaultMiddlewares bool + middlewares []Middleware + shedder load.Shedder + priorityShedder load.Shedder + tlsConfig *tls.Config } func newEngine(c RestConf) *engine { @@ -87,7 +87,7 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, route Route, verifier func(chain alice.Chain) alice.Chain) error { var chain alice.Chain - if ng.chain == nil { + if !ng.disableDefaultMiddlewares { chain = alice.New( handler.TracingHandler(ng.conf.Name, route.Path), ng.getLogHandler(), @@ -101,15 +101,12 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)), handler.GunzipHandler, ) - } else { - chain = *ng.chain } - chain = ng.appendAuthHandler(fr, chain, verifier) - for _, middleware := range ng.middlewares { chain = chain.Append(convertMiddleware(middleware)) } + chain = ng.appendAuthHandler(fr, chain, verifier) handle := chain.ThenFunc(route.Handler) return router.Handle(route.Method, route.Path, handle) @@ -213,10 +210,6 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) { ng.tlsConfig = cfg } -func (ng *engine) setChainConfig(chain *alice.Chain) { - ng.chain = chain -} - func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) { ng.unauthorizedCallback = callback } diff --git a/rest/engine_test.go b/rest/engine_test.go index 1c84bff7..39686b5f 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -250,7 +250,9 @@ func TestEngine_checkedChain(t *testing.T) { } } - server := MustNewServer(RestConf{}, WithChain(middleware1(), middleware2())) + server := MustNewServer(RestConf{}, DisableDefaultMiddlewares()) + server.Use(ToMiddleware(middleware1())) + server.Use(ToMiddleware(middleware2())) server.router = chainRouter{} server.AddRoutes( []Route{ diff --git a/rest/server.go b/rest/server.go index 68154316..5544d449 100644 --- a/rest/server.go +++ b/rest/server.go @@ -7,7 +7,6 @@ import ( "path" "time" - "github.com/justinas/alice" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/rest/handler" "github.com/zeromicro/go-zero/rest/httpx" @@ -96,6 +95,13 @@ func (s *Server) Use(middleware Middleware) { s.ngin.use(middleware) } +// DisableDefaultMiddlewares returns a RunOption that disables the builtin middlewares. +func DisableDefaultMiddlewares() RunOption { + return func(svr *Server) { + svr.ngin.disableDefaultMiddlewares = true + } +} + // ToMiddleware converts the given handler to a Middleware. func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { return func(handle http.HandlerFunc) http.HandlerFunc { @@ -243,17 +249,6 @@ func WithTLSConfig(cfg *tls.Config) RunOption { } } -// WithChain returns a RunOption that with given chain config. -func WithChain(middlewares ...func(http.Handler) http.Handler) RunOption { - return func(svr *Server) { - chain := alice.New() - for _, middleware := range middlewares { - chain = chain.Append(middleware) - } - svr.ngin.setChainConfig(&chain) - } -} - // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set. func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { return func(svr *Server) { diff --git a/rest/server_test.go b/rest/server_test.go index fba7f2eb..e4c1ef39 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/conf" "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/service" "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/router" ) @@ -102,6 +103,18 @@ Port: 54321 } } +func TestNewServerError(t *testing.T) { + _, err := NewServer(RestConf{ + ServiceConf: service.ServiceConf{ + Log: logx.LogConf{ + // file mode, no path specified + Mode: "file", + }, + }, + }) + assert.NotNil(t, err) +} + func TestWithMaxBytes(t *testing.T) { const maxBytes = 1000 var fr featuredRoutes @@ -320,6 +333,7 @@ Port: 54321 rt := router.NewRouter() svr, err := NewServer(cnf, WithRouter(rt)) assert.Nil(t, err) + defer svr.Stop() opt := WithCors("local") opt(svr) @@ -408,3 +422,16 @@ Port: 54321 out := <-ch assert.Equal(t, expect, out) } + +func TestHandleError(t *testing.T) { + assert.NotPanics(t, func() { + handleError(nil) + handleError(http.ErrServerClosed) + }) +} + +func TestValidateSecret(t *testing.T) { + assert.Panics(t, func() { + validateSecret("short") + }) +}