diff --git a/rest/internal/cors/handlers.go b/rest/internal/cors/handlers.go index 4547094a..c613b27d 100644 --- a/rest/internal/cors/handlers.go +++ b/rest/internal/cors/handlers.go @@ -45,12 +45,12 @@ func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.H } // Middleware returns a middleware that adds CORS headers to the response. -func Middleware(fn func(w http.ResponseWriter), origins ...string) func(http.HandlerFunc) http.HandlerFunc { +func Middleware(fn func(w http.Header), origins ...string) func(http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { checkAndSetHeaders(w, r, origins) if fn != nil { - fn(w) + fn(w.Header()) } if r.Method == http.MethodOptions { diff --git a/rest/internal/cors/handlers_test.go b/rest/internal/cors/handlers_test.go index f5e8e017..047fdb98 100644 --- a/rest/internal/cors/handlers_test.go +++ b/rest/internal/cors/handlers_test.go @@ -114,8 +114,8 @@ func TestCorsHandlerWithOrigins(t *testing.T) { r := httptest.NewRequest(method, "http://localhost", nil) r.Header.Set(originHeader, test.reqOrigin) w := httptest.NewRecorder() - handler := Middleware(func(w http.ResponseWriter) { - w.Header().Set("foo", "bar") + handler := Middleware(func(header http.Header) { + header.Set("foo", "bar") }, test.origins...)(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) diff --git a/rest/server.go b/rest/server.go index dd7d5b6b..23355ca9 100644 --- a/rest/server.go +++ b/rest/server.go @@ -106,10 +106,11 @@ func WithCors(origin ...string) RunOption { // WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*), // fn lets caller customizing the response. -func WithCustomCors(fn func(http.ResponseWriter), origin ...string) RunOption { +func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(http.ResponseWriter), + origin ...string) RunOption { return func(server *Server) { - server.router.SetNotAllowedHandler(cors.NotAllowedHandler(fn, origin...)) - server.Use(cors.Middleware(fn, origin...)) + server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...)) + server.Use(cors.Middleware(middlewareFn, origin...)) } } diff --git a/rest/server_test.go b/rest/server_test.go index 1f36f916..ef814553 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -322,8 +322,10 @@ Port: 54321 srv, err := NewServer(cnf, WithRouter(rt)) assert.Nil(t, err) - opt := WithCustomCors(func(w http.ResponseWriter) { - w.Header().Set("foo", "bar") + opt := WithCustomCors(func(header http.Header) { + header.Set("foo", "bar") + }, func(w http.ResponseWriter) { + w.WriteHeader(http.StatusOK) }, "local") opt(srv) }