chore: only allow cors middleware to change headers (#1276)

master
Kevin Wan 3 years ago committed by GitHub
parent c800f6f723
commit 3dda557410
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -45,12 +45,12 @@ func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.H
} }
// Middleware returns a middleware that adds CORS headers to the response. // Middleware returns a middleware that adds CORS headers to the response.
func Middleware(fn func(w http.ResponseWriter), origins ...string) func(http.HandlerFunc) http.HandlerFunc { func Middleware(fn func(w http.Header), origins ...string) func(http.HandlerFunc) http.HandlerFunc {
return func(next http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
checkAndSetHeaders(w, r, origins) checkAndSetHeaders(w, r, origins)
if fn != nil { if fn != nil {
fn(w) fn(w.Header())
} }
if r.Method == http.MethodOptions { if r.Method == http.MethodOptions {

@ -114,8 +114,8 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil) r := httptest.NewRequest(method, "http://localhost", nil)
r.Header.Set(originHeader, test.reqOrigin) r.Header.Set(originHeader, test.reqOrigin)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := Middleware(func(w http.ResponseWriter) { handler := Middleware(func(header http.Header) {
w.Header().Set("foo", "bar") header.Set("foo", "bar")
}, test.origins...)(func(w http.ResponseWriter, r *http.Request) { }, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })

@ -106,10 +106,11 @@ func WithCors(origin ...string) RunOption {
// WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*), // WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*),
// fn lets caller customizing the response. // fn lets caller customizing the response.
func WithCustomCors(fn func(http.ResponseWriter), origin ...string) RunOption { func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(http.ResponseWriter),
origin ...string) RunOption {
return func(server *Server) { return func(server *Server) {
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(fn, origin...)) server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...))
server.Use(cors.Middleware(fn, origin...)) server.Use(cors.Middleware(middlewareFn, origin...))
} }
} }

@ -322,8 +322,10 @@ Port: 54321
srv, err := NewServer(cnf, WithRouter(rt)) srv, err := NewServer(cnf, WithRouter(rt))
assert.Nil(t, err) assert.Nil(t, err)
opt := WithCustomCors(func(w http.ResponseWriter) { opt := WithCustomCors(func(header http.Header) {
w.Header().Set("foo", "bar") header.Set("foo", "bar")
}, func(w http.ResponseWriter) {
w.WriteHeader(http.StatusOK)
}, "local") }, "local")
opt(srv) opt(srv)
} }

Loading…
Cancel
Save