diff --git a/rest/httpx/responses.go b/rest/httpx/responses.go index 9ec67e5b..28b4ba1b 100644 --- a/rest/httpx/responses.go +++ b/rest/httpx/responses.go @@ -61,12 +61,16 @@ func SetErrorHandler(handler func(error) (int, interface{})) { // WriteJson writes v as json string into w with code. func WriteJson(w http.ResponseWriter, code int, v interface{}) { + bs, err := json.Marshal(v) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set(ContentType, ApplicationJson) w.WriteHeader(code) - if bs, err := json.Marshal(v); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } else if n, err := w.Write(bs); err != nil { + if n, err := w.Write(bs); err != nil { // http.ErrHandlerTimeout has been handled by http.TimeoutHandler, // so it's ignored here. if err != http.ErrHandlerTimeout { diff --git a/rest/httpx/responses_test.go b/rest/httpx/responses_test.go index c886ae2e..9a553cd0 100644 --- a/rest/httpx/responses_test.go +++ b/rest/httpx/responses_test.go @@ -146,6 +146,16 @@ func TestWriteJsonLessWritten(t *testing.T) { assert.Equal(t, http.StatusOK, w.code) } +func TestWriteJsonMarshalFailed(t *testing.T) { + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + WriteJson(&w, http.StatusOK, map[string]interface{}{ + "Data": complex(0, 0), + }) + assert.Equal(t, http.StatusInternalServerError, w.code) +} + type tracedResponseWriter struct { headers map[string][]string builder strings.Builder @@ -153,6 +163,7 @@ type tracedResponseWriter struct { code int lessWritten bool timeout bool + wroteHeader bool } func (w *tracedResponseWriter) Header() http.Header { @@ -174,5 +185,9 @@ func (w *tracedResponseWriter) Write(bytes []byte) (n int, err error) { } func (w *tracedResponseWriter) WriteHeader(code int) { + if w.wroteHeader { + return + } + w.wroteHeader = true w.code = code }