can only specify one origin in cors

master
kevin 4 years ago
parent 96cb7af728
commit 9c8f31cf83

@ -1,26 +1,24 @@
package rest package rest
import ( import "net/http"
"net/http"
"strings"
)
const ( const (
allowOrigin = "Access-Control-Allow-Origin" allowOrigin = "Access-Control-Allow-Origin"
allOrigin = "*" allOrigins = "*"
allowMethods = "Access-Control-Allow-Methods" allowMethods = "Access-Control-Allow-Methods"
allowHeaders = "Access-Control-Allow-Headers" allowHeaders = "Access-Control-Allow-Headers"
headers = "Content-Type, Content-Length, Origin" headers = "Content-Type, Content-Length, Origin"
methods = "GET, HEAD, POST, PATCH, PUT, DELETE" methods = "GET, HEAD, POST, PATCH, PUT, DELETE"
separator = ", "
) )
// CorsHandler handles cross domain OPTIONS requests.
// At most one origin can be specified, other origins are ignored if given.
func CorsHandler(origins ...string) http.Handler { func CorsHandler(origins ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if len(origins) > 0 { if len(origins) > 0 {
w.Header().Set(allowOrigin, strings.Join(origins, separator)) w.Header().Set(allowOrigin, origins[0])
} else { } else {
w.Header().Set(allowOrigin, allOrigin) w.Header().Set(allowOrigin, allOrigins)
} }
w.Header().Set(allowMethods, methods) w.Header().Set(allowMethods, methods)
w.Header().Set(allowHeaders, headers) w.Header().Set(allowHeaders, headers)

@ -3,25 +3,40 @@ package rest
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestCorsHandler(t *testing.T) { func TestCorsHandlerWithOrigins(t *testing.T) {
w := httptest.NewRecorder() tests := []struct {
handler := CorsHandler() name string
handler.ServeHTTP(w, nil) origins []string
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) expect string
assert.Equal(t, allOrigin, w.Header().Get(allowOrigin)) }{
{
name: "allow all origins",
expect: allOrigins,
},
{
name: "allow one origin",
origins: []string{"local"},
expect: "local",
},
{
name: "allow many origins",
origins: []string{"local", "remote"},
expect: "local",
},
} }
func TestCorsHandlerWithOrigins(t *testing.T) { for _, test := range tests {
origins := []string{"local", "remote"} t.Run(test.name, func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := CorsHandler(origins...) handler := CorsHandler(test.origins...)
handler.ServeHTTP(w, nil) handler.ServeHTTP(w, nil)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
assert.Equal(t, strings.Join(origins, separator), w.Header().Get(allowOrigin)) assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
})
}
} }

Loading…
Cancel
Save