diff --git a/rest/handlers.go b/rest/handlers.go index a6bffbb6..8181c9b3 100644 --- a/rest/handlers.go +++ b/rest/handlers.go @@ -1,26 +1,24 @@ package rest -import ( - "net/http" - "strings" -) +import "net/http" const ( allowOrigin = "Access-Control-Allow-Origin" - allOrigin = "*" + allOrigins = "*" allowMethods = "Access-Control-Allow-Methods" allowHeaders = "Access-Control-Allow-Headers" headers = "Content-Type, Content-Length, Origin" methods = "GET, HEAD, POST, PATCH, PUT, DELETE" - separator = ", " ) +// CorsHandler handles cross domain OPTIONS requests. +// At most one origin can be specified, other origins are ignored if given. func CorsHandler(origins ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if len(origins) > 0 { - w.Header().Set(allowOrigin, strings.Join(origins, separator)) + w.Header().Set(allowOrigin, origins[0]) } else { - w.Header().Set(allowOrigin, allOrigin) + w.Header().Set(allowOrigin, allOrigins) } w.Header().Set(allowMethods, methods) w.Header().Set(allowHeaders, headers) diff --git a/rest/handlers_test.go b/rest/handlers_test.go index 9b2dd746..366dce42 100644 --- a/rest/handlers_test.go +++ b/rest/handlers_test.go @@ -3,25 +3,40 @@ package rest import ( "net/http" "net/http/httptest" - "strings" "testing" "github.com/stretchr/testify/assert" ) -func TestCorsHandler(t *testing.T) { - w := httptest.NewRecorder() - handler := CorsHandler() - handler.ServeHTTP(w, nil) - assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) - assert.Equal(t, allOrigin, w.Header().Get(allowOrigin)) -} - func TestCorsHandlerWithOrigins(t *testing.T) { - origins := []string{"local", "remote"} - w := httptest.NewRecorder() - handler := CorsHandler(origins...) - handler.ServeHTTP(w, nil) - assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) - assert.Equal(t, strings.Join(origins, separator), w.Header().Get(allowOrigin)) + tests := []struct { + name string + origins []string + expect string + }{ + { + name: "allow all origins", + expect: allOrigins, + }, + { + name: "allow one origin", + origins: []string{"local"}, + expect: "local", + }, + { + name: "allow many origins", + origins: []string{"local", "remote"}, + expect: "local", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + w := httptest.NewRecorder() + handler := CorsHandler(test.origins...) + handler.ServeHTTP(w, nil) + assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) + assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) + }) + } }