From 5d00dfb9620f6f2ebd81ba14049363f88574e65d Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sun, 28 Aug 2022 15:41:02 +0800 Subject: [PATCH] fix: handle the scenarios that content-length is invalid (#2313) --- rest/httpc/responses.go | 21 ++++++-- rest/httpc/responses_test.go | 95 ++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 4 deletions(-) diff --git a/rest/httpc/responses.go b/rest/httpc/responses.go index 704d0a9d..7207cdc7 100644 --- a/rest/httpc/responses.go +++ b/rest/httpc/responses.go @@ -1,6 +1,8 @@ package httpc import ( + "bytes" + "io" "net/http" "strings" @@ -27,13 +29,24 @@ func ParseHeaders(resp *http.Response, val interface{}) error { func ParseJsonBody(resp *http.Response, val interface{}) error { defer resp.Body.Close() - if withJsonBody(resp) { - return mapping.UnmarshalJsonReader(resp.Body, val) + if isContentTypeJson(resp) { + if resp.ContentLength > 0 { + return mapping.UnmarshalJsonReader(resp.Body, val) + } + + var buf bytes.Buffer + if _, err := io.Copy(&buf, resp.Body); err != nil { + return err + } + + if buf.Len() > 0 { + return mapping.UnmarshalJsonReader(&buf, val) + } } return mapping.UnmarshalJsonMap(nil, val) } -func withJsonBody(r *http.Response) bool { - return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson) +func isContentTypeJson(r *http.Response) bool { + return strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson) } diff --git a/rest/httpc/responses_test.go b/rest/httpc/responses_test.go index 6f1a2d1e..31369ee1 100644 --- a/rest/httpc/responses_test.go +++ b/rest/httpc/responses_test.go @@ -1,6 +1,7 @@ package httpc import ( + "errors" "net/http" "net/http/httptest" "testing" @@ -83,3 +84,97 @@ func TestParseWithZeroValue(t *testing.T) { assert.Equal(t, 0, val.Foo) assert.Equal(t, 0, val.Bar) } + +func TestParseWithNegativeContentLength(t *testing.T) { + var val struct { + Bar int `json:"bar"` + } + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(header.ContentType, header.JsonContentType) + w.Write([]byte(`{"bar":0}`)) + })) + defer svr.Close() + req, err := http.NewRequest(http.MethodGet, svr.URL, nil) + assert.Nil(t, err) + + tests := []struct { + name string + length int64 + }{ + { + name: "negative", + length: -1, + }, + { + name: "zero", + length: 0, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + resp, err := DoRequest(req) + resp.ContentLength = test.length + assert.Nil(t, err) + assert.Nil(t, Parse(resp, &val)) + assert.Equal(t, 0, val.Bar) + }) + } +} + +func TestParseWithNegativeContentLengthNoBody(t *testing.T) { + var val struct{} + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(header.ContentType, header.JsonContentType) + })) + defer svr.Close() + req, err := http.NewRequest(http.MethodGet, svr.URL, nil) + assert.Nil(t, err) + + tests := []struct { + name string + length int64 + }{ + { + name: "negative", + length: -1, + }, + { + name: "zero", + length: 0, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + resp, err := DoRequest(req) + resp.ContentLength = test.length + assert.Nil(t, err) + assert.Nil(t, Parse(resp, &val)) + }) + } +} + +func TestParseJsonBody_BodyError(t *testing.T) { + var val struct{} + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(header.ContentType, header.JsonContentType) + })) + defer svr.Close() + req, err := http.NewRequest(http.MethodGet, svr.URL, nil) + assert.Nil(t, err) + + resp, err := DoRequest(req) + resp.ContentLength = -1 + resp.Body = mockedReader{} + assert.Nil(t, err) + assert.NotNil(t, Parse(resp, &val)) +} + +type mockedReader struct{} + +func (m mockedReader) Close() error { + return nil +} + +func (m mockedReader) Read(p []byte) (n int, err error) { + return 0, errors.New("dummy") +}