support cors in rest server

master
kevin 4 years ago
parent 1c1e4bca86
commit fe0d0687f5

@ -56,7 +56,7 @@ func main() {
Port: *port, Port: *port,
Timeout: *timeout, Timeout: *timeout,
MaxConns: 500, MaxConns: 500,
}) }, rest.WithNotAllowedHandler(rest.CorsHandler()))
defer engine.Stop() defer engine.Stop()
engine.Use(first) engine.Use(first)

@ -0,0 +1,29 @@
package rest
import (
"net/http"
"strings"
)
const (
allowOrigin = "Access-Control-Allow-Origin"
allOrigin = "*"
allowMethods = "Access-Control-Allow-Methods"
allowHeaders = "Access-Control-Allow-Headers"
headers = "Content-Type, Content-Length, Origin"
methods = "GET, HEAD, POST, PATCH, PUT, DELETE"
separator = ", "
)
func CorsHandler(origins ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if len(origins) > 0 {
w.Header().Set(allowOrigin, strings.Join(origins, separator))
} else {
w.Header().Set(allowOrigin, allOrigin)
}
w.Header().Set(allowMethods, methods)
w.Header().Set(allowHeaders, headers)
w.WriteHeader(http.StatusNoContent)
})
}

@ -0,0 +1,27 @@
package rest
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCorsHandler(t *testing.T) {
w := httptest.NewRecorder()
handler := CorsHandler()
handler.ServeHTTP(w, nil)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
assert.Equal(t, allOrigin, w.Header().Get(allowOrigin))
}
func TestCorsHandlerWithOrigins(t *testing.T) {
origins := []string{"local", "remote"}
w := httptest.NewRecorder()
handler := CorsHandler(origins...)
handler.ServeHTTP(w, nil)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
assert.Equal(t, strings.Join(origins, separator), w.Header().Get(allowOrigin))
}

@ -6,4 +6,5 @@ type Router interface {
http.Handler http.Handler
Handle(method string, path string, handler http.Handler) error Handle(method string, path string, handler http.Handler) error
SetNotFoundHandler(handler http.Handler) SetNotFoundHandler(handler http.Handler)
SetNotAllowedHandler(handler http.Handler)
} }

@ -22,8 +22,9 @@ var (
) )
type patRouter struct { type patRouter struct {
trees map[string]*search.Tree trees map[string]*search.Tree
notFound http.Handler notFound http.Handler
notAllowed http.Handler
} }
func NewRouter() httpx.Router { func NewRouter() httpx.Router {
@ -63,11 +64,17 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok { allow, ok := pr.methodNotAllowed(r.Method, reqPath)
if !ok {
pr.handleNotFound(w, r)
return
}
if pr.notAllowed != nil {
pr.notAllowed.ServeHTTP(w, r)
} else {
w.Header().Set(allowHeader, allow) w.Header().Set(allowHeader, allow)
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
} else {
pr.handleNotFound(w, r)
} }
} }
@ -75,6 +82,10 @@ func (pr *patRouter) SetNotFoundHandler(handler http.Handler) {
pr.notFound = handler pr.notFound = handler
} }
func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) {
pr.notAllowed = handler
}
func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) { func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
if pr.notFound != nil { if pr.notFound != nil {
pr.notFound.ServeHTTP(w, r) pr.notFound.ServeHTTP(w, r)

@ -60,13 +60,30 @@ func TestPatRouterNotFound(t *testing.T) {
router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
notFound = true notFound = true
})) }))
router.Handle(http.MethodGet, "/a/b", nil) err := router.Handle(http.MethodGet, "/a/b",
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
assert.Nil(t, err)
r, _ := http.NewRequest(http.MethodGet, "/b/c", nil) r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
w := new(mockedResponseWriter) w := new(mockedResponseWriter)
router.ServeHTTP(w, r) router.ServeHTTP(w, r)
assert.True(t, notFound) assert.True(t, notFound)
} }
func TestPatRouterNotAllowed(t *testing.T) {
var notAllowed bool
router := NewRouter()
router.SetNotAllowedHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
notAllowed = true
}))
err := router.Handle(http.MethodGet, "/a/b",
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
assert.Nil(t, err)
r, _ := http.NewRequest(http.MethodPost, "/a/b", nil)
w := new(mockedResponseWriter)
router.ServeHTTP(w, r)
assert.True(t, notAllowed)
}
func TestPatRouter(t *testing.T) { func TestPatRouter(t *testing.T) {
tests := []struct { tests := []struct {
method string method string

@ -1,12 +1,14 @@
package rest package rest
import ( import (
"errors"
"log" "log"
"net/http" "net/http"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/rest/handler" "github.com/tal-tech/go-zero/rest/handler"
"github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/httpx"
"github.com/tal-tech/go-zero/rest/router"
) )
type ( type (
@ -32,6 +34,10 @@ func MustNewServer(c RestConf, opts ...RunOption) *Server {
} }
func NewServer(c RestConf, opts ...RunOption) (*Server, error) { func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
if len(opts) > 1 {
return nil, errors.New("only one RunOption is allowed")
}
if err := c.SetUp(); err != nil { if err := c.SetUp(); err != nil {
return nil, err return nil, err
} }
@ -125,6 +131,18 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route {
return routes return routes
} }
func WithNotFoundHandler(handler http.Handler) RunOption {
rt := router.NewRouter()
rt.SetNotFoundHandler(handler)
return WithRouter(rt)
}
func WithNotAllowedHandler(handler http.Handler) RunOption {
rt := router.NewRouter()
rt.SetNotAllowedHandler(handler)
return WithRouter(rt)
}
func WithPriority() RouteOption { func WithPriority() RouteOption {
return func(r *featuredRoutes) { return func(r *featuredRoutes) {
r.priority = true r.priority = true

@ -12,6 +12,11 @@ import (
"github.com/tal-tech/go-zero/rest/router" "github.com/tal-tech/go-zero/rest/router"
) )
func TestNewServer(t *testing.T) {
_, err := NewServer(RestConf{}, WithNotFoundHandler(nil), WithNotAllowedHandler(nil))
assert.NotNil(t, err)
}
func TestWithMiddleware(t *testing.T) { func TestWithMiddleware(t *testing.T) {
m := make(map[string]string) m := make(map[string]string)
router := router.NewRouter() router := router.NewRouter()
@ -69,7 +74,7 @@ func TestWithMiddleware(t *testing.T) {
}, m) }, m)
} }
func TestMultiMiddleware(t *testing.T) { func TestMultiMiddlewares(t *testing.T) {
m := make(map[string]string) m := make(map[string]string)
router := router.NewRouter() router := router.NewRouter()
handler := func(w http.ResponseWriter, r *http.Request) { handler := func(w http.ResponseWriter, r *http.Request) {
@ -140,3 +145,9 @@ func TestMultiMiddleware(t *testing.T) {
"whatever": "200000200000", "whatever": "200000200000",
}, m) }, m)
} }
func TestWithPriority(t *testing.T) {
var fr featuredRoutes
WithPriority()(&fr)
assert.True(t, fr.priority)
}

Loading…
Cancel
Save