package cors import ( "bufio" "net" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" ) func TestCorsHandlerWithOrigins(t *testing.T) { tests := []struct { name string origins []string reqOrigin string expect string }{ { name: "allow all origins", expect: allOrigins, }, { name: "allow one origin", origins: []string{"http://local"}, reqOrigin: "http://local", expect: "http://local", }, { name: "allow many origins", origins: []string{"http://local", "http://remote"}, reqOrigin: "http://local", expect: "http://local", }, { name: "allow all origins", reqOrigin: "http://local", expect: "*", }, { name: "allow many origins with all mark", origins: []string{"http://local", "http://remote", "*"}, reqOrigin: "http://another", expect: "http://another", }, { name: "not allow origin", origins: []string{"http://local", "http://remote"}, reqOrigin: "http://another", }, } methods := []string{ http.MethodOptions, http.MethodGet, http.MethodPost, } for _, test := range tests { for _, method := range methods { test := test t.Run(test.name+"-handler", func(t *testing.T) { r := httptest.NewRequest(method, "http://localhost", nil) r.Header.Set(originHeader, test.reqOrigin) w := httptest.NewRecorder() handler := NotAllowedHandler(nil, test.origins...) handler.ServeHTTP(w, r) if method == http.MethodOptions { assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) } else { assert.Equal(t, http.StatusNotFound, w.Result().StatusCode) } assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) }) t.Run(test.name+"-handler-custom", func(t *testing.T) { r := httptest.NewRequest(method, "http://localhost", nil) r.Header.Set(originHeader, test.reqOrigin) w := httptest.NewRecorder() handler := NotAllowedHandler(func(w http.ResponseWriter) { w.Header().Set("foo", "bar") }, test.origins...) handler.ServeHTTP(w, r) if method == http.MethodOptions { assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) } else { assert.Equal(t, http.StatusNotFound, w.Result().StatusCode) } assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) assert.Equal(t, "bar", w.Header().Get("foo")) }) } } for _, test := range tests { for _, method := range methods { test := test t.Run(test.name+"-middleware", func(t *testing.T) { r := httptest.NewRequest(method, "http://localhost", nil) r.Header.Set(originHeader, test.reqOrigin) w := httptest.NewRecorder() handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) handler.ServeHTTP(w, r) if method == http.MethodOptions { assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) } else { assert.Equal(t, http.StatusOK, w.Result().StatusCode) } assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) }) t.Run(test.name+"-middleware-custom", func(t *testing.T) { r := httptest.NewRequest(method, "http://localhost", nil) r.Header.Set(originHeader, test.reqOrigin) w := httptest.NewRecorder() handler := Middleware(func(header http.Header) { header.Set("foo", "bar") }, test.origins...)(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) handler.ServeHTTP(w, r) if method == http.MethodOptions { assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) } else { assert.Equal(t, http.StatusOK, w.Result().StatusCode) } assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) assert.Equal(t, "bar", w.Header().Get("foo")) }) } } } func TestGuardedResponseWriter_Flush(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) handler := NotAllowedHandler(func(w http.ResponseWriter) { w.Header().Set("X-Test", "test") w.WriteHeader(http.StatusServiceUnavailable) _, err := w.Write([]byte("content")) assert.Nil(t, err) flusher, ok := w.(http.Flusher) assert.True(t, ok) flusher.Flush() }, "foo.com") resp := httptest.NewRecorder() handler.ServeHTTP(resp, req) assert.Equal(t, http.StatusServiceUnavailable, resp.Code) assert.Equal(t, "test", resp.Header().Get("X-Test")) assert.Equal(t, "content", resp.Body.String()) } func TestGuardedResponseWriter_Hijack(t *testing.T) { resp := httptest.NewRecorder() writer := &guardedResponseWriter{ w: resp, } assert.NotPanics(t, func() { writer.Hijack() }) writer = &guardedResponseWriter{ w: mockedHijackable{resp}, } assert.NotPanics(t, func() { writer.Hijack() }) } type mockedHijackable struct { *httptest.ResponseRecorder } func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, nil }