diff --git a/rest/handler/breakerhandler.go b/rest/handler/breakerhandler.go index e9d7243a..05b85bc8 100644 --- a/rest/handler/breakerhandler.go +++ b/rest/handler/breakerhandler.go @@ -28,7 +28,7 @@ func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handle return } - cw := &response.WithCodeResponseWriter{Writer: w} + cw := response.NewWithCodeResponseWriter(w) defer func() { if cw.Code < http.StatusInternalServerError { promise.Accept() diff --git a/rest/handler/loghandler.go b/rest/handler/loghandler.go index 179a72f9..07cbdf19 100644 --- a/rest/handler/loghandler.go +++ b/rest/handler/loghandler.go @@ -36,14 +36,11 @@ func LogHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { timer := utils.NewElapsedTimer() logs := new(internal.LogCollector) - lrw := response.WithCodeResponseWriter{ - Writer: w, - Code: http.StatusOK, - } + lrw := response.NewWithCodeResponseWriter(w) var dup io.ReadCloser r.Body, dup = iox.DupReadCloser(r.Body) - next.ServeHTTP(&lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs))) + next.ServeHTTP(lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs))) r.Body = dup logBrief(r, lrw.Code, timer, logs) }) @@ -54,7 +51,8 @@ type detailLoggedResponseWriter struct { buf *bytes.Buffer } -func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter, buf *bytes.Buffer) *detailLoggedResponseWriter { +func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter, + buf *bytes.Buffer) *detailLoggedResponseWriter { return &detailLoggedResponseWriter{ writer: writer, buf: buf, @@ -93,10 +91,8 @@ func DetailedLogHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { timer := utils.NewElapsedTimer() var buf bytes.Buffer - lrw := newDetailLoggedResponseWriter(&response.WithCodeResponseWriter{ - Writer: w, - Code: http.StatusOK, - }, &buf) + rw := response.NewWithCodeResponseWriter(w) + lrw := newDetailLoggedResponseWriter(rw, &buf) var dup io.ReadCloser r.Body, dup = iox.DupReadCloser(r.Body) diff --git a/rest/handler/loghandler_test.go b/rest/handler/loghandler_test.go index 61156e76..c5763379 100644 --- a/rest/handler/loghandler_test.go +++ b/rest/handler/loghandler_test.go @@ -2,6 +2,7 @@ package handler import ( "bytes" + "errors" "io" "net/http" "net/http/httptest" @@ -88,18 +89,23 @@ func TestLogHandlerSlow(t *testing.T) { func TestDetailedLogHandler_Hijack(t *testing.T) { resp := httptest.NewRecorder() writer := &detailLoggedResponseWriter{ - writer: &response.WithCodeResponseWriter{ - Writer: resp, - }, + writer: response.NewWithCodeResponseWriter(resp), } assert.NotPanics(t, func() { _, _, _ = writer.Hijack() }) writer = &detailLoggedResponseWriter{ - writer: &response.WithCodeResponseWriter{ - Writer: mockedHijackable{resp}, - }, + writer: response.NewWithCodeResponseWriter(resp), + } + assert.NotPanics(t, func() { + _, _, _ = writer.Hijack() + }) + + writer = &detailLoggedResponseWriter{ + writer: response.NewWithCodeResponseWriter(mockedHijackable{ + ResponseRecorder: resp, + }), } assert.NotPanics(t, func() { _, _, _ = writer.Hijack() @@ -133,6 +139,13 @@ func TestWrapStatusCodeWithColor(t *testing.T) { assert.Equal(t, "503", wrapStatusCode(http.StatusServiceUnavailable)) } +func TestDumpRequest(t *testing.T) { + const errMsg = "error" + r := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) + r.Body = mockedReadCloser{errMsg: errMsg} + assert.Equal(t, errMsg, dumpRequest(r)) +} + func BenchmarkLogHandler(b *testing.B) { b.ReportAllocs() @@ -146,3 +159,15 @@ func BenchmarkLogHandler(b *testing.B) { handler.ServeHTTP(resp, req) } } + +type mockedReadCloser struct { + errMsg string +} + +func (m mockedReadCloser) Read(p []byte) (n int, err error) { + return 0, errors.New(m.errMsg) +} + +func (m mockedReadCloser) Close() error { + return nil +} diff --git a/rest/handler/prometheushandler.go b/rest/handler/prometheushandler.go index cf0d91fc..86676d4c 100644 --- a/rest/handler/prometheushandler.go +++ b/rest/handler/prometheushandler.go @@ -35,7 +35,7 @@ func PrometheusHandler(path, method string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { startTime := timex.Now() - cw := &response.WithCodeResponseWriter{Writer: w} + cw := response.NewWithCodeResponseWriter(w) defer func() { metricServerReqDur.Observe(timex.Since(startTime).Milliseconds(), path, method) metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code), method) diff --git a/rest/handler/sheddinghandler.go b/rest/handler/sheddinghandler.go index 977824dc..80cfe122 100644 --- a/rest/handler/sheddinghandler.go +++ b/rest/handler/sheddinghandler.go @@ -41,7 +41,7 @@ func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Hand return } - cw := &response.WithCodeResponseWriter{Writer: w} + cw := response.NewWithCodeResponseWriter(w) defer func() { if cw.Code == http.StatusServiceUnavailable { promise.Fail() diff --git a/rest/handler/timeouthandler.go b/rest/handler/timeouthandler.go index c956b49e..25a98126 100644 --- a/rest/handler/timeouthandler.go +++ b/rest/handler/timeouthandler.go @@ -67,9 +67,10 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { r = r.WithContext(ctx) done := make(chan struct{}) tw := &timeoutWriter{ - w: w, - h: make(http.Header), - req: r, + w: w, + h: make(http.Header), + req: r, + code: http.StatusOK, } panicChan := make(chan any, 1) go func() { @@ -91,10 +92,12 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { for k, vv := range tw.h { dst[k] = vv } - if !tw.wroteHeader { - tw.code = http.StatusOK + + // We don't need to write header 200, because it's written by default. + // If we write it again, it will cause a warning: `http: superfluous response.WriteHeader call`. + if tw.code != http.StatusOK { + w.WriteHeader(tw.code) } - w.WriteHeader(tw.code) w.Write(tw.wbuf.Bytes()) case <-ctx.Done(): tw.mu.Lock() diff --git a/rest/handler/timeouthandler_test.go b/rest/handler/timeouthandler_test.go index a234ce09..0d96c592 100644 --- a/rest/handler/timeouthandler_test.go +++ b/rest/handler/timeouthandler_test.go @@ -100,6 +100,18 @@ func TestWithinTimeout(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Code) } +func TestWithinTimeoutBadCode(t *testing.T) { + timeoutHandler := TimeoutHandler(time.Second) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusInternalServerError, resp.Code) +} + func TestWithTimeoutTimedout(t *testing.T) { timeoutHandler := TimeoutHandler(time.Millisecond) handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -208,9 +220,7 @@ func TestTimeoutHijack(t *testing.T) { resp := httptest.NewRecorder() writer := &timeoutWriter{ - w: &response.WithCodeResponseWriter{ - Writer: resp, - }, + w: response.NewWithCodeResponseWriter(resp), } assert.NotPanics(t, func() { @@ -218,9 +228,7 @@ func TestTimeoutHijack(t *testing.T) { }) writer = &timeoutWriter{ - w: &response.WithCodeResponseWriter{ - Writer: mockedHijackable{resp}, - }, + w: response.NewWithCodeResponseWriter(mockedHijackable{resp}), } assert.NotPanics(t, func() { @@ -274,9 +282,7 @@ func TestTimeoutWriter_Hijack(t *testing.T) { func TestTimeoutWroteTwice(t *testing.T) { c := logtest.NewCollector(t) writer := &timeoutWriter{ - w: &response.WithCodeResponseWriter{ - Writer: httptest.NewRecorder(), - }, + w: response.NewWithCodeResponseWriter(httptest.NewRecorder()), h: make(http.Header), req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody), } diff --git a/rest/handler/tracehandler.go b/rest/handler/tracehandler.go index bc98e73b..7a11da46 100644 --- a/rest/handler/tracehandler.go +++ b/rest/handler/tracehandler.go @@ -60,7 +60,7 @@ func TraceHandler(serviceName, path string, opts ...TraceOption) func(http.Handl // convenient for tracking error messages propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header())) - trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK} + trw := response.NewWithCodeResponseWriter(w) next.ServeHTTP(trw, r.WithContext(spanCtx)) span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...) diff --git a/rest/internal/response/withcoderesponsewriter.go b/rest/internal/response/withcoderesponsewriter.go index 4fa9631e..2a49c79a 100644 --- a/rest/internal/response/withcoderesponsewriter.go +++ b/rest/internal/response/withcoderesponsewriter.go @@ -13,6 +13,20 @@ type WithCodeResponseWriter struct { Code int } +// NewWithCodeResponseWriter returns a WithCodeResponseWriter. +// If writer is already a WithCodeResponseWriter, it returns writer directly. +func NewWithCodeResponseWriter(writer http.ResponseWriter) *WithCodeResponseWriter { + switch w := writer.(type) { + case *WithCodeResponseWriter: + return w + default: + return &WithCodeResponseWriter{ + Writer: writer, + Code: http.StatusOK, + } + } +} + // Flush flushes the response writer. func (w *WithCodeResponseWriter) Flush() { if flusher, ok := w.Writer.(http.Flusher); ok { diff --git a/rest/internal/response/withcoderesponsewriter_test.go b/rest/internal/response/withcoderesponsewriter_test.go index c03e84c5..00485d3d 100644 --- a/rest/internal/response/withcoderesponsewriter_test.go +++ b/rest/internal/response/withcoderesponsewriter_test.go @@ -11,7 +11,7 @@ import ( func TestWithCodeResponseWriter(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - cw := &WithCodeResponseWriter{Writer: w} + cw := NewWithCodeResponseWriter(w) cw.Header().Set("X-Test", "test") cw.WriteHeader(http.StatusServiceUnavailable) @@ -34,9 +34,7 @@ func TestWithCodeResponseWriter(t *testing.T) { func TestWithCodeResponseWriter_Hijack(t *testing.T) { resp := httptest.NewRecorder() - writer := &WithCodeResponseWriter{ - Writer: resp, - } + writer := NewWithCodeResponseWriter(NewWithCodeResponseWriter(resp)) assert.NotPanics(t, func() { writer.Hijack() })