diff --git a/example/http/demo/main.go b/example/http/demo/main.go index 47b05806..9838c8ab 100644 --- a/example/http/demo/main.go +++ b/example/http/demo/main.go @@ -56,7 +56,7 @@ func main() { Port: *port, Timeout: *timeout, MaxConns: 500, - }) + }, rest.WithNotAllowedHandler(rest.CorsHandler())) defer engine.Stop() engine.Use(first) diff --git a/rest/handlers.go b/rest/handlers.go new file mode 100644 index 00000000..a6bffbb6 --- /dev/null +++ b/rest/handlers.go @@ -0,0 +1,29 @@ +package rest + +import ( + "net/http" + "strings" +) + +const ( + allowOrigin = "Access-Control-Allow-Origin" + allOrigin = "*" + allowMethods = "Access-Control-Allow-Methods" + allowHeaders = "Access-Control-Allow-Headers" + headers = "Content-Type, Content-Length, Origin" + methods = "GET, HEAD, POST, PATCH, PUT, DELETE" + separator = ", " +) + +func CorsHandler(origins ...string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(origins) > 0 { + w.Header().Set(allowOrigin, strings.Join(origins, separator)) + } else { + w.Header().Set(allowOrigin, allOrigin) + } + w.Header().Set(allowMethods, methods) + w.Header().Set(allowHeaders, headers) + w.WriteHeader(http.StatusNoContent) + }) +} diff --git a/rest/handlers_test.go b/rest/handlers_test.go new file mode 100644 index 00000000..9b2dd746 --- /dev/null +++ b/rest/handlers_test.go @@ -0,0 +1,27 @@ +package rest + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCorsHandler(t *testing.T) { + w := httptest.NewRecorder() + handler := CorsHandler() + handler.ServeHTTP(w, nil) + assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) + assert.Equal(t, allOrigin, w.Header().Get(allowOrigin)) +} + +func TestCorsHandlerWithOrigins(t *testing.T) { + origins := []string{"local", "remote"} + w := httptest.NewRecorder() + handler := CorsHandler(origins...) + handler.ServeHTTP(w, nil) + assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) + assert.Equal(t, strings.Join(origins, separator), w.Header().Get(allowOrigin)) +} diff --git a/rest/httpx/router.go b/rest/httpx/router.go index 11bbc9a8..7d3e32b8 100644 --- a/rest/httpx/router.go +++ b/rest/httpx/router.go @@ -6,4 +6,5 @@ type Router interface { http.Handler Handle(method string, path string, handler http.Handler) error SetNotFoundHandler(handler http.Handler) + SetNotAllowedHandler(handler http.Handler) } diff --git a/rest/router/patrouter.go b/rest/router/patrouter.go index 4b2d23bc..cee331d5 100644 --- a/rest/router/patrouter.go +++ b/rest/router/patrouter.go @@ -22,8 +22,9 @@ var ( ) type patRouter struct { - trees map[string]*search.Tree - notFound http.Handler + trees map[string]*search.Tree + notFound http.Handler + notAllowed http.Handler } func NewRouter() httpx.Router { @@ -63,11 +64,17 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok { + allow, ok := pr.methodNotAllowed(r.Method, reqPath) + if !ok { + pr.handleNotFound(w, r) + return + } + + if pr.notAllowed != nil { + pr.notAllowed.ServeHTTP(w, r) + } else { w.Header().Set(allowHeader, allow) w.WriteHeader(http.StatusMethodNotAllowed) - } else { - pr.handleNotFound(w, r) } } @@ -75,6 +82,10 @@ func (pr *patRouter) SetNotFoundHandler(handler http.Handler) { pr.notFound = handler } +func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) { + pr.notAllowed = handler +} + func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) { if pr.notFound != nil { pr.notFound.ServeHTTP(w, r) diff --git a/rest/router/patrouter_test.go b/rest/router/patrouter_test.go index 6594df3e..1a952b4b 100644 --- a/rest/router/patrouter_test.go +++ b/rest/router/patrouter_test.go @@ -60,13 +60,30 @@ func TestPatRouterNotFound(t *testing.T) { router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { notFound = true })) - router.Handle(http.MethodGet, "/a/b", nil) + err := router.Handle(http.MethodGet, "/a/b", + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + assert.Nil(t, err) r, _ := http.NewRequest(http.MethodGet, "/b/c", nil) w := new(mockedResponseWriter) router.ServeHTTP(w, r) assert.True(t, notFound) } +func TestPatRouterNotAllowed(t *testing.T) { + var notAllowed bool + router := NewRouter() + router.SetNotAllowedHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + notAllowed = true + })) + err := router.Handle(http.MethodGet, "/a/b", + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + assert.Nil(t, err) + r, _ := http.NewRequest(http.MethodPost, "/a/b", nil) + w := new(mockedResponseWriter) + router.ServeHTTP(w, r) + assert.True(t, notAllowed) +} + func TestPatRouter(t *testing.T) { tests := []struct { method string diff --git a/rest/server.go b/rest/server.go index 962c5999..bc23ce6c 100644 --- a/rest/server.go +++ b/rest/server.go @@ -1,12 +1,14 @@ package rest import ( + "errors" "log" "net/http" "github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/rest/handler" "github.com/tal-tech/go-zero/rest/httpx" + "github.com/tal-tech/go-zero/rest/router" ) type ( @@ -32,6 +34,10 @@ func MustNewServer(c RestConf, opts ...RunOption) *Server { } func NewServer(c RestConf, opts ...RunOption) (*Server, error) { + if len(opts) > 1 { + return nil, errors.New("only one RunOption is allowed") + } + if err := c.SetUp(); err != nil { return nil, err } @@ -125,6 +131,18 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route { return routes } +func WithNotFoundHandler(handler http.Handler) RunOption { + rt := router.NewRouter() + rt.SetNotFoundHandler(handler) + return WithRouter(rt) +} + +func WithNotAllowedHandler(handler http.Handler) RunOption { + rt := router.NewRouter() + rt.SetNotAllowedHandler(handler) + return WithRouter(rt) +} + func WithPriority() RouteOption { return func(r *featuredRoutes) { r.priority = true diff --git a/rest/server_test.go b/rest/server_test.go index a363b91d..413325ad 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -12,6 +12,11 @@ import ( "github.com/tal-tech/go-zero/rest/router" ) +func TestNewServer(t *testing.T) { + _, err := NewServer(RestConf{}, WithNotFoundHandler(nil), WithNotAllowedHandler(nil)) + assert.NotNil(t, err) +} + func TestWithMiddleware(t *testing.T) { m := make(map[string]string) router := router.NewRouter() @@ -69,7 +74,7 @@ func TestWithMiddleware(t *testing.T) { }, m) } -func TestMultiMiddleware(t *testing.T) { +func TestMultiMiddlewares(t *testing.T) { m := make(map[string]string) router := router.NewRouter() handler := func(w http.ResponseWriter, r *http.Request) { @@ -140,3 +145,9 @@ func TestMultiMiddleware(t *testing.T) { "whatever": "200000200000", }, m) } + +func TestWithPriority(t *testing.T) { + var fr featuredRoutes + WithPriority()(&fr) + assert.True(t, fr.priority) +}