You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/ngin/internal/router/patrouter_test.go

123 lines
3.0 KiB
Go

package router
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"zero/ngin/internal/context"
)
type mockedResponseWriter struct {
code int
}
func (m *mockedResponseWriter) Header() http.Header {
return http.Header{}
}
func (m *mockedResponseWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (m *mockedResponseWriter) WriteHeader(code int) {
m.code = code
}
func TestPatRouterHandleErrors(t *testing.T) {
tests := []struct {
method string
path string
err error
}{
{"FAKE", "", ErrInvalidMethod},
{"GET", "", ErrInvalidPath},
}
for _, test := range tests {
t.Run(test.method, func(t *testing.T) {
router := NewPatRouter()
err := router.Handle(test.method, test.path, nil)
assert.Error(t, ErrInvalidMethod, err)
})
}
}
func TestPatRouterNotFound(t *testing.T) {
var notFound bool
router := NewPatRouter()
router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
notFound = true
}))
router.Handle(http.MethodGet, "/a/b", nil)
r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
w := new(mockedResponseWriter)
router.ServeHTTP(w, r)
assert.True(t, notFound)
}
func TestPatRouter(t *testing.T) {
tests := []struct {
method string
path string
expect bool
code int
err error
}{
// we don't explicitly set status code, framework will do it.
{http.MethodGet, "/a/b", true, 0, nil},
{http.MethodGet, "/a/b/", true, 0, nil},
{http.MethodGet, "/a/b?a=b", true, 0, nil},
{http.MethodGet, "/a/b/?a=b", true, 0, nil},
{http.MethodGet, "/a/b/c?a=b", true, 0, nil},
{http.MethodGet, "/b/d", false, http.StatusNotFound, nil},
}
for _, test := range tests {
t.Run(test.method+":"+test.path, func(t *testing.T) {
routed := false
router := NewPatRouter()
err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true
assert.Equal(t, 1, len(context.Vars(r)))
}))
assert.Nil(t, err)
err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true
assert.Nil(t, context.Vars(r))
}))
assert.Nil(t, err)
err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true
}))
assert.Nil(t, err)
w := new(mockedResponseWriter)
r, _ := http.NewRequest(test.method, test.path, nil)
router.ServeHTTP(w, r)
assert.Equal(t, test.expect, routed)
assert.Equal(t, test.code, w.code)
if test.code == 0 {
r, _ = http.NewRequest(http.MethodPut, test.path, nil)
router.ServeHTTP(w, r)
assert.Equal(t, http.StatusMethodNotAllowed, w.code)
}
})
}
}
func BenchmarkPatRouter(b *testing.B) {
b.ReportAllocs()
router := NewPatRouter()
router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
w := &mockedResponseWriter{}
r, _ := http.NewRequest(http.MethodGet, "/api/a/b", nil)
for i := 0; i < b.N; i++ {
router.ServeHTTP(w, r)
}
}