feat: rest.WithChain to replace builtin middlewares (#2033)

* feat: rest.WithChain to replace builtin middlewares

* chore: add comments

* chore: refine code
master
Kevin Wan 2 years ago committed by GitHub
parent 50f16e2892
commit 47c49de94e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,109 @@
package chain
// This is a modified version of https://github.com/justinas/alice
// The original code is licensed under the MIT license.
// It's modified for couple reasons:
// - Added the Chain interface
// - Added support for the Chain.Prepend(...) method
import "net/http"
type (
// Chain defines a chain of middleware.
Chain interface {
Append(middlewares ...Middleware) Chain
Prepend(middlewares ...Middleware) Chain
Then(h http.Handler) http.Handler
ThenFunc(fn http.HandlerFunc) http.Handler
}
// Middleware is an HTTP middleware.
Middleware func(http.Handler) http.Handler
// chain acts as a list of http.Handler middlewares.
// chain is effectively immutable:
// once created, it will always hold
// the same set of middlewares in the same order.
chain struct {
middlewares []Middleware
}
)
// New creates a new Chain, memorizing the given list of middleware middlewares.
// New serves no other function, middlewares are only called upon a call to Then() or ThenFunc().
func New(middlewares ...Middleware) Chain {
return chain{middlewares: append(([]Middleware)(nil), middlewares...)}
}
// Append extends a chain, adding the specified middlewares as the last ones in the request flow.
//
// c := chain.New(m1, m2)
// c.Append(m3, m4)
// // requests in c go m1 -> m2 -> m3 -> m4
func (c chain) Append(middlewares ...Middleware) Chain {
return chain{middlewares: join(c.middlewares, middlewares)}
}
// Prepend extends a chain by adding the specified chain as the first one in the request flow.
//
// c := chain.New(m3, m4)
// c1 := chain.New(m1, m2)
// c.Prepend(c1)
// // requests in c go m1 -> m2 -> m3 -> m4
func (c chain) Prepend(middlewares ...Middleware) Chain {
return chain{middlewares: join(middlewares, c.middlewares)}
}
// Then chains the middleware and returns the final http.Handler.
// New(m1, m2, m3).Then(h)
// is equivalent to:
// m1(m2(m3(h)))
// When the request comes in, it will be passed to m1, then m2, then m3
// and finally, the given handler
// (assuming every middleware calls the following one).
//
// A chain can be safely reused by calling Then() several times.
// stdStack := chain.New(ratelimitHandler, csrfHandler)
// indexPipe = stdStack.Then(indexHandler)
// authPipe = stdStack.Then(authHandler)
// Note that middlewares are called on every call to Then() or ThenFunc()
// and thus several instances of the same middleware will be created
// when a chain is reused in this way.
// For proper middleware, this should cause no problems.
//
// Then() treats nil as http.DefaultServeMux.
func (c chain) Then(h http.Handler) http.Handler {
if h == nil {
h = http.DefaultServeMux
}
for i := range c.middlewares {
h = c.middlewares[len(c.middlewares)-1-i](h)
}
return h
}
// ThenFunc works identically to Then, but takes
// a HandlerFunc instead of a Handler.
//
// The following two statements are equivalent:
// c.Then(http.HandlerFunc(fn))
// c.ThenFunc(fn)
//
// ThenFunc provides all the guarantees of Then.
func (c chain) ThenFunc(fn http.HandlerFunc) http.Handler {
// This nil check cannot be removed due to the "nil is not nil" common mistake in Go.
// Required due to: https://stackoverflow.com/questions/33426977/how-to-golang-check-a-variable-is-nil
if fn == nil {
return c.Then(nil)
}
return c.Then(fn)
}
func join(a, b []Middleware) []Middleware {
mids := make([]Middleware, 0, len(a)+len(b))
mids = append(mids, a...)
mids = append(mids, b...)
return mids
}

@ -0,0 +1,126 @@
package chain
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
// A constructor for middleware
// that writes its own "tag" into the RW and does nothing else.
// Useful in checking if a chain is behaving in the right order.
func tagMiddleware(tag string) Middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(tag))
h.ServeHTTP(w, r)
})
}
}
// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
// but the best we can do.
func funcsEqual(f1, f2 interface{}) bool {
val1 := reflect.ValueOf(f1)
val2 := reflect.ValueOf(f2)
return val1.Pointer() == val2.Pointer()
}
var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("app\n"))
})
func TestNew(t *testing.T) {
c1 := func(h http.Handler) http.Handler {
return nil
}
c2 := func(h http.Handler) http.Handler {
return http.StripPrefix("potato", nil)
}
slice := []Middleware{c1, c2}
c := New(slice...)
for k := range slice {
assert.True(t, funcsEqual(c.(chain).middlewares[k], slice[k]),
"New does not add constructors correctly")
}
}
func TestThenWorksWithNoMiddleware(t *testing.T) {
assert.True(t, funcsEqual(New().Then(testApp), testApp),
"Then does not work with no middleware")
}
func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
assert.Equal(t, http.DefaultServeMux, New().Then(nil),
"Then does not treat nil as DefaultServeMux")
}
func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
assert.Equal(t, http.DefaultServeMux, New().ThenFunc(nil),
"ThenFunc does not treat nil as DefaultServeMux")
}
func TestThenFuncConstructsHandlerFunc(t *testing.T) {
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
})
chained := New().ThenFunc(fn)
rec := httptest.NewRecorder()
chained.ServeHTTP(rec, (*http.Request)(nil))
assert.Equal(t, reflect.TypeOf((http.HandlerFunc)(nil)), reflect.TypeOf(chained),
"ThenFunc does not construct HandlerFunc")
}
func TestThenOrdersHandlersCorrectly(t *testing.T) {
t1 := tagMiddleware("t1\n")
t2 := tagMiddleware("t2\n")
t3 := tagMiddleware("t3\n")
chained := New(t1, t2, t3).Then(testApp)
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
chained.ServeHTTP(w, r)
assert.Equal(t, "t1\nt2\nt3\napp\n", w.Body.String(),
"Then does not order handlers correctly")
}
func TestAppendAddsHandlersCorrectly(t *testing.T) {
c := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
c = c.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
h := c.Then(testApp)
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
assert.Nil(t, err)
h.ServeHTTP(w, r)
assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(),
"Append does not add handlers correctly")
}
func TestExtendAddsHandlersCorrectly(t *testing.T) {
c := New(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
c = c.Prepend(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
h := c.Then(testApp)
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
assert.Nil(t, err)
h.ServeHTTP(w, r)
assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(),
"Extend does not add handlers in correctly")
}

@ -8,10 +8,10 @@ import (
"sort" "sort"
"time" "time"
"github.com/justinas/alice"
"github.com/zeromicro/go-zero/core/codec" "github.com/zeromicro/go-zero/core/codec"
"github.com/zeromicro/go-zero/core/load" "github.com/zeromicro/go-zero/core/load"
"github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/rest/chain"
"github.com/zeromicro/go-zero/rest/handler" "github.com/zeromicro/go-zero/rest/handler"
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/rest/internal" "github.com/zeromicro/go-zero/rest/internal"
@ -29,7 +29,7 @@ type engine struct {
routes []featuredRoutes routes []featuredRoutes
unauthorizedCallback handler.UnauthorizedCallback unauthorizedCallback handler.UnauthorizedCallback
unsignedCallback handler.UnsignedCallback unsignedCallback handler.UnsignedCallback
disableDefaultMiddlewares bool chain chain.Chain
middlewares []Middleware middlewares []Middleware
shedder load.Shedder shedder load.Shedder
priorityShedder load.Shedder priorityShedder load.Shedder
@ -53,20 +53,20 @@ func (ng *engine) addRoutes(r featuredRoutes) {
ng.routes = append(ng.routes, r) ng.routes = append(ng.routes, r)
} }
func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain, func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
verifier func(alice.Chain) alice.Chain) alice.Chain { verifier func(chain.Chain) chain.Chain) chain.Chain {
if fr.jwt.enabled { if fr.jwt.enabled {
if len(fr.jwt.prevSecret) == 0 { if len(fr.jwt.prevSecret) == 0 {
chain = chain.Append(handler.Authorize(fr.jwt.secret, chn = chn.Append(handler.Authorize(fr.jwt.secret,
handler.WithUnauthorizedCallback(ng.unauthorizedCallback))) handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
} else { } else {
chain = chain.Append(handler.Authorize(fr.jwt.secret, chn = chn.Append(handler.Authorize(fr.jwt.secret,
handler.WithPrevSecret(fr.jwt.prevSecret), handler.WithPrevSecret(fr.jwt.prevSecret),
handler.WithUnauthorizedCallback(ng.unauthorizedCallback))) handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
} }
} }
return verifier(chain) return verifier(chn)
} }
func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error { func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
@ -85,10 +85,10 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met
} }
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
route Route, verifier func(chain alice.Chain) alice.Chain) error { route Route, verifier func(chain.Chain) chain.Chain) error {
var chain alice.Chain chn := ng.chain
if !ng.disableDefaultMiddlewares { if chn == nil {
chain = alice.New( chn = chain.New(
handler.TracingHandler(ng.conf.Name, route.Path), handler.TracingHandler(ng.conf.Name, route.Path),
ng.getLogHandler(), ng.getLogHandler(),
handler.PrometheusHandler(route.Path), handler.PrometheusHandler(route.Path),
@ -103,11 +103,12 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
) )
} }
chn = ng.appendAuthHandler(fr, chn, verifier)
for _, middleware := range ng.middlewares { for _, middleware := range ng.middlewares {
chain = chain.Append(convertMiddleware(middleware)) chn = chn.Append(convertMiddleware(middleware))
} }
chain = ng.appendAuthHandler(fr, chain, verifier) handle := chn.ThenFunc(route.Handler)
handle := chain.ThenFunc(route.Handler)
return router.Handle(route.Method, route.Path, handle) return router.Handle(route.Method, route.Path, handle)
} }
@ -171,16 +172,16 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
// notFoundHandler returns a middleware that handles 404 not found requests. // notFoundHandler returns a middleware that handles 404 not found requests.
func (ng *engine) notFoundHandler(next http.Handler) http.Handler { func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
chain := alice.New( chn := chain.New(
handler.TracingHandler(ng.conf.Name, ""), handler.TracingHandler(ng.conf.Name, ""),
ng.getLogHandler(), ng.getLogHandler(),
) )
var h http.Handler var h http.Handler
if next != nil { if next != nil {
h = chain.Then(next) h = chn.Then(next)
} else { } else {
h = chain.Then(http.NotFoundHandler()) h = chn.Then(http.NotFoundHandler())
} }
cw := response.NewHeaderOnceResponseWriter(w) cw := response.NewHeaderOnceResponseWriter(w)
@ -218,10 +219,10 @@ func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
ng.unsignedCallback = callback ng.unsignedCallback = callback
} }
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) { func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
if !signature.enabled { if !signature.enabled {
return func(chain alice.Chain) alice.Chain { return func(chn chain.Chain) chain.Chain {
return chain return chn
}, nil }, nil
} }
@ -230,8 +231,8 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
return nil, ErrSignatureConfig return nil, ErrSignatureConfig
} }
return func(chain alice.Chain) alice.Chain { return func(chn chain.Chain) chain.Chain {
return chain return chn
}, nil }, nil
} }
@ -247,14 +248,13 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
decrypters[fingerprint] = decrypter decrypters[fingerprint] = decrypter
} }
return func(chain alice.Chain) alice.Chain { return func(chn chain.Chain) chain.Chain {
if ng.unsignedCallback != nil { if ng.unsignedCallback != nil {
return chain.Append(handler.ContentSecurityHandler( return chn.Append(handler.ContentSecurityHandler(
decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback)) decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
} }
return chain.Append(handler.ContentSecurityHandler( return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict))
decrypters, signature.Expiry, signature.Strict))
}, nil }, nil
} }

@ -229,46 +229,6 @@ func TestEngine_checkedMaxBytes(t *testing.T) {
} }
} }
func TestEngine_checkedChain(t *testing.T) {
var called int32
middleware1 := func() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&called, 1)
next.ServeHTTP(w, r)
atomic.AddInt32(&called, 1)
})
}
}
middleware2 := func() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&called, 1)
next.ServeHTTP(w, r)
atomic.AddInt32(&called, 1)
})
}
}
server := MustNewServer(RestConf{}, DisableDefaultMiddlewares())
server.Use(ToMiddleware(middleware1()))
server.Use(ToMiddleware(middleware2()))
server.router = chainRouter{}
server.AddRoutes(
[]Route{
{
Method: http.MethodGet,
Path: "/",
Handler: func(_ http.ResponseWriter, _ *http.Request) {
atomic.AddInt32(&called, 1)
},
},
},
)
server.ngin.bindRoutes(chainRouter{})
assert.Equal(t, int32(5), atomic.LoadInt32(&called))
}
func TestEngine_notFoundHandler(t *testing.T) { func TestEngine_notFoundHandler(t *testing.T) {
logx.Disable() logx.Disable()
@ -374,7 +334,7 @@ type mockedRouter struct{}
func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
} }
func (m mockedRouter) Handle(_, _ string, _ http.Handler) error { func (m mockedRouter) Handle(_, _ string, handler http.Handler) error {
return errors.New("foo") return errors.New("foo")
} }
@ -383,19 +343,3 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) { func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
} }
type chainRouter struct{}
func (c chainRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
}
func (c chainRouter) Handle(_, _ string, handler http.Handler) error {
handler.ServeHTTP(nil, nil)
return nil
}
func (c chainRouter) SetNotFoundHandler(_ http.Handler) {
}
func (c chainRouter) SetNotAllowedHandler(_ http.Handler) {
}

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/chain"
"github.com/zeromicro/go-zero/rest/handler" "github.com/zeromicro/go-zero/rest/handler"
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/rest/internal/cors" "github.com/zeromicro/go-zero/rest/internal/cors"
@ -95,13 +96,6 @@ func (s *Server) Use(middleware Middleware) {
s.ngin.use(middleware) s.ngin.use(middleware)
} }
// DisableDefaultMiddlewares returns a RunOption that disables the builtin middlewares.
func DisableDefaultMiddlewares() RunOption {
return func(svr *Server) {
svr.ngin.disableDefaultMiddlewares = true
}
}
// ToMiddleware converts the given handler to a Middleware. // ToMiddleware converts the given handler to a Middleware.
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
return func(handle http.HandlerFunc) http.HandlerFunc { return func(handle http.HandlerFunc) http.HandlerFunc {
@ -109,6 +103,14 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
} }
} }
// WithChain returns a RunOption that uses the given chain to replace the default chain.
// JWT auth middleware and the middlewares that added by svr.Use() will be appended.
func WithChain(chn chain.Chain) RunOption {
return func(svr *Server) {
svr.ngin.chain = chn
}
}
// WithCors returns a func to enable CORS for given origin, or default to all origins (*). // WithCors returns a func to enable CORS for given origin, or default to all origins (*).
func WithCors(origin ...string) RunOption { func WithCors(origin ...string) RunOption {
return func(server *Server) { return func(server *Server) {

@ -9,6 +9,7 @@ import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -16,6 +17,7 @@ import (
"github.com/zeromicro/go-zero/core/conf" "github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/service" "github.com/zeromicro/go-zero/core/service"
"github.com/zeromicro/go-zero/rest/chain"
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/rest/router" "github.com/zeromicro/go-zero/rest/router"
) )
@ -435,3 +437,44 @@ func TestValidateSecret(t *testing.T) {
validateSecret("short") validateSecret("short")
}) })
} }
func TestServer_WithChain(t *testing.T) {
var called int32
middleware1 := func() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&called, 1)
next.ServeHTTP(w, r)
atomic.AddInt32(&called, 1)
})
}
}
middleware2 := func() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&called, 1)
next.ServeHTTP(w, r)
atomic.AddInt32(&called, 1)
})
}
}
server := MustNewServer(RestConf{}, WithChain(chain.New(middleware1(), middleware2())))
server.AddRoutes(
[]Route{
{
Method: http.MethodGet,
Path: "/",
Handler: func(_ http.ResponseWriter, _ *http.Request) {
atomic.AddInt32(&called, 1)
},
},
},
)
rt := router.NewRouter()
assert.Nil(t, server.ngin.bindRoutes(rt))
req, err := http.NewRequest(http.MethodGet, "/", nil)
assert.Nil(t, err)
rt.ServeHTTP(httptest.NewRecorder(), req)
assert.Equal(t, int32(5), atomic.LoadInt32(&called))
}

Loading…
Cancel
Save