You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
123 lines
3.0 KiB
Go
123 lines
3.0 KiB
Go
package router
|
|
|
|
import (
|
|
"net/http"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"zero/ngin/internal/context"
|
|
)
|
|
|
|
type mockedResponseWriter struct {
|
|
code int
|
|
}
|
|
|
|
func (m *mockedResponseWriter) Header() http.Header {
|
|
return http.Header{}
|
|
}
|
|
|
|
func (m *mockedResponseWriter) Write(p []byte) (int, error) {
|
|
return len(p), nil
|
|
}
|
|
|
|
func (m *mockedResponseWriter) WriteHeader(code int) {
|
|
m.code = code
|
|
}
|
|
|
|
func TestPatRouterHandleErrors(t *testing.T) {
|
|
tests := []struct {
|
|
method string
|
|
path string
|
|
err error
|
|
}{
|
|
{"FAKE", "", ErrInvalidMethod},
|
|
{"GET", "", ErrInvalidPath},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.method, func(t *testing.T) {
|
|
router := NewPatRouter()
|
|
err := router.Handle(test.method, test.path, nil)
|
|
assert.Error(t, ErrInvalidMethod, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPatRouterNotFound(t *testing.T) {
|
|
var notFound bool
|
|
router := NewPatRouter()
|
|
router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
notFound = true
|
|
}))
|
|
router.Handle(http.MethodGet, "/a/b", nil)
|
|
r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
|
|
w := new(mockedResponseWriter)
|
|
router.ServeHTTP(w, r)
|
|
assert.True(t, notFound)
|
|
}
|
|
|
|
func TestPatRouter(t *testing.T) {
|
|
tests := []struct {
|
|
method string
|
|
path string
|
|
expect bool
|
|
code int
|
|
err error
|
|
}{
|
|
// we don't explicitly set status code, framework will do it.
|
|
{http.MethodGet, "/a/b", true, 0, nil},
|
|
{http.MethodGet, "/a/b/", true, 0, nil},
|
|
{http.MethodGet, "/a/b?a=b", true, 0, nil},
|
|
{http.MethodGet, "/a/b/?a=b", true, 0, nil},
|
|
{http.MethodGet, "/a/b/c?a=b", true, 0, nil},
|
|
{http.MethodGet, "/b/d", false, http.StatusNotFound, nil},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.method+":"+test.path, func(t *testing.T) {
|
|
routed := false
|
|
router := NewPatRouter()
|
|
err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
routed = true
|
|
assert.Equal(t, 1, len(context.Vars(r)))
|
|
}))
|
|
assert.Nil(t, err)
|
|
err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
routed = true
|
|
assert.Nil(t, context.Vars(r))
|
|
}))
|
|
assert.Nil(t, err)
|
|
err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
routed = true
|
|
}))
|
|
assert.Nil(t, err)
|
|
|
|
w := new(mockedResponseWriter)
|
|
r, _ := http.NewRequest(test.method, test.path, nil)
|
|
router.ServeHTTP(w, r)
|
|
assert.Equal(t, test.expect, routed)
|
|
assert.Equal(t, test.code, w.code)
|
|
|
|
if test.code == 0 {
|
|
r, _ = http.NewRequest(http.MethodPut, test.path, nil)
|
|
router.ServeHTTP(w, r)
|
|
assert.Equal(t, http.StatusMethodNotAllowed, w.code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func BenchmarkPatRouter(b *testing.B) {
|
|
b.ReportAllocs()
|
|
|
|
router := NewPatRouter()
|
|
router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
}))
|
|
w := &mockedResponseWriter{}
|
|
r, _ := http.NewRequest(http.MethodGet, "/api/a/b", nil)
|
|
for i := 0; i < b.N; i++ {
|
|
router.ServeHTTP(w, r)
|
|
}
|
|
}
|