diff --git a/rest/chain/chain.go b/rest/chain/chain.go new file mode 100644 index 00000000..79d5fc60 --- /dev/null +++ b/rest/chain/chain.go @@ -0,0 +1,109 @@ +package chain + +// This is a modified version of https://github.com/justinas/alice +// The original code is licensed under the MIT license. +// It's modified for couple reasons: +// - Added the Chain interface +// - Added support for the Chain.Prepend(...) method + +import "net/http" + +type ( + // Chain defines a chain of middleware. + Chain interface { + Append(middlewares ...Middleware) Chain + Prepend(middlewares ...Middleware) Chain + Then(h http.Handler) http.Handler + ThenFunc(fn http.HandlerFunc) http.Handler + } + + // Middleware is an HTTP middleware. + Middleware func(http.Handler) http.Handler + + // chain acts as a list of http.Handler middlewares. + // chain is effectively immutable: + // once created, it will always hold + // the same set of middlewares in the same order. + chain struct { + middlewares []Middleware + } +) + +// New creates a new Chain, memorizing the given list of middleware middlewares. +// New serves no other function, middlewares are only called upon a call to Then() or ThenFunc(). +func New(middlewares ...Middleware) Chain { + return chain{middlewares: append(([]Middleware)(nil), middlewares...)} +} + +// Append extends a chain, adding the specified middlewares as the last ones in the request flow. +// +// c := chain.New(m1, m2) +// c.Append(m3, m4) +// // requests in c go m1 -> m2 -> m3 -> m4 +func (c chain) Append(middlewares ...Middleware) Chain { + return chain{middlewares: join(c.middlewares, middlewares)} +} + +// Prepend extends a chain by adding the specified chain as the first one in the request flow. +// +// c := chain.New(m3, m4) +// c1 := chain.New(m1, m2) +// c.Prepend(c1) +// // requests in c go m1 -> m2 -> m3 -> m4 +func (c chain) Prepend(middlewares ...Middleware) Chain { + return chain{middlewares: join(middlewares, c.middlewares)} +} + +// Then chains the middleware and returns the final http.Handler. +// New(m1, m2, m3).Then(h) +// is equivalent to: +// m1(m2(m3(h))) +// When the request comes in, it will be passed to m1, then m2, then m3 +// and finally, the given handler +// (assuming every middleware calls the following one). +// +// A chain can be safely reused by calling Then() several times. +// stdStack := chain.New(ratelimitHandler, csrfHandler) +// indexPipe = stdStack.Then(indexHandler) +// authPipe = stdStack.Then(authHandler) +// Note that middlewares are called on every call to Then() or ThenFunc() +// and thus several instances of the same middleware will be created +// when a chain is reused in this way. +// For proper middleware, this should cause no problems. +// +// Then() treats nil as http.DefaultServeMux. +func (c chain) Then(h http.Handler) http.Handler { + if h == nil { + h = http.DefaultServeMux + } + + for i := range c.middlewares { + h = c.middlewares[len(c.middlewares)-1-i](h) + } + + return h +} + +// ThenFunc works identically to Then, but takes +// a HandlerFunc instead of a Handler. +// +// The following two statements are equivalent: +// c.Then(http.HandlerFunc(fn)) +// c.ThenFunc(fn) +// +// ThenFunc provides all the guarantees of Then. +func (c chain) ThenFunc(fn http.HandlerFunc) http.Handler { + // This nil check cannot be removed due to the "nil is not nil" common mistake in Go. + // Required due to: https://stackoverflow.com/questions/33426977/how-to-golang-check-a-variable-is-nil + if fn == nil { + return c.Then(nil) + } + return c.Then(fn) +} + +func join(a, b []Middleware) []Middleware { + mids := make([]Middleware, 0, len(a)+len(b)) + mids = append(mids, a...) + mids = append(mids, b...) + return mids +} diff --git a/rest/chain/chain_test.go b/rest/chain/chain_test.go new file mode 100644 index 00000000..f5973c46 --- /dev/null +++ b/rest/chain/chain_test.go @@ -0,0 +1,126 @@ +package chain + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +// A constructor for middleware +// that writes its own "tag" into the RW and does nothing else. +// Useful in checking if a chain is behaving in the right order. +func tagMiddleware(tag string) Middleware { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(tag)) + h.ServeHTTP(w, r) + }) + } +} + +// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer), +// but the best we can do. +func funcsEqual(f1, f2 interface{}) bool { + val1 := reflect.ValueOf(f1) + val2 := reflect.ValueOf(f2) + return val1.Pointer() == val2.Pointer() +} + +var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("app\n")) +}) + +func TestNew(t *testing.T) { + c1 := func(h http.Handler) http.Handler { + return nil + } + + c2 := func(h http.Handler) http.Handler { + return http.StripPrefix("potato", nil) + } + + slice := []Middleware{c1, c2} + c := New(slice...) + for k := range slice { + assert.True(t, funcsEqual(c.(chain).middlewares[k], slice[k]), + "New does not add constructors correctly") + } +} + +func TestThenWorksWithNoMiddleware(t *testing.T) { + assert.True(t, funcsEqual(New().Then(testApp), testApp), + "Then does not work with no middleware") +} + +func TestThenTreatsNilAsDefaultServeMux(t *testing.T) { + assert.Equal(t, http.DefaultServeMux, New().Then(nil), + "Then does not treat nil as DefaultServeMux") +} + +func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) { + assert.Equal(t, http.DefaultServeMux, New().ThenFunc(nil), + "ThenFunc does not treat nil as DefaultServeMux") +} + +func TestThenFuncConstructsHandlerFunc(t *testing.T) { + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + }) + chained := New().ThenFunc(fn) + rec := httptest.NewRecorder() + + chained.ServeHTTP(rec, (*http.Request)(nil)) + + assert.Equal(t, reflect.TypeOf((http.HandlerFunc)(nil)), reflect.TypeOf(chained), + "ThenFunc does not construct HandlerFunc") +} + +func TestThenOrdersHandlersCorrectly(t *testing.T) { + t1 := tagMiddleware("t1\n") + t2 := tagMiddleware("t2\n") + t3 := tagMiddleware("t3\n") + + chained := New(t1, t2, t3).Then(testApp) + + w := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + chained.ServeHTTP(w, r) + + assert.Equal(t, "t1\nt2\nt3\napp\n", w.Body.String(), + "Then does not order handlers correctly") +} + +func TestAppendAddsHandlersCorrectly(t *testing.T) { + c := New(tagMiddleware("t1\n"), tagMiddleware("t2\n")) + c = c.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n")) + h := c.Then(testApp) + + w := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/", nil) + assert.Nil(t, err) + + h.ServeHTTP(w, r) + assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(), + "Append does not add handlers correctly") +} + +func TestExtendAddsHandlersCorrectly(t *testing.T) { + c := New(tagMiddleware("t3\n"), tagMiddleware("t4\n")) + c = c.Prepend(tagMiddleware("t1\n"), tagMiddleware("t2\n")) + h := c.Then(testApp) + + w := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/", nil) + assert.Nil(t, err) + + h.ServeHTTP(w, r) + assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(), + "Extend does not add handlers in correctly") +} diff --git a/rest/engine.go b/rest/engine.go index 7cdc7249..5209693c 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -8,10 +8,10 @@ import ( "sort" "time" - "github.com/justinas/alice" "github.com/zeromicro/go-zero/core/codec" "github.com/zeromicro/go-zero/core/load" "github.com/zeromicro/go-zero/core/stat" + "github.com/zeromicro/go-zero/rest/chain" "github.com/zeromicro/go-zero/rest/handler" "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/internal" @@ -25,15 +25,15 @@ const topCpuUsage = 1000 var ErrSignatureConfig = errors.New("bad config for Signature") type engine struct { - conf RestConf - routes []featuredRoutes - unauthorizedCallback handler.UnauthorizedCallback - unsignedCallback handler.UnsignedCallback - disableDefaultMiddlewares bool - middlewares []Middleware - shedder load.Shedder - priorityShedder load.Shedder - tlsConfig *tls.Config + conf RestConf + routes []featuredRoutes + unauthorizedCallback handler.UnauthorizedCallback + unsignedCallback handler.UnsignedCallback + chain chain.Chain + middlewares []Middleware + shedder load.Shedder + priorityShedder load.Shedder + tlsConfig *tls.Config } func newEngine(c RestConf) *engine { @@ -53,20 +53,20 @@ func (ng *engine) addRoutes(r featuredRoutes) { ng.routes = append(ng.routes, r) } -func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain, - verifier func(alice.Chain) alice.Chain) alice.Chain { +func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain, + verifier func(chain.Chain) chain.Chain) chain.Chain { if fr.jwt.enabled { if len(fr.jwt.prevSecret) == 0 { - chain = chain.Append(handler.Authorize(fr.jwt.secret, + chn = chn.Append(handler.Authorize(fr.jwt.secret, handler.WithUnauthorizedCallback(ng.unauthorizedCallback))) } else { - chain = chain.Append(handler.Authorize(fr.jwt.secret, + chn = chn.Append(handler.Authorize(fr.jwt.secret, handler.WithPrevSecret(fr.jwt.prevSecret), handler.WithUnauthorizedCallback(ng.unauthorizedCallback))) } } - return verifier(chain) + return verifier(chn) } func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error { @@ -85,10 +85,10 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met } func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, - route Route, verifier func(chain alice.Chain) alice.Chain) error { - var chain alice.Chain - if !ng.disableDefaultMiddlewares { - chain = alice.New( + route Route, verifier func(chain.Chain) chain.Chain) error { + chn := ng.chain + if chn == nil { + chn = chain.New( handler.TracingHandler(ng.conf.Name, route.Path), ng.getLogHandler(), handler.PrometheusHandler(route.Path), @@ -103,11 +103,12 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta ) } + chn = ng.appendAuthHandler(fr, chn, verifier) + for _, middleware := range ng.middlewares { - chain = chain.Append(convertMiddleware(middleware)) + chn = chn.Append(convertMiddleware(middleware)) } - chain = ng.appendAuthHandler(fr, chain, verifier) - handle := chain.ThenFunc(route.Handler) + handle := chn.ThenFunc(route.Handler) return router.Handle(route.Method, route.Path, handle) } @@ -171,16 +172,16 @@ func (ng *engine) getShedder(priority bool) load.Shedder { // notFoundHandler returns a middleware that handles 404 not found requests. func (ng *engine) notFoundHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - chain := alice.New( + chn := chain.New( handler.TracingHandler(ng.conf.Name, ""), ng.getLogHandler(), ) var h http.Handler if next != nil { - h = chain.Then(next) + h = chn.Then(next) } else { - h = chain.Then(http.NotFoundHandler()) + h = chn.Then(http.NotFoundHandler()) } cw := response.NewHeaderOnceResponseWriter(w) @@ -218,10 +219,10 @@ func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) { ng.unsignedCallback = callback } -func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) { +func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) { if !signature.enabled { - return func(chain alice.Chain) alice.Chain { - return chain + return func(chn chain.Chain) chain.Chain { + return chn }, nil } @@ -230,8 +231,8 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic return nil, ErrSignatureConfig } - return func(chain alice.Chain) alice.Chain { - return chain + return func(chn chain.Chain) chain.Chain { + return chn }, nil } @@ -247,14 +248,13 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic decrypters[fingerprint] = decrypter } - return func(chain alice.Chain) alice.Chain { + return func(chn chain.Chain) chain.Chain { if ng.unsignedCallback != nil { - return chain.Append(handler.ContentSecurityHandler( + return chn.Append(handler.ContentSecurityHandler( decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback)) } - return chain.Append(handler.ContentSecurityHandler( - decrypters, signature.Expiry, signature.Strict)) + return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict)) }, nil } diff --git a/rest/engine_test.go b/rest/engine_test.go index 39686b5f..b8633d88 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -229,46 +229,6 @@ func TestEngine_checkedMaxBytes(t *testing.T) { } } -func TestEngine_checkedChain(t *testing.T) { - var called int32 - middleware1 := func() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt32(&called, 1) - next.ServeHTTP(w, r) - atomic.AddInt32(&called, 1) - }) - } - } - middleware2 := func() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt32(&called, 1) - next.ServeHTTP(w, r) - atomic.AddInt32(&called, 1) - }) - } - } - - server := MustNewServer(RestConf{}, DisableDefaultMiddlewares()) - server.Use(ToMiddleware(middleware1())) - server.Use(ToMiddleware(middleware2())) - server.router = chainRouter{} - server.AddRoutes( - []Route{ - { - Method: http.MethodGet, - Path: "/", - Handler: func(_ http.ResponseWriter, _ *http.Request) { - atomic.AddInt32(&called, 1) - }, - }, - }, - ) - server.ngin.bindRoutes(chainRouter{}) - assert.Equal(t, int32(5), atomic.LoadInt32(&called)) -} - func TestEngine_notFoundHandler(t *testing.T) { logx.Disable() @@ -374,7 +334,7 @@ type mockedRouter struct{} func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { } -func (m mockedRouter) Handle(_, _ string, _ http.Handler) error { +func (m mockedRouter) Handle(_, _ string, handler http.Handler) error { return errors.New("foo") } @@ -383,19 +343,3 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) { func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) { } - -type chainRouter struct{} - -func (c chainRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { -} - -func (c chainRouter) Handle(_, _ string, handler http.Handler) error { - handler.ServeHTTP(nil, nil) - return nil -} - -func (c chainRouter) SetNotFoundHandler(_ http.Handler) { -} - -func (c chainRouter) SetNotAllowedHandler(_ http.Handler) { -} diff --git a/rest/server.go b/rest/server.go index 5544d449..b0debcac 100644 --- a/rest/server.go +++ b/rest/server.go @@ -8,6 +8,7 @@ import ( "time" "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/rest/chain" "github.com/zeromicro/go-zero/rest/handler" "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/internal/cors" @@ -95,13 +96,6 @@ func (s *Server) Use(middleware Middleware) { s.ngin.use(middleware) } -// DisableDefaultMiddlewares returns a RunOption that disables the builtin middlewares. -func DisableDefaultMiddlewares() RunOption { - return func(svr *Server) { - svr.ngin.disableDefaultMiddlewares = true - } -} - // ToMiddleware converts the given handler to a Middleware. func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { return func(handle http.HandlerFunc) http.HandlerFunc { @@ -109,6 +103,14 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { } } +// WithChain returns a RunOption that uses the given chain to replace the default chain. +// JWT auth middleware and the middlewares that added by svr.Use() will be appended. +func WithChain(chn chain.Chain) RunOption { + return func(svr *Server) { + svr.ngin.chain = chn + } +} + // WithCors returns a func to enable CORS for given origin, or default to all origins (*). func WithCors(origin ...string) RunOption { return func(server *Server) { diff --git a/rest/server_test.go b/rest/server_test.go index e4c1ef39..0f4af3f5 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "os" "strings" + "sync/atomic" "testing" "time" @@ -16,6 +17,7 @@ import ( "github.com/zeromicro/go-zero/core/conf" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/service" + "github.com/zeromicro/go-zero/rest/chain" "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/router" ) @@ -435,3 +437,44 @@ func TestValidateSecret(t *testing.T) { validateSecret("short") }) } + +func TestServer_WithChain(t *testing.T) { + var called int32 + middleware1 := func() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&called, 1) + next.ServeHTTP(w, r) + atomic.AddInt32(&called, 1) + }) + } + } + middleware2 := func() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&called, 1) + next.ServeHTTP(w, r) + atomic.AddInt32(&called, 1) + }) + } + } + + server := MustNewServer(RestConf{}, WithChain(chain.New(middleware1(), middleware2()))) + server.AddRoutes( + []Route{ + { + Method: http.MethodGet, + Path: "/", + Handler: func(_ http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&called, 1) + }, + }, + }, + ) + rt := router.NewRouter() + assert.Nil(t, server.ngin.bindRoutes(rt)) + req, err := http.NewRequest(http.MethodGet, "/", nil) + assert.Nil(t, err) + rt.ServeHTTP(httptest.NewRecorder(), req) + assert.Equal(t, int32(5), atomic.LoadInt32(&called)) +}