diff --git a/rest/httpc/internal/loginterceptor_test.go b/rest/httpc/internal/loginterceptor_test.go index 108dfdee..39d445d3 100644 --- a/rest/httpc/internal/loginterceptor_test.go +++ b/rest/httpc/internal/loginterceptor_test.go @@ -11,6 +11,7 @@ import ( func TestLogInterceptor(t *testing.T) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { })) + defer svr.Close() req, err := http.NewRequest(http.MethodGet, svr.URL, nil) assert.Nil(t, err) req, handler := LogInterceptor(req) @@ -24,6 +25,7 @@ func TestLogInterceptorServerError(t *testing.T) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) + defer svr.Close() req, err := http.NewRequest(http.MethodGet, svr.URL, nil) assert.Nil(t, err) req, handler := LogInterceptor(req) diff --git a/rest/httpc/requests_test.go b/rest/httpc/requests_test.go index 6696d20f..0deeff2a 100644 --- a/rest/httpc/requests_test.go +++ b/rest/httpc/requests_test.go @@ -11,6 +11,7 @@ import ( func TestDo(t *testing.T) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { })) + defer svr.Close() _, err := Get("foo", "tcp://bad request") assert.NotNil(t, err) resp, err := Get("foo", svr.URL) @@ -20,6 +21,7 @@ func TestDo(t *testing.T) { func TestDoNotFound(t *testing.T) { svr := httptest.NewServer(http.NotFoundHandler()) + defer svr.Close() _, err := Post("foo", "tcp://bad request", "application/json", nil) assert.NotNil(t, err) resp, err := Post("foo", svr.URL, "application/json", nil) @@ -29,6 +31,7 @@ func TestDoNotFound(t *testing.T) { func TestDoMoved(t *testing.T) { svr := httptest.NewServer(http.RedirectHandler("/foo", http.StatusMovedPermanently)) + defer svr.Close() req, err := http.NewRequest(http.MethodGet, svr.URL, nil) assert.Nil(t, err) _, err = Do("foo", req) diff --git a/rest/httpc/responses.go b/rest/httpc/responses.go new file mode 100644 index 00000000..eb571e53 --- /dev/null +++ b/rest/httpc/responses.go @@ -0,0 +1,33 @@ +package httpc + +import ( + "net/http" + "strings" + + "github.com/zeromicro/go-zero/core/mapping" + "github.com/zeromicro/go-zero/rest/internal/encoding" +) + +func Parse(resp *http.Response, val interface{}) error { + if err := ParseHeaders(resp, val); err != nil { + return err + } + + return ParseJsonBody(resp, val) +} + +func ParseHeaders(resp *http.Response, val interface{}) error { + return encoding.ParseHeaders(resp.Header, val) +} + +func ParseJsonBody(resp *http.Response, val interface{}) error { + if withJsonBody(resp) { + return mapping.UnmarshalJsonReader(resp.Body, val) + } + + return mapping.UnmarshalJsonMap(nil, val) +} + +func withJsonBody(r *http.Response) bool { + return r.ContentLength > 0 && strings.Contains(r.Header.Get(contentType), applicationJson) +} diff --git a/rest/httpc/responses_test.go b/rest/httpc/responses_test.go new file mode 100644 index 00000000..13e4acc6 --- /dev/null +++ b/rest/httpc/responses_test.go @@ -0,0 +1,58 @@ +package httpc + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + var val struct { + Foo string `header:"foo"` + Name string `json:"name"` + Value int `json:"value"` + } + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("foo", "bar") + w.Header().Set(contentType, applicationJson) + w.Write([]byte(`{"name":"kevin","value":100}`)) + })) + defer svr.Close() + resp, err := Get("foo", svr.URL) + assert.Nil(t, err) + assert.Nil(t, Parse(resp, &val)) + assert.Equal(t, "bar", val.Foo) + assert.Equal(t, "kevin", val.Name) + assert.Equal(t, 100, val.Value) +} + +func TestParseHeaderError(t *testing.T) { + var val struct { + Foo int `header:"foo"` + } + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("foo", "bar") + w.Header().Set(contentType, applicationJson) + })) + defer svr.Close() + resp, err := Get("foo", svr.URL) + assert.Nil(t, err) + assert.NotNil(t, Parse(resp, &val)) +} + +func TestParseNoBody(t *testing.T) { + var val struct { + Foo string `header:"foo"` + } + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("foo", "bar") + w.Header().Set(contentType, applicationJson) + })) + defer svr.Close() + resp, err := Get("foo", svr.URL) + assert.Nil(t, err) + assert.Nil(t, Parse(resp, &val)) + assert.Equal(t, "bar", val.Foo) +} diff --git a/rest/httpc/service.go b/rest/httpc/service.go index db17b280..8297a3df 100644 --- a/rest/httpc/service.go +++ b/rest/httpc/service.go @@ -9,9 +9,6 @@ import ( "github.com/zeromicro/go-zero/rest/httpc/internal" ) -// ContentType means Content-Type. -const ContentType = "Content-Type" - var interceptors = []internal.Interceptor{ internal.LogInterceptor, } @@ -86,13 +83,13 @@ func (s namedService) Get(url string) (*http.Response, error) { } // Post sends an HTTP POST request to the service. -func (s namedService) Post(url, contentType string, body io.Reader) (*http.Response, error) { +func (s namedService) Post(url, ctype string, body io.Reader) (*http.Response, error) { r, err := http.NewRequest(http.MethodPost, url, body) if err != nil { return nil, err } - r.Header.Set(ContentType, contentType) + r.Header.Set(contentType, ctype) return s.Do(r) } diff --git a/rest/httpc/service_test.go b/rest/httpc/service_test.go index f619fe76..f92177a4 100644 --- a/rest/httpc/service_test.go +++ b/rest/httpc/service_test.go @@ -10,6 +10,7 @@ import ( func TestNamedService_Do(t *testing.T) { svr := httptest.NewServer(http.RedirectHandler("/foo", http.StatusMovedPermanently)) + defer svr.Close() req, err := http.NewRequest(http.MethodGet, svr.URL, nil) assert.Nil(t, err) service := NewService("foo") @@ -22,6 +23,7 @@ func TestNamedService_Get(t *testing.T) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("foo", r.Header.Get("foo")) })) + defer svr.Close() service := NewService("foo", func(r *http.Request) *http.Request { r.Header.Set("foo", "bar") return r @@ -34,6 +36,7 @@ func TestNamedService_Get(t *testing.T) { func TestNamedService_Post(t *testing.T) { svr := httptest.NewServer(http.NotFoundHandler()) + defer svr.Close() service := NewService("foo") _, err := service.Post("tcp://bad request", "application/json", nil) assert.NotNil(t, err) diff --git a/rest/httpc/vars.go b/rest/httpc/vars.go new file mode 100644 index 00000000..32d68da9 --- /dev/null +++ b/rest/httpc/vars.go @@ -0,0 +1,6 @@ +package httpc + +const ( + contentType = "Content-Type" + applicationJson = "application/json" +) diff --git a/rest/httpx/requests.go b/rest/httpx/requests.go index 46c879bb..4b3da146 100644 --- a/rest/httpx/requests.go +++ b/rest/httpx/requests.go @@ -3,17 +3,16 @@ package httpx import ( "io" "net/http" - "net/textproto" "strings" "github.com/zeromicro/go-zero/core/mapping" + "github.com/zeromicro/go-zero/rest/internal/encoding" "github.com/zeromicro/go-zero/rest/pathvar" ) const ( formKey = "form" pathKey = "path" - headerKey = "header" maxMemory = 32 << 20 // 32MB maxBodyLen = 8 << 20 // 8MB separator = ";" @@ -21,10 +20,8 @@ const ( ) var ( - formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues()) - pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues()) - headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(), - mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey)) + formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues()) + pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues()) ) // Parse parses the request. @@ -46,16 +43,7 @@ func Parse(r *http.Request, v interface{}) error { // ParseHeaders parses the headers request. func ParseHeaders(r *http.Request, v interface{}) error { - m := map[string]interface{}{} - for k, v := range r.Header { - if len(v) == 1 { - m[k] = v[0] - } else { - m[k] = v - } - } - - return headerUnmarshaler.Unmarshal(m, v) + return encoding.ParseHeaders(r.Header, v) } // ParseForm parses the form request. diff --git a/rest/internal/encoding/parser.go b/rest/internal/encoding/parser.go new file mode 100644 index 00000000..b9bfec05 --- /dev/null +++ b/rest/internal/encoding/parser.go @@ -0,0 +1,27 @@ +package encoding + +import ( + "net/http" + "net/textproto" + + "github.com/zeromicro/go-zero/core/mapping" +) + +const headerKey = "header" + +var headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(), + mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey)) + +// ParseHeaders parses the headers request. +func ParseHeaders(header http.Header, v interface{}) error { + m := map[string]interface{}{} + for k, v := range header { + if len(v) == 1 { + m[k] = v[0] + } else { + m[k] = v + } + } + + return headerUnmarshaler.Unmarshal(m, v) +} diff --git a/rest/internal/encoding/parser_test.go b/rest/internal/encoding/parser_test.go new file mode 100644 index 00000000..70e001f4 --- /dev/null +++ b/rest/internal/encoding/parser_test.go @@ -0,0 +1,40 @@ +package encoding + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseHeaders(t *testing.T) { + var val struct { + Foo string `header:"foo"` + Baz int `header:"baz"` + Qux bool `header:"qux,default=true"` + } + r := httptest.NewRequest(http.MethodGet, "/any", nil) + r.Header.Set("foo", "bar") + r.Header.Set("baz", "1") + assert.Nil(t, ParseHeaders(r.Header, &val)) + assert.Equal(t, "bar", val.Foo) + assert.Equal(t, 1, val.Baz) + assert.True(t, val.Qux) +} + +func TestParseHeadersMulti(t *testing.T) { + var val struct { + Foo []string `header:"foo"` + Baz int `header:"baz"` + Qux bool `header:"qux,default=true"` + } + r := httptest.NewRequest(http.MethodGet, "/any", nil) + r.Header.Set("foo", "bar") + r.Header.Add("foo", "bar1") + r.Header.Set("baz", "1") + assert.Nil(t, ParseHeaders(r.Header, &val)) + assert.Equal(t, []string{"bar", "bar1"}, val.Foo) + assert.Equal(t, 1, val.Baz) + assert.True(t, val.Qux) +}