diff --git a/rest/engine.go b/rest/engine.go index e19a56b1..881a0a9e 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -14,7 +14,6 @@ import ( "github.com/tal-tech/go-zero/rest/handler" "github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/internal" - "github.com/tal-tech/go-zero/rest/router" ) // use 1000m to represent 100% @@ -47,39 +46,10 @@ func newEngine(c RestConf) *engine { return srv } -func (ng *engine) AddRoutes(r featuredRoutes) { +func (ng *engine) addRoutes(r featuredRoutes) { ng.routes = append(ng.routes, r) } -func (ng *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) { - ng.unauthorizedCallback = callback -} - -func (ng *engine) SetUnsignedCallback(callback handler.UnsignedCallback) { - ng.unsignedCallback = callback -} - -func (ng *engine) Start() error { - return ng.StartWithRouter(router.NewRouter()) -} - -func (ng *engine) StartWithRouter(router httpx.Router) error { - if err := ng.bindRoutes(router); err != nil { - return err - } - - if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 { - return internal.StartHttp(ng.conf.Host, ng.conf.Port, router) - } - - return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile, - ng.conf.KeyFile, router, func(srv *http.Server) { - if ng.tlsConfig != nil { - srv.TLSConfig = ng.tlsConfig - } - }) -} - func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain, verifier func(alice.Chain) alice.Chain) alice.Chain { if fr.jwt.enabled { @@ -188,6 +158,14 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) { ng.tlsConfig = cfg } +func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) { + ng.unauthorizedCallback = callback +} + +func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) { + ng.unsignedCallback = callback +} + func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) { if !signature.enabled { return func(chain alice.Chain) alice.Chain { @@ -228,6 +206,23 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic }, nil } +func (ng *engine) start(router httpx.Router) error { + if err := ng.bindRoutes(router); err != nil { + return err + } + + if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 { + return internal.StartHttp(ng.conf.Host, ng.conf.Port, router) + } + + return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile, + ng.conf.KeyFile, router, func(srv *http.Server) { + if ng.tlsConfig != nil { + srv.TLSConfig = ng.tlsConfig + } + }) +} + func (ng *engine) use(middleware Middleware) { ng.middlewares = append(ng.middlewares, middleware) } diff --git a/rest/engine_test.go b/rest/engine_test.go index dfd084c2..60526d5f 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -144,13 +144,13 @@ Verbose: true var cnf RestConf assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(yaml), &cnf)) ng := newEngine(cnf) - ng.AddRoutes(route) + ng.addRoutes(route) ng.use(func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r) } }) - assert.NotNil(t, ng.StartWithRouter(mockedRouter{})) + assert.NotNil(t, ng.start(mockedRouter{})) } } } diff --git a/rest/handlers.go b/rest/handlers.go deleted file mode 100644 index 8181c9b3..00000000 --- a/rest/handlers.go +++ /dev/null @@ -1,27 +0,0 @@ -package rest - -import "net/http" - -const ( - allowOrigin = "Access-Control-Allow-Origin" - allOrigins = "*" - allowMethods = "Access-Control-Allow-Methods" - allowHeaders = "Access-Control-Allow-Headers" - headers = "Content-Type, Content-Length, Origin" - methods = "GET, HEAD, POST, PATCH, PUT, DELETE" -) - -// CorsHandler handles cross domain OPTIONS requests. -// At most one origin can be specified, other origins are ignored if given. -func CorsHandler(origins ...string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if len(origins) > 0 { - w.Header().Set(allowOrigin, origins[0]) - } else { - w.Header().Set(allowOrigin, allOrigins) - } - w.Header().Set(allowMethods, methods) - w.Header().Set(allowHeaders, headers) - w.WriteHeader(http.StatusNoContent) - }) -} diff --git a/rest/handlers_test.go b/rest/handlers_test.go deleted file mode 100644 index 366dce42..00000000 --- a/rest/handlers_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package rest - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCorsHandlerWithOrigins(t *testing.T) { - tests := []struct { - name string - origins []string - expect string - }{ - { - name: "allow all origins", - expect: allOrigins, - }, - { - name: "allow one origin", - origins: []string{"local"}, - expect: "local", - }, - { - name: "allow many origins", - origins: []string{"local", "remote"}, - expect: "local", - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - w := httptest.NewRecorder() - handler := CorsHandler(test.origins...) - handler.ServeHTTP(w, nil) - assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) - assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) - }) - } -} diff --git a/rest/internal/cors/handlers.go b/rest/internal/cors/handlers.go new file mode 100644 index 00000000..3be0bd44 --- /dev/null +++ b/rest/internal/cors/handlers.go @@ -0,0 +1,64 @@ +package cors + +import "net/http" + +const ( + allowOrigin = "Access-Control-Allow-Origin" + allOrigins = "*" + allowMethods = "Access-Control-Allow-Methods" + allowHeaders = "Access-Control-Allow-Headers" + allowCredentials = "Access-Control-Allow-Credentials" + exposeHeaders = "Access-Control-Expose-Headers" + allowHeadersVal = "Content-Type, Origin, X-CSRF-Token, Authorization, AccessToken, Token, Range" + exposeHeadersVal = "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers" + methods = "GET, HEAD, POST, PATCH, PUT, DELETE" + allowTrue = "true" + maxAgeHeader = "Access-Control-Max-Age" + maxAgeHeaderVal = "86400" +) + +// Handler handles cross domain not allowed requests. +// At most one origin can be specified, other origins are ignored if given, default to be *. +func Handler(origin ...string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + setHeader(w, getOrigin(origin)) + + if r.Method != http.MethodOptions { + w.WriteHeader(http.StatusNotFound) + } else { + w.WriteHeader(http.StatusNoContent) + } + }) +} + +// Middleware returns a middleware that adds CORS headers to the response. +func Middleware(origin ...string) func(http.HandlerFunc) http.HandlerFunc { + return func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + setHeader(w, getOrigin(origin)) + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + } else { + next(w, r) + } + } + } +} + +func getOrigin(origins []string) string { + if len(origins) > 0 { + return origins[0] + } else { + return allOrigins + } +} + +func setHeader(w http.ResponseWriter, origin string) { + w.Header().Set(allowOrigin, origin) + w.Header().Set(allowMethods, methods) + w.Header().Set(allowHeaders, allowHeadersVal) + w.Header().Set(exposeHeaders, exposeHeadersVal) + w.Header().Set(allowCredentials, allowTrue) + w.Header().Set(maxAgeHeader, maxAgeHeaderVal) +} diff --git a/rest/internal/cors/handlers_test.go b/rest/internal/cors/handlers_test.go new file mode 100644 index 00000000..1e6b8420 --- /dev/null +++ b/rest/internal/cors/handlers_test.go @@ -0,0 +1,76 @@ +package cors + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCorsHandlerWithOrigins(t *testing.T) { + tests := []struct { + name string + origins []string + expect string + }{ + { + name: "allow all origins", + expect: allOrigins, + }, + { + name: "allow one origin", + origins: []string{"local"}, + expect: "local", + }, + { + name: "allow many origins", + origins: []string{"local", "remote"}, + expect: "local", + }, + } + + methods := []string{ + http.MethodOptions, + http.MethodGet, + http.MethodPost, + } + + for _, test := range tests { + for _, method := range methods { + test := test + t.Run(test.name+"-handler", func(t *testing.T) { + r := httptest.NewRequest(method, "http://localhost", nil) + w := httptest.NewRecorder() + handler := Handler(test.origins...) + handler.ServeHTTP(w, r) + if method == http.MethodOptions { + assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) + } else { + assert.Equal(t, http.StatusNotFound, w.Result().StatusCode) + } + assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) + }) + } + } + + for _, test := range tests { + for _, method := range methods { + test := test + t.Run(test.name+"-middleware", func(t *testing.T) { + r := httptest.NewRequest(method, "http://localhost", nil) + w := httptest.NewRecorder() + handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler.ServeHTTP(w, r) + if method == http.MethodOptions { + assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) + } else { + assert.Equal(t, http.StatusOK, w.Result().StatusCode) + } + assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) + }) + } + } +} diff --git a/rest/server.go b/rest/server.go index 18b41fe2..60221d6b 100644 --- a/rest/server.go +++ b/rest/server.go @@ -10,21 +10,18 @@ import ( "github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/rest/handler" "github.com/tal-tech/go-zero/rest/httpx" + "github.com/tal-tech/go-zero/rest/internal/cors" "github.com/tal-tech/go-zero/rest/router" ) type ( - runOptions struct { - start func(*engine) error - } - // RunOption defines the method to customize a Server. RunOption func(*Server) // A Server is a http server. Server struct { - ngin *engine - opts runOptions + ngin *engine + router httpx.Router } ) @@ -48,12 +45,8 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) { } server := &Server{ - ngin: newEngine(c), - opts: runOptions{ - start: func(ng *engine) error { - return ng.Start() - }, - }, + ngin: newEngine(c), + router: router.NewRouter(), } for _, opt := range opts { @@ -71,7 +64,7 @@ func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) { for _, opt := range opts { opt(&r) } - s.ngin.AddRoutes(r) + s.ngin.addRoutes(r) } // AddRoute adds given route into the Server. @@ -83,7 +76,7 @@ func (s *Server) AddRoute(r Route, opts ...RouteOption) { // Graceful shutdown is enabled by default. // Use proc.SetTimeToForceQuit to customize the graceful shutdown period. func (s *Server) Start() { - handleError(s.opts.start(s.ngin)) + handleError(s.ngin.start(s.router)) } // Stop stops the Server. @@ -103,6 +96,14 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { } } +// WithCors returns a func to enable CORS for given origin, or default to all origins (*). +func WithCors(origin ...string) RunOption { + return func(server *Server) { + server.router.SetNotAllowedHandler(cors.Handler(origin...)) + server.Use(cors.Middleware(origin...)) + } +} + // WithJwt returns a func to enable jwt authentication in given route. func WithJwt(secret string) RouteOption { return func(r *featuredRoutes) { @@ -151,16 +152,16 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route { // WithNotFoundHandler returns a RunOption with not found handler set to given handler. func WithNotFoundHandler(handler http.Handler) RunOption { - rt := router.NewRouter() - rt.SetNotFoundHandler(handler) - return WithRouter(rt) + return func(server *Server) { + server.router.SetNotFoundHandler(handler) + } } // WithNotAllowedHandler returns a RunOption with not allowed handler set to given handler. func WithNotAllowedHandler(handler http.Handler) RunOption { - rt := router.NewRouter() - rt.SetNotAllowedHandler(handler) - return WithRouter(rt) + return func(server *Server) { + server.router.SetNotAllowedHandler(handler) + } } // WithPrefix adds group as a prefix to the route paths. @@ -189,9 +190,7 @@ func WithPriority() RouteOption { // WithRouter returns a RunOption that make server run with given router. func WithRouter(router httpx.Router) RunOption { return func(server *Server) { - server.opts.start = func(ng *engine) error { - return ng.StartWithRouter(router) - } + server.router = router } } @@ -222,14 +221,14 @@ func WithTLSConfig(cfg *tls.Config) RunOption { // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set. func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { return func(srv *Server) { - srv.ngin.SetUnauthorizedCallback(callback) + srv.ngin.setUnauthorizedCallback(callback) } } // WithUnsignedCallback returns a RunOption that with given unsigned callback set. func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption { return func(srv *Server) { - srv.ngin.SetUnsignedCallback(callback) + srv.ngin.setUnsignedCallback(callback) } } diff --git a/rest/server_test.go b/rest/server_test.go index 3208048c..da164c41 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -22,11 +22,6 @@ Port: 54321 ` var cnf RestConf assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf)) - failStart := func(server *Server) { - server.opts.start = func(e *engine) error { - return http.ErrServerClosed - } - } tests := []struct { c RestConf @@ -35,38 +30,40 @@ Port: 54321 }{ { c: RestConf{}, - opts: []RunOption{failStart}, + opts: []RunOption{WithRouter(mockedRouter{}), WithCors()}, fail: true, }, { c: cnf, - opts: []RunOption{failStart}, + opts: []RunOption{WithRouter(mockedRouter{})}, }, { c: cnf, - opts: []RunOption{WithNotAllowedHandler(nil), failStart}, + opts: []RunOption{WithRouter(mockedRouter{}), WithNotAllowedHandler(nil)}, }, { c: cnf, - opts: []RunOption{WithNotFoundHandler(nil), failStart}, + opts: []RunOption{WithNotFoundHandler(nil), WithRouter(mockedRouter{})}, }, { c: cnf, - opts: []RunOption{WithUnauthorizedCallback(nil), failStart}, + opts: []RunOption{WithUnauthorizedCallback(nil), WithRouter(mockedRouter{})}, }, { c: cnf, - opts: []RunOption{WithUnsignedCallback(nil), failStart}, + opts: []RunOption{WithUnsignedCallback(nil), WithRouter(mockedRouter{})}, }, } for _, test := range tests { - srv, err := NewServer(test.c, test.opts...) + var srv *Server + var err error if test.fail { + _, err = NewServer(test.c, test.opts...) assert.NotNil(t, err) - } - if err != nil { continue + } else { + srv = MustNewServer(test.c, test.opts...) } srv.Use(ToMiddleware(func(next http.Handler) http.Handler { @@ -80,8 +77,21 @@ Port: 54321 Handler: nil, }, WithJwt("thesecret"), WithSignature(SignatureConf{}), WithJwtTransition("preivous", "thenewone")) - srv.Start() - srv.Stop() + + func() { + defer func() { + p := recover() + switch v := p.(type) { + case error: + assert.Equal(t, "foo", v.Error()) + default: + t.Fail() + } + }() + + srv.Start() + srv.Stop() + }() } } @@ -180,6 +190,9 @@ func TestMultiMiddlewares(t *testing.T) { next.ServeHTTP(w, r) } }, + ToMiddleware(func(next http.Handler) http.Handler { + return next + }), }, Route{ Method: http.MethodGet, Path: "/first/:name/:year", @@ -282,3 +295,18 @@ Port: 54321 assert.Equal(t, srv.ngin.tlsConfig, testCase.res) } } + +func TestWithCors(t *testing.T) { + const configYaml = ` +Name: foo +Port: 54321 +` + var cnf RestConf + assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf)) + rt := router.NewRouter() + srv, err := NewServer(cnf, WithRouter(rt)) + assert.Nil(t, err) + + opt := WithCors("local") + opt(srv) +} diff --git a/tools/goctl/api/gogen/genroutes.go b/tools/goctl/api/gogen/genroutes.go index 7c14d83b..abd0c001 100644 --- a/tools/goctl/api/gogen/genroutes.go +++ b/tools/goctl/api/gogen/genroutes.go @@ -27,12 +27,12 @@ import ( {{.importPackages}} ) -func RegisterHandlers(engine *rest.Server, serverCtx *svc.ServiceContext) { +func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) { {{.routesAdditions}} } ` routesAdditionTemplate = ` - engine.AddRoutes( + server.AddRoutes( {{.routes}} {{.jwt}}{{.signature}} {{.prefix}} ) `