feat: rest.WithChain to replace builtin middlewares (#2033)
* feat: rest.WithChain to replace builtin middlewares * chore: add comments * chore: refine codemaster
parent
50f16e2892
commit
47c49de94e
@ -0,0 +1,109 @@
|
|||||||
|
package chain
|
||||||
|
|
||||||
|
// This is a modified version of https://github.com/justinas/alice
|
||||||
|
// The original code is licensed under the MIT license.
|
||||||
|
// It's modified for couple reasons:
|
||||||
|
// - Added the Chain interface
|
||||||
|
// - Added support for the Chain.Prepend(...) method
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type (
|
||||||
|
// Chain defines a chain of middleware.
|
||||||
|
Chain interface {
|
||||||
|
Append(middlewares ...Middleware) Chain
|
||||||
|
Prepend(middlewares ...Middleware) Chain
|
||||||
|
Then(h http.Handler) http.Handler
|
||||||
|
ThenFunc(fn http.HandlerFunc) http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware is an HTTP middleware.
|
||||||
|
Middleware func(http.Handler) http.Handler
|
||||||
|
|
||||||
|
// chain acts as a list of http.Handler middlewares.
|
||||||
|
// chain is effectively immutable:
|
||||||
|
// once created, it will always hold
|
||||||
|
// the same set of middlewares in the same order.
|
||||||
|
chain struct {
|
||||||
|
middlewares []Middleware
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// New creates a new Chain, memorizing the given list of middleware middlewares.
|
||||||
|
// New serves no other function, middlewares are only called upon a call to Then() or ThenFunc().
|
||||||
|
func New(middlewares ...Middleware) Chain {
|
||||||
|
return chain{middlewares: append(([]Middleware)(nil), middlewares...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append extends a chain, adding the specified middlewares as the last ones in the request flow.
|
||||||
|
//
|
||||||
|
// c := chain.New(m1, m2)
|
||||||
|
// c.Append(m3, m4)
|
||||||
|
// // requests in c go m1 -> m2 -> m3 -> m4
|
||||||
|
func (c chain) Append(middlewares ...Middleware) Chain {
|
||||||
|
return chain{middlewares: join(c.middlewares, middlewares)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepend extends a chain by adding the specified chain as the first one in the request flow.
|
||||||
|
//
|
||||||
|
// c := chain.New(m3, m4)
|
||||||
|
// c1 := chain.New(m1, m2)
|
||||||
|
// c.Prepend(c1)
|
||||||
|
// // requests in c go m1 -> m2 -> m3 -> m4
|
||||||
|
func (c chain) Prepend(middlewares ...Middleware) Chain {
|
||||||
|
return chain{middlewares: join(middlewares, c.middlewares)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then chains the middleware and returns the final http.Handler.
|
||||||
|
// New(m1, m2, m3).Then(h)
|
||||||
|
// is equivalent to:
|
||||||
|
// m1(m2(m3(h)))
|
||||||
|
// When the request comes in, it will be passed to m1, then m2, then m3
|
||||||
|
// and finally, the given handler
|
||||||
|
// (assuming every middleware calls the following one).
|
||||||
|
//
|
||||||
|
// A chain can be safely reused by calling Then() several times.
|
||||||
|
// stdStack := chain.New(ratelimitHandler, csrfHandler)
|
||||||
|
// indexPipe = stdStack.Then(indexHandler)
|
||||||
|
// authPipe = stdStack.Then(authHandler)
|
||||||
|
// Note that middlewares are called on every call to Then() or ThenFunc()
|
||||||
|
// and thus several instances of the same middleware will be created
|
||||||
|
// when a chain is reused in this way.
|
||||||
|
// For proper middleware, this should cause no problems.
|
||||||
|
//
|
||||||
|
// Then() treats nil as http.DefaultServeMux.
|
||||||
|
func (c chain) Then(h http.Handler) http.Handler {
|
||||||
|
if h == nil {
|
||||||
|
h = http.DefaultServeMux
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range c.middlewares {
|
||||||
|
h = c.middlewares[len(c.middlewares)-1-i](h)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// ThenFunc works identically to Then, but takes
|
||||||
|
// a HandlerFunc instead of a Handler.
|
||||||
|
//
|
||||||
|
// The following two statements are equivalent:
|
||||||
|
// c.Then(http.HandlerFunc(fn))
|
||||||
|
// c.ThenFunc(fn)
|
||||||
|
//
|
||||||
|
// ThenFunc provides all the guarantees of Then.
|
||||||
|
func (c chain) ThenFunc(fn http.HandlerFunc) http.Handler {
|
||||||
|
// This nil check cannot be removed due to the "nil is not nil" common mistake in Go.
|
||||||
|
// Required due to: https://stackoverflow.com/questions/33426977/how-to-golang-check-a-variable-is-nil
|
||||||
|
if fn == nil {
|
||||||
|
return c.Then(nil)
|
||||||
|
}
|
||||||
|
return c.Then(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func join(a, b []Middleware) []Middleware {
|
||||||
|
mids := make([]Middleware, 0, len(a)+len(b))
|
||||||
|
mids = append(mids, a...)
|
||||||
|
mids = append(mids, b...)
|
||||||
|
return mids
|
||||||
|
}
|
@ -0,0 +1,126 @@
|
|||||||
|
package chain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A constructor for middleware
|
||||||
|
// that writes its own "tag" into the RW and does nothing else.
|
||||||
|
// Useful in checking if a chain is behaving in the right order.
|
||||||
|
func tagMiddleware(tag string) Middleware {
|
||||||
|
return func(h http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(tag))
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
|
||||||
|
// but the best we can do.
|
||||||
|
func funcsEqual(f1, f2 interface{}) bool {
|
||||||
|
val1 := reflect.ValueOf(f1)
|
||||||
|
val2 := reflect.ValueOf(f2)
|
||||||
|
return val1.Pointer() == val2.Pointer()
|
||||||
|
}
|
||||||
|
|
||||||
|
var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("app\n"))
|
||||||
|
})
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
c1 := func(h http.Handler) http.Handler {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c2 := func(h http.Handler) http.Handler {
|
||||||
|
return http.StripPrefix("potato", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
slice := []Middleware{c1, c2}
|
||||||
|
c := New(slice...)
|
||||||
|
for k := range slice {
|
||||||
|
assert.True(t, funcsEqual(c.(chain).middlewares[k], slice[k]),
|
||||||
|
"New does not add constructors correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenWorksWithNoMiddleware(t *testing.T) {
|
||||||
|
assert.True(t, funcsEqual(New().Then(testApp), testApp),
|
||||||
|
"Then does not work with no middleware")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
|
||||||
|
assert.Equal(t, http.DefaultServeMux, New().Then(nil),
|
||||||
|
"Then does not treat nil as DefaultServeMux")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
|
||||||
|
assert.Equal(t, http.DefaultServeMux, New().ThenFunc(nil),
|
||||||
|
"ThenFunc does not treat nil as DefaultServeMux")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenFuncConstructsHandlerFunc(t *testing.T) {
|
||||||
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
})
|
||||||
|
chained := New().ThenFunc(fn)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
chained.ServeHTTP(rec, (*http.Request)(nil))
|
||||||
|
|
||||||
|
assert.Equal(t, reflect.TypeOf((http.HandlerFunc)(nil)), reflect.TypeOf(chained),
|
||||||
|
"ThenFunc does not construct HandlerFunc")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThenOrdersHandlersCorrectly(t *testing.T) {
|
||||||
|
t1 := tagMiddleware("t1\n")
|
||||||
|
t2 := tagMiddleware("t2\n")
|
||||||
|
t3 := tagMiddleware("t3\n")
|
||||||
|
|
||||||
|
chained := New(t1, t2, t3).Then(testApp)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, err := http.NewRequest("GET", "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chained.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, "t1\nt2\nt3\napp\n", w.Body.String(),
|
||||||
|
"Then does not order handlers correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendAddsHandlersCorrectly(t *testing.T) {
|
||||||
|
c := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
||||||
|
c = c.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
||||||
|
h := c.Then(testApp)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, err := http.NewRequest("GET", "/", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(),
|
||||||
|
"Append does not add handlers correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtendAddsHandlersCorrectly(t *testing.T) {
|
||||||
|
c := New(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
||||||
|
c = c.Prepend(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
||||||
|
h := c.Then(testApp)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, err := http.NewRequest("GET", "/", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(),
|
||||||
|
"Extend does not add handlers in correctly")
|
||||||
|
}
|
Loading…
Reference in New Issue