can only specify one origin in cors

master
kevin 4 years ago
parent 96cb7af728
commit 9c8f31cf83

@ -1,26 +1,24 @@
package rest
import (
"net/http"
"strings"
)
import "net/http"
const (
allowOrigin = "Access-Control-Allow-Origin"
allOrigin = "*"
allOrigins = "*"
allowMethods = "Access-Control-Allow-Methods"
allowHeaders = "Access-Control-Allow-Headers"
headers = "Content-Type, Content-Length, Origin"
methods = "GET, HEAD, POST, PATCH, PUT, DELETE"
separator = ", "
)
// CorsHandler handles cross domain OPTIONS requests.
// At most one origin can be specified, other origins are ignored if given.
func CorsHandler(origins ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if len(origins) > 0 {
w.Header().Set(allowOrigin, strings.Join(origins, separator))
w.Header().Set(allowOrigin, origins[0])
} else {
w.Header().Set(allowOrigin, allOrigin)
w.Header().Set(allowOrigin, allOrigins)
}
w.Header().Set(allowMethods, methods)
w.Header().Set(allowHeaders, headers)

@ -3,25 +3,40 @@ package rest
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCorsHandler(t *testing.T) {
w := httptest.NewRecorder()
handler := CorsHandler()
handler.ServeHTTP(w, nil)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
assert.Equal(t, allOrigin, w.Header().Get(allowOrigin))
}
func TestCorsHandlerWithOrigins(t *testing.T) {
origins := []string{"local", "remote"}
w := httptest.NewRecorder()
handler := CorsHandler(origins...)
handler.ServeHTTP(w, nil)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
assert.Equal(t, strings.Join(origins, separator), w.Header().Get(allowOrigin))
tests := []struct {
name string
origins []string
expect string
}{
{
name: "allow all origins",
expect: allOrigins,
},
{
name: "allow one origin",
origins: []string{"local"},
expect: "local",
},
{
name: "allow many origins",
origins: []string{"local", "remote"},
expect: "local",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
w := httptest.NewRecorder()
handler := CorsHandler(test.origins...)
handler.ServeHTTP(w, nil)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
})
}
}

Loading…
Cancel
Save