From 28409791fac79c0a1cd6f83089cfb50409b2295e Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Tue, 9 Nov 2021 20:35:57 +0800 Subject: [PATCH] feat: support CORS, better implementation (#1217) * feat: support CORS, better implementation * chore: refine code --- rest/internal/cors/handlers.go | 57 +++++++++++++++++++++++------ rest/internal/cors/handlers_test.go | 41 ++++++++++++++++----- rest/server.go | 2 +- 3 files changed, 78 insertions(+), 22 deletions(-) diff --git a/rest/internal/cors/handlers.go b/rest/internal/cors/handlers.go index 3be0bd44..f0b4d4fc 100644 --- a/rest/internal/cors/handlers.go +++ b/rest/internal/cors/handlers.go @@ -9,19 +9,23 @@ const ( allowHeaders = "Access-Control-Allow-Headers" allowCredentials = "Access-Control-Allow-Credentials" exposeHeaders = "Access-Control-Expose-Headers" + requestMethod = "Access-Control-Request-Method" + requestHeaders = "Access-Control-Request-Headers" allowHeadersVal = "Content-Type, Origin, X-CSRF-Token, Authorization, AccessToken, Token, Range" exposeHeadersVal = "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers" methods = "GET, HEAD, POST, PATCH, PUT, DELETE" allowTrue = "true" maxAgeHeader = "Access-Control-Max-Age" maxAgeHeaderVal = "86400" + varyHeader = "Vary" + originHeader = "Origin" ) -// Handler handles cross domain not allowed requests. +// NotAllowedHandler handles cross domain not allowed requests. // At most one origin can be specified, other origins are ignored if given, default to be *. -func Handler(origin ...string) http.Handler { +func NotAllowedHandler(origins ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - setHeader(w, getOrigin(origin)) + checkAndSetHeaders(w, r, origins) if r.Method != http.MethodOptions { w.WriteHeader(http.StatusNotFound) @@ -32,10 +36,10 @@ func Handler(origin ...string) http.Handler { } // Middleware returns a middleware that adds CORS headers to the response. -func Middleware(origin ...string) func(http.HandlerFunc) http.HandlerFunc { +func Middleware(origins ...string) func(http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - setHeader(w, getOrigin(origin)) + checkAndSetHeaders(w, r, origins) if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) @@ -46,12 +50,32 @@ func Middleware(origin ...string) func(http.HandlerFunc) http.HandlerFunc { } } -func getOrigin(origins []string) string { - if len(origins) > 0 { - return origins[0] - } else { - return allOrigins +func checkAndSetHeaders(w http.ResponseWriter, r *http.Request, origins []string) { + setVaryHeaders(w, r) + + if len(origins) == 0 { + setHeader(w, allOrigins) + return + } + + origin := r.Header.Get(originHeader) + if isOriginAllowed(origins, origin) { + setHeader(w, origin) + } +} + +func isOriginAllowed(allows []string, origin string) bool { + for _, o := range allows { + if o == allOrigins { + return true + } + + if o == origin { + return true + } } + + return false } func setHeader(w http.ResponseWriter, origin string) { @@ -59,6 +83,17 @@ func setHeader(w http.ResponseWriter, origin string) { w.Header().Set(allowMethods, methods) w.Header().Set(allowHeaders, allowHeadersVal) w.Header().Set(exposeHeaders, exposeHeadersVal) - w.Header().Set(allowCredentials, allowTrue) + if origin != allOrigins { + w.Header().Set(allowCredentials, allowTrue) + } w.Header().Set(maxAgeHeader, maxAgeHeaderVal) } + +func setVaryHeaders(w http.ResponseWriter, r *http.Request) { + header := w.Header() + header.Add(varyHeader, originHeader) + if r.Method == http.MethodOptions { + header.Add(varyHeader, requestMethod) + header.Add(varyHeader, requestHeaders) + } +} diff --git a/rest/internal/cors/handlers_test.go b/rest/internal/cors/handlers_test.go index 1e6b8420..03052b29 100644 --- a/rest/internal/cors/handlers_test.go +++ b/rest/internal/cors/handlers_test.go @@ -10,23 +10,42 @@ import ( func TestCorsHandlerWithOrigins(t *testing.T) { tests := []struct { - name string - origins []string - expect string + name string + origins []string + reqOrigin string + expect string }{ { name: "allow all origins", expect: allOrigins, }, { - name: "allow one origin", - origins: []string{"local"}, - expect: "local", + name: "allow one origin", + origins: []string{"http://local"}, + reqOrigin: "http://local", + expect: "http://local", }, { - name: "allow many origins", - origins: []string{"local", "remote"}, - expect: "local", + name: "allow many origins", + origins: []string{"http://local", "http://remote"}, + reqOrigin: "http://local", + expect: "http://local", + }, + { + name: "allow all origins", + reqOrigin: "http://local", + expect: "*", + }, + { + name: "allow many origins with all mark", + origins: []string{"http://local", "http://remote", "*"}, + reqOrigin: "http://another", + expect: "http://another", + }, + { + name: "not allow origin", + origins: []string{"http://local", "http://remote"}, + reqOrigin: "http://another", }, } @@ -41,8 +60,9 @@ func TestCorsHandlerWithOrigins(t *testing.T) { test := test t.Run(test.name+"-handler", func(t *testing.T) { r := httptest.NewRequest(method, "http://localhost", nil) + r.Header.Set(originHeader, test.reqOrigin) w := httptest.NewRecorder() - handler := Handler(test.origins...) + handler := NotAllowedHandler(test.origins...) handler.ServeHTTP(w, r) if method == http.MethodOptions { assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) @@ -59,6 +79,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) { test := test t.Run(test.name+"-middleware", func(t *testing.T) { r := httptest.NewRequest(method, "http://localhost", nil) + r.Header.Set(originHeader, test.reqOrigin) w := httptest.NewRecorder() handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/rest/server.go b/rest/server.go index 60221d6b..e847ca71 100644 --- a/rest/server.go +++ b/rest/server.go @@ -99,7 +99,7 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { // WithCors returns a func to enable CORS for given origin, or default to all origins (*). func WithCors(origin ...string) RunOption { return func(server *Server) { - server.router.SetNotAllowedHandler(cors.Handler(origin...)) + server.router.SetNotAllowedHandler(cors.NotAllowedHandler(origin...)) server.Use(cors.Middleware(origin...)) } }