diff --git a/rest/internal/cors/handlers.go b/rest/internal/cors/handlers.go index 6b27ef6c..4547094a 100644 --- a/rest/internal/cors/handlers.go +++ b/rest/internal/cors/handlers.go @@ -1,6 +1,11 @@ package cors -import "net/http" +import ( + "bufio" + "errors" + "net" + "net/http" +) const ( allowOrigin = "Access-Control-Allow-Origin" @@ -25,15 +30,16 @@ const ( // At most one origin can be specified, other origins are ignored if given, default to be *. func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - checkAndSetHeaders(w, r, origins) + gw := &guardedResponseWriter{w: w} + checkAndSetHeaders(gw, r, origins) if fn != nil { - fn(w) + fn(gw) } - if r.Method != http.MethodOptions { - w.WriteHeader(http.StatusNotFound) + if r.Method == http.MethodOptions { + gw.WriteHeader(http.StatusNoContent) } else { - w.WriteHeader(http.StatusNoContent) + gw.WriteHeader(http.StatusNotFound) } }) } @@ -56,6 +62,44 @@ func Middleware(fn func(w http.ResponseWriter), origins ...string) func(http.Han } } +type guardedResponseWriter struct { + w http.ResponseWriter + wroteHeader bool +} + +func (w *guardedResponseWriter) Flush() { + if flusher, ok := w.w.(http.Flusher); ok { + flusher.Flush() + } +} + +func (w *guardedResponseWriter) Header() http.Header { + return w.w.Header() +} + +// Hijack implements the http.Hijacker interface. +// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. +func (w *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacked, ok := w.w.(http.Hijacker); ok { + return hijacked.Hijack() + } + + return nil, nil, errors.New("server doesn't support hijacking") +} + +func (w *guardedResponseWriter) Write(bytes []byte) (int, error) { + return w.w.Write(bytes) +} + +func (w *guardedResponseWriter) WriteHeader(code int) { + if w.wroteHeader { + return + } + + w.w.WriteHeader(code) + w.wroteHeader = true +} + func checkAndSetHeaders(w http.ResponseWriter, r *http.Request, origins []string) { setVaryHeaders(w, r) diff --git a/rest/internal/cors/handlers_test.go b/rest/internal/cors/handlers_test.go index 0e112b93..f5e8e017 100644 --- a/rest/internal/cors/handlers_test.go +++ b/rest/internal/cors/handlers_test.go @@ -1,6 +1,8 @@ package cors import ( + "bufio" + "net" "net/http" "net/http/httptest" "testing" @@ -129,3 +131,48 @@ func TestCorsHandlerWithOrigins(t *testing.T) { } } } + +func TestGuardedResponseWriter_Flush(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + handler := NotAllowedHandler(func(w http.ResponseWriter) { + w.Header().Set("X-Test", "test") + w.WriteHeader(http.StatusServiceUnavailable) + _, err := w.Write([]byte("content")) + assert.Nil(t, err) + + flusher, ok := w.(http.Flusher) + assert.True(t, ok) + flusher.Flush() + }, "foo.com") + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) + assert.Equal(t, "test", resp.Header().Get("X-Test")) + assert.Equal(t, "content", resp.Body.String()) +} + +func TestGuardedResponseWriter_Hijack(t *testing.T) { + resp := httptest.NewRecorder() + writer := &guardedResponseWriter{ + w: resp, + } + assert.NotPanics(t, func() { + writer.Hijack() + }) + + writer = &guardedResponseWriter{ + w: mockedHijackable{resp}, + } + assert.NotPanics(t, func() { + writer.Hijack() + }) +} + +type mockedHijackable struct { + *httptest.ResponseRecorder +} + +func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, nil +}