diff --git a/rest/server.go b/rest/server.go index 4c7a30ff..f3104344 100644 --- a/rest/server.go +++ b/rest/server.go @@ -307,3 +307,8 @@ func newCorsRouter(router httpx.Router, headerFn func(http.Header), origins ...s func (c *corsRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.middleware(c.Router.ServeHTTP)(w, r) } + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.ngin.bindRoutes(s.router) + s.router.ServeHTTP(w, r) +} diff --git a/rest/server_test.go b/rest/server_test.go index ad205552..bd215fae 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -535,3 +535,91 @@ func TestServer_WithCors(t *testing.T) { cr.ServeHTTP(httptest.NewRecorder(), req) assert.Equal(t, int32(0), atomic.LoadInt32(&called)) } + +func TestServer_ServeHTTP(t *testing.T) { + const configYaml = ` +Name: foo +Port: 54321 +` + + var cnf RestConf + assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf)) + + svr, err := NewServer(cnf) + assert.Nil(t, err) + + svr.AddRoutes([]Route{ + { + Method: http.MethodGet, + Path: "/foo", + Handler: func(writer http.ResponseWriter, request *http.Request) { + _, _ = writer.Write([]byte("succeed")) + writer.WriteHeader(http.StatusOK) + }, + }, + { + Method: http.MethodGet, + Path: "/bar", + Handler: func(writer http.ResponseWriter, request *http.Request) { + _, _ = writer.Write([]byte("succeed")) + writer.WriteHeader(http.StatusOK) + }, + }, + { + Method: http.MethodGet, + Path: "/user/:name", + Handler: func(writer http.ResponseWriter, request *http.Request) { + + var userInfo struct { + Name string `path:"name"` + } + + err := httpx.Parse(request, &userInfo) + if err != nil { + _, _ = writer.Write([]byte("failed")) + writer.WriteHeader(http.StatusBadRequest) + return + } + + _, _ = writer.Write([]byte("succeed")) + writer.WriteHeader(http.StatusOK) + }, + }, + }) + + testCase := []struct { + name string + path string + code int + }{ + { + name: "URI : /foo", + path: "/foo", + code: http.StatusOK, + }, + { + name: "URI : /bar", + path: "/bar", + code: http.StatusOK, + }, + { + name: "URI : undefined path", + path: "/test", + code: http.StatusNotFound, + }, + { + name: "URI : /user/:name", + path: "/user/abc", + code: http.StatusOK, + }, + } + + for _, test := range testCase { + t.Run(test.name, func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", test.path, nil) + svr.ServeHTTP(w, req) + assert.Equal(t, test.code, w.Code) + }) + } +}