feat: support CORS by using rest.WithCors(...) (#1212)

* feat: support CORS by using rest.WithCors(...)

* chore: add comments

* refactor: lowercase unexported methods

* ci: fix lint errors
master
Kevin Wan 3 years ago committed by GitHub
parent e8efcef108
commit c28e01fed3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,7 +14,6 @@ import (
"github.com/tal-tech/go-zero/rest/handler" "github.com/tal-tech/go-zero/rest/handler"
"github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/httpx"
"github.com/tal-tech/go-zero/rest/internal" "github.com/tal-tech/go-zero/rest/internal"
"github.com/tal-tech/go-zero/rest/router"
) )
// use 1000m to represent 100% // use 1000m to represent 100%
@ -47,39 +46,10 @@ func newEngine(c RestConf) *engine {
return srv return srv
} }
func (ng *engine) AddRoutes(r featuredRoutes) { func (ng *engine) addRoutes(r featuredRoutes) {
ng.routes = append(ng.routes, r) ng.routes = append(ng.routes, r)
} }
func (ng *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
ng.unauthorizedCallback = callback
}
func (ng *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
ng.unsignedCallback = callback
}
func (ng *engine) Start() error {
return ng.StartWithRouter(router.NewRouter())
}
func (ng *engine) StartWithRouter(router httpx.Router) error {
if err := ng.bindRoutes(router); err != nil {
return err
}
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router)
}
return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
ng.conf.KeyFile, router, func(srv *http.Server) {
if ng.tlsConfig != nil {
srv.TLSConfig = ng.tlsConfig
}
})
}
func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain, func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
verifier func(alice.Chain) alice.Chain) alice.Chain { verifier func(alice.Chain) alice.Chain) alice.Chain {
if fr.jwt.enabled { if fr.jwt.enabled {
@ -188,6 +158,14 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) {
ng.tlsConfig = cfg ng.tlsConfig = cfg
} }
func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
ng.unauthorizedCallback = callback
}
func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
ng.unsignedCallback = callback
}
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) { func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
if !signature.enabled { if !signature.enabled {
return func(chain alice.Chain) alice.Chain { return func(chain alice.Chain) alice.Chain {
@ -228,6 +206,23 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
}, nil }, nil
} }
func (ng *engine) start(router httpx.Router) error {
if err := ng.bindRoutes(router); err != nil {
return err
}
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router)
}
return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
ng.conf.KeyFile, router, func(srv *http.Server) {
if ng.tlsConfig != nil {
srv.TLSConfig = ng.tlsConfig
}
})
}
func (ng *engine) use(middleware Middleware) { func (ng *engine) use(middleware Middleware) {
ng.middlewares = append(ng.middlewares, middleware) ng.middlewares = append(ng.middlewares, middleware)
} }

@ -144,13 +144,13 @@ Verbose: true
var cnf RestConf var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(yaml), &cnf)) assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(yaml), &cnf))
ng := newEngine(cnf) ng := newEngine(cnf)
ng.AddRoutes(route) ng.addRoutes(route)
ng.use(func(next http.HandlerFunc) http.HandlerFunc { ng.use(func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} }
}) })
assert.NotNil(t, ng.StartWithRouter(mockedRouter{})) assert.NotNil(t, ng.start(mockedRouter{}))
} }
} }
} }

@ -1,27 +0,0 @@
package rest
import "net/http"
const (
allowOrigin = "Access-Control-Allow-Origin"
allOrigins = "*"
allowMethods = "Access-Control-Allow-Methods"
allowHeaders = "Access-Control-Allow-Headers"
headers = "Content-Type, Content-Length, Origin"
methods = "GET, HEAD, POST, PATCH, PUT, DELETE"
)
// CorsHandler handles cross domain OPTIONS requests.
// At most one origin can be specified, other origins are ignored if given.
func CorsHandler(origins ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if len(origins) > 0 {
w.Header().Set(allowOrigin, origins[0])
} else {
w.Header().Set(allowOrigin, allOrigins)
}
w.Header().Set(allowMethods, methods)
w.Header().Set(allowHeaders, headers)
w.WriteHeader(http.StatusNoContent)
})
}

@ -1,42 +0,0 @@
package rest
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCorsHandlerWithOrigins(t *testing.T) {
tests := []struct {
name string
origins []string
expect string
}{
{
name: "allow all origins",
expect: allOrigins,
},
{
name: "allow one origin",
origins: []string{"local"},
expect: "local",
},
{
name: "allow many origins",
origins: []string{"local", "remote"},
expect: "local",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
w := httptest.NewRecorder()
handler := CorsHandler(test.origins...)
handler.ServeHTTP(w, nil)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
})
}
}

@ -0,0 +1,64 @@
package cors
import "net/http"
const (
allowOrigin = "Access-Control-Allow-Origin"
allOrigins = "*"
allowMethods = "Access-Control-Allow-Methods"
allowHeaders = "Access-Control-Allow-Headers"
allowCredentials = "Access-Control-Allow-Credentials"
exposeHeaders = "Access-Control-Expose-Headers"
allowHeadersVal = "Content-Type, Origin, X-CSRF-Token, Authorization, AccessToken, Token, Range"
exposeHeadersVal = "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers"
methods = "GET, HEAD, POST, PATCH, PUT, DELETE"
allowTrue = "true"
maxAgeHeader = "Access-Control-Max-Age"
maxAgeHeaderVal = "86400"
)
// Handler handles cross domain not allowed requests.
// At most one origin can be specified, other origins are ignored if given, default to be *.
func Handler(origin ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
setHeader(w, getOrigin(origin))
if r.Method != http.MethodOptions {
w.WriteHeader(http.StatusNotFound)
} else {
w.WriteHeader(http.StatusNoContent)
}
})
}
// Middleware returns a middleware that adds CORS headers to the response.
func Middleware(origin ...string) func(http.HandlerFunc) http.HandlerFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
setHeader(w, getOrigin(origin))
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
} else {
next(w, r)
}
}
}
}
func getOrigin(origins []string) string {
if len(origins) > 0 {
return origins[0]
} else {
return allOrigins
}
}
func setHeader(w http.ResponseWriter, origin string) {
w.Header().Set(allowOrigin, origin)
w.Header().Set(allowMethods, methods)
w.Header().Set(allowHeaders, allowHeadersVal)
w.Header().Set(exposeHeaders, exposeHeadersVal)
w.Header().Set(allowCredentials, allowTrue)
w.Header().Set(maxAgeHeader, maxAgeHeaderVal)
}

@ -0,0 +1,76 @@
package cors
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCorsHandlerWithOrigins(t *testing.T) {
tests := []struct {
name string
origins []string
expect string
}{
{
name: "allow all origins",
expect: allOrigins,
},
{
name: "allow one origin",
origins: []string{"local"},
expect: "local",
},
{
name: "allow many origins",
origins: []string{"local", "remote"},
expect: "local",
},
}
methods := []string{
http.MethodOptions,
http.MethodGet,
http.MethodPost,
}
for _, test := range tests {
for _, method := range methods {
test := test
t.Run(test.name+"-handler", func(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil)
w := httptest.NewRecorder()
handler := Handler(test.origins...)
handler.ServeHTTP(w, r)
if method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
} else {
assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
}
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
})
}
}
for _, test := range tests {
for _, method := range methods {
test := test
t.Run(test.name+"-middleware", func(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil)
w := httptest.NewRecorder()
handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler.ServeHTTP(w, r)
if method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
} else {
assert.Equal(t, http.StatusOK, w.Result().StatusCode)
}
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
})
}
}
}

@ -10,21 +10,18 @@ import (
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/rest/handler" "github.com/tal-tech/go-zero/rest/handler"
"github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/httpx"
"github.com/tal-tech/go-zero/rest/internal/cors"
"github.com/tal-tech/go-zero/rest/router" "github.com/tal-tech/go-zero/rest/router"
) )
type ( type (
runOptions struct {
start func(*engine) error
}
// RunOption defines the method to customize a Server. // RunOption defines the method to customize a Server.
RunOption func(*Server) RunOption func(*Server)
// A Server is a http server. // A Server is a http server.
Server struct { Server struct {
ngin *engine ngin *engine
opts runOptions router httpx.Router
} }
) )
@ -48,12 +45,8 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
} }
server := &Server{ server := &Server{
ngin: newEngine(c), ngin: newEngine(c),
opts: runOptions{ router: router.NewRouter(),
start: func(ng *engine) error {
return ng.Start()
},
},
} }
for _, opt := range opts { for _, opt := range opts {
@ -71,7 +64,7 @@ func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
for _, opt := range opts { for _, opt := range opts {
opt(&r) opt(&r)
} }
s.ngin.AddRoutes(r) s.ngin.addRoutes(r)
} }
// AddRoute adds given route into the Server. // AddRoute adds given route into the Server.
@ -83,7 +76,7 @@ func (s *Server) AddRoute(r Route, opts ...RouteOption) {
// Graceful shutdown is enabled by default. // Graceful shutdown is enabled by default.
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period. // Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
func (s *Server) Start() { func (s *Server) Start() {
handleError(s.opts.start(s.ngin)) handleError(s.ngin.start(s.router))
} }
// Stop stops the Server. // Stop stops the Server.
@ -103,6 +96,14 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
} }
} }
// WithCors returns a func to enable CORS for given origin, or default to all origins (*).
func WithCors(origin ...string) RunOption {
return func(server *Server) {
server.router.SetNotAllowedHandler(cors.Handler(origin...))
server.Use(cors.Middleware(origin...))
}
}
// WithJwt returns a func to enable jwt authentication in given route. // WithJwt returns a func to enable jwt authentication in given route.
func WithJwt(secret string) RouteOption { func WithJwt(secret string) RouteOption {
return func(r *featuredRoutes) { return func(r *featuredRoutes) {
@ -151,16 +152,16 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route {
// WithNotFoundHandler returns a RunOption with not found handler set to given handler. // WithNotFoundHandler returns a RunOption with not found handler set to given handler.
func WithNotFoundHandler(handler http.Handler) RunOption { func WithNotFoundHandler(handler http.Handler) RunOption {
rt := router.NewRouter() return func(server *Server) {
rt.SetNotFoundHandler(handler) server.router.SetNotFoundHandler(handler)
return WithRouter(rt) }
} }
// WithNotAllowedHandler returns a RunOption with not allowed handler set to given handler. // WithNotAllowedHandler returns a RunOption with not allowed handler set to given handler.
func WithNotAllowedHandler(handler http.Handler) RunOption { func WithNotAllowedHandler(handler http.Handler) RunOption {
rt := router.NewRouter() return func(server *Server) {
rt.SetNotAllowedHandler(handler) server.router.SetNotAllowedHandler(handler)
return WithRouter(rt) }
} }
// WithPrefix adds group as a prefix to the route paths. // WithPrefix adds group as a prefix to the route paths.
@ -189,9 +190,7 @@ func WithPriority() RouteOption {
// WithRouter returns a RunOption that make server run with given router. // WithRouter returns a RunOption that make server run with given router.
func WithRouter(router httpx.Router) RunOption { func WithRouter(router httpx.Router) RunOption {
return func(server *Server) { return func(server *Server) {
server.opts.start = func(ng *engine) error { server.router = router
return ng.StartWithRouter(router)
}
} }
} }
@ -222,14 +221,14 @@ func WithTLSConfig(cfg *tls.Config) RunOption {
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set. // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
return func(srv *Server) { return func(srv *Server) {
srv.ngin.SetUnauthorizedCallback(callback) srv.ngin.setUnauthorizedCallback(callback)
} }
} }
// WithUnsignedCallback returns a RunOption that with given unsigned callback set. // WithUnsignedCallback returns a RunOption that with given unsigned callback set.
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption { func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
return func(srv *Server) { return func(srv *Server) {
srv.ngin.SetUnsignedCallback(callback) srv.ngin.setUnsignedCallback(callback)
} }
} }

@ -22,11 +22,6 @@ Port: 54321
` `
var cnf RestConf var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf)) assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
failStart := func(server *Server) {
server.opts.start = func(e *engine) error {
return http.ErrServerClosed
}
}
tests := []struct { tests := []struct {
c RestConf c RestConf
@ -35,38 +30,40 @@ Port: 54321
}{ }{
{ {
c: RestConf{}, c: RestConf{},
opts: []RunOption{failStart}, opts: []RunOption{WithRouter(mockedRouter{}), WithCors()},
fail: true, fail: true,
}, },
{ {
c: cnf, c: cnf,
opts: []RunOption{failStart}, opts: []RunOption{WithRouter(mockedRouter{})},
}, },
{ {
c: cnf, c: cnf,
opts: []RunOption{WithNotAllowedHandler(nil), failStart}, opts: []RunOption{WithRouter(mockedRouter{}), WithNotAllowedHandler(nil)},
}, },
{ {
c: cnf, c: cnf,
opts: []RunOption{WithNotFoundHandler(nil), failStart}, opts: []RunOption{WithNotFoundHandler(nil), WithRouter(mockedRouter{})},
}, },
{ {
c: cnf, c: cnf,
opts: []RunOption{WithUnauthorizedCallback(nil), failStart}, opts: []RunOption{WithUnauthorizedCallback(nil), WithRouter(mockedRouter{})},
}, },
{ {
c: cnf, c: cnf,
opts: []RunOption{WithUnsignedCallback(nil), failStart}, opts: []RunOption{WithUnsignedCallback(nil), WithRouter(mockedRouter{})},
}, },
} }
for _, test := range tests { for _, test := range tests {
srv, err := NewServer(test.c, test.opts...) var srv *Server
var err error
if test.fail { if test.fail {
_, err = NewServer(test.c, test.opts...)
assert.NotNil(t, err) assert.NotNil(t, err)
}
if err != nil {
continue continue
} else {
srv = MustNewServer(test.c, test.opts...)
} }
srv.Use(ToMiddleware(func(next http.Handler) http.Handler { srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
@ -80,8 +77,21 @@ Port: 54321
Handler: nil, Handler: nil,
}, WithJwt("thesecret"), WithSignature(SignatureConf{}), }, WithJwt("thesecret"), WithSignature(SignatureConf{}),
WithJwtTransition("preivous", "thenewone")) WithJwtTransition("preivous", "thenewone"))
srv.Start()
srv.Stop() func() {
defer func() {
p := recover()
switch v := p.(type) {
case error:
assert.Equal(t, "foo", v.Error())
default:
t.Fail()
}
}()
srv.Start()
srv.Stop()
}()
} }
} }
@ -180,6 +190,9 @@ func TestMultiMiddlewares(t *testing.T) {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} }
}, },
ToMiddleware(func(next http.Handler) http.Handler {
return next
}),
}, Route{ }, Route{
Method: http.MethodGet, Method: http.MethodGet,
Path: "/first/:name/:year", Path: "/first/:name/:year",
@ -282,3 +295,18 @@ Port: 54321
assert.Equal(t, srv.ngin.tlsConfig, testCase.res) assert.Equal(t, srv.ngin.tlsConfig, testCase.res)
} }
} }
func TestWithCors(t *testing.T) {
const configYaml = `
Name: foo
Port: 54321
`
var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
rt := router.NewRouter()
srv, err := NewServer(cnf, WithRouter(rt))
assert.Nil(t, err)
opt := WithCors("local")
opt(srv)
}

@ -27,12 +27,12 @@ import (
{{.importPackages}} {{.importPackages}}
) )
func RegisterHandlers(engine *rest.Server, serverCtx *svc.ServiceContext) { func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
{{.routesAdditions}} {{.routesAdditions}}
} }
` `
routesAdditionTemplate = ` routesAdditionTemplate = `
engine.AddRoutes( server.AddRoutes(
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.routes}} {{.jwt}}{{.signature}} {{.prefix}}
) )
` `

Loading…
Cancel
Save