chore: avoid nested WithCodeResponseWriter (#3406)

master
Kevin Wan 1 year ago committed by GitHub
parent e8c1e6e09b
commit 13cdbdc98b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -28,7 +28,7 @@ func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handle
return return
} }
cw := &response.WithCodeResponseWriter{Writer: w} cw := response.NewWithCodeResponseWriter(w)
defer func() { defer func() {
if cw.Code < http.StatusInternalServerError { if cw.Code < http.StatusInternalServerError {
promise.Accept() promise.Accept()

@ -36,14 +36,11 @@ func LogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer() timer := utils.NewElapsedTimer()
logs := new(internal.LogCollector) logs := new(internal.LogCollector)
lrw := response.WithCodeResponseWriter{ lrw := response.NewWithCodeResponseWriter(w)
Writer: w,
Code: http.StatusOK,
}
var dup io.ReadCloser var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body) r.Body, dup = iox.DupReadCloser(r.Body)
next.ServeHTTP(&lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs))) next.ServeHTTP(lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
r.Body = dup r.Body = dup
logBrief(r, lrw.Code, timer, logs) logBrief(r, lrw.Code, timer, logs)
}) })
@ -54,7 +51,8 @@ type detailLoggedResponseWriter struct {
buf *bytes.Buffer buf *bytes.Buffer
} }
func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter, buf *bytes.Buffer) *detailLoggedResponseWriter { func newDetailLoggedResponseWriter(writer *response.WithCodeResponseWriter,
buf *bytes.Buffer) *detailLoggedResponseWriter {
return &detailLoggedResponseWriter{ return &detailLoggedResponseWriter{
writer: writer, writer: writer,
buf: buf, buf: buf,
@ -93,10 +91,8 @@ func DetailedLogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timer := utils.NewElapsedTimer() timer := utils.NewElapsedTimer()
var buf bytes.Buffer var buf bytes.Buffer
lrw := newDetailLoggedResponseWriter(&response.WithCodeResponseWriter{ rw := response.NewWithCodeResponseWriter(w)
Writer: w, lrw := newDetailLoggedResponseWriter(rw, &buf)
Code: http.StatusOK,
}, &buf)
var dup io.ReadCloser var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body) r.Body, dup = iox.DupReadCloser(r.Body)

@ -2,6 +2,7 @@ package handler
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -88,18 +89,23 @@ func TestLogHandlerSlow(t *testing.T) {
func TestDetailedLogHandler_Hijack(t *testing.T) { func TestDetailedLogHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
writer := &detailLoggedResponseWriter{ writer := &detailLoggedResponseWriter{
writer: &response.WithCodeResponseWriter{ writer: response.NewWithCodeResponseWriter(resp),
Writer: resp,
},
} }
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
_, _, _ = writer.Hijack() _, _, _ = writer.Hijack()
}) })
writer = &detailLoggedResponseWriter{ writer = &detailLoggedResponseWriter{
writer: &response.WithCodeResponseWriter{ writer: response.NewWithCodeResponseWriter(resp),
Writer: mockedHijackable{resp}, }
}, assert.NotPanics(t, func() {
_, _, _ = writer.Hijack()
})
writer = &detailLoggedResponseWriter{
writer: response.NewWithCodeResponseWriter(mockedHijackable{
ResponseRecorder: resp,
}),
} }
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
_, _, _ = writer.Hijack() _, _, _ = writer.Hijack()
@ -133,6 +139,13 @@ func TestWrapStatusCodeWithColor(t *testing.T) {
assert.Equal(t, "503", wrapStatusCode(http.StatusServiceUnavailable)) assert.Equal(t, "503", wrapStatusCode(http.StatusServiceUnavailable))
} }
func TestDumpRequest(t *testing.T) {
const errMsg = "error"
r := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
r.Body = mockedReadCloser{errMsg: errMsg}
assert.Equal(t, errMsg, dumpRequest(r))
}
func BenchmarkLogHandler(b *testing.B) { func BenchmarkLogHandler(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
@ -146,3 +159,15 @@ func BenchmarkLogHandler(b *testing.B) {
handler.ServeHTTP(resp, req) handler.ServeHTTP(resp, req)
} }
} }
type mockedReadCloser struct {
errMsg string
}
func (m mockedReadCloser) Read(p []byte) (n int, err error) {
return 0, errors.New(m.errMsg)
}
func (m mockedReadCloser) Close() error {
return nil
}

@ -35,7 +35,7 @@ func PrometheusHandler(path, method string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := timex.Now() startTime := timex.Now()
cw := &response.WithCodeResponseWriter{Writer: w} cw := response.NewWithCodeResponseWriter(w)
defer func() { defer func() {
metricServerReqDur.Observe(timex.Since(startTime).Milliseconds(), path, method) metricServerReqDur.Observe(timex.Since(startTime).Milliseconds(), path, method)
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code), method) metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code), method)

@ -41,7 +41,7 @@ func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Hand
return return
} }
cw := &response.WithCodeResponseWriter{Writer: w} cw := response.NewWithCodeResponseWriter(w)
defer func() { defer func() {
if cw.Code == http.StatusServiceUnavailable { if cw.Code == http.StatusServiceUnavailable {
promise.Fail() promise.Fail()

@ -67,9 +67,10 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(ctx) r = r.WithContext(ctx)
done := make(chan struct{}) done := make(chan struct{})
tw := &timeoutWriter{ tw := &timeoutWriter{
w: w, w: w,
h: make(http.Header), h: make(http.Header),
req: r, req: r,
code: http.StatusOK,
} }
panicChan := make(chan any, 1) panicChan := make(chan any, 1)
go func() { go func() {
@ -91,10 +92,12 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for k, vv := range tw.h { for k, vv := range tw.h {
dst[k] = vv dst[k] = vv
} }
if !tw.wroteHeader {
tw.code = http.StatusOK // We don't need to write header 200, because it's written by default.
// If we write it again, it will cause a warning: `http: superfluous response.WriteHeader call`.
if tw.code != http.StatusOK {
w.WriteHeader(tw.code)
} }
w.WriteHeader(tw.code)
w.Write(tw.wbuf.Bytes()) w.Write(tw.wbuf.Bytes())
case <-ctx.Done(): case <-ctx.Done():
tw.mu.Lock() tw.mu.Lock()

@ -100,6 +100,18 @@ func TestWithinTimeout(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.Code) assert.Equal(t, http.StatusOK, resp.Code)
} }
func TestWithinTimeoutBadCode(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Second)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusInternalServerError, resp.Code)
}
func TestWithTimeoutTimedout(t *testing.T) { func TestWithTimeoutTimedout(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Millisecond) timeoutHandler := TimeoutHandler(time.Millisecond)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -208,9 +220,7 @@ func TestTimeoutHijack(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
writer := &timeoutWriter{ writer := &timeoutWriter{
w: &response.WithCodeResponseWriter{ w: response.NewWithCodeResponseWriter(resp),
Writer: resp,
},
} }
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
@ -218,9 +228,7 @@ func TestTimeoutHijack(t *testing.T) {
}) })
writer = &timeoutWriter{ writer = &timeoutWriter{
w: &response.WithCodeResponseWriter{ w: response.NewWithCodeResponseWriter(mockedHijackable{resp}),
Writer: mockedHijackable{resp},
},
} }
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
@ -274,9 +282,7 @@ func TestTimeoutWriter_Hijack(t *testing.T) {
func TestTimeoutWroteTwice(t *testing.T) { func TestTimeoutWroteTwice(t *testing.T) {
c := logtest.NewCollector(t) c := logtest.NewCollector(t)
writer := &timeoutWriter{ writer := &timeoutWriter{
w: &response.WithCodeResponseWriter{ w: response.NewWithCodeResponseWriter(httptest.NewRecorder()),
Writer: httptest.NewRecorder(),
},
h: make(http.Header), h: make(http.Header),
req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody), req: httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody),
} }

@ -60,7 +60,7 @@ func TraceHandler(serviceName, path string, opts ...TraceOption) func(http.Handl
// convenient for tracking error messages // convenient for tracking error messages
propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header())) propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header()))
trw := &response.WithCodeResponseWriter{Writer: w, Code: http.StatusOK} trw := response.NewWithCodeResponseWriter(w)
next.ServeHTTP(trw, r.WithContext(spanCtx)) next.ServeHTTP(trw, r.WithContext(spanCtx))
span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...) span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(trw.Code)...)

@ -13,6 +13,20 @@ type WithCodeResponseWriter struct {
Code int Code int
} }
// NewWithCodeResponseWriter returns a WithCodeResponseWriter.
// If writer is already a WithCodeResponseWriter, it returns writer directly.
func NewWithCodeResponseWriter(writer http.ResponseWriter) *WithCodeResponseWriter {
switch w := writer.(type) {
case *WithCodeResponseWriter:
return w
default:
return &WithCodeResponseWriter{
Writer: writer,
Code: http.StatusOK,
}
}
}
// Flush flushes the response writer. // Flush flushes the response writer.
func (w *WithCodeResponseWriter) Flush() { func (w *WithCodeResponseWriter) Flush() {
if flusher, ok := w.Writer.(http.Flusher); ok { if flusher, ok := w.Writer.(http.Flusher); ok {

@ -11,7 +11,7 @@ import (
func TestWithCodeResponseWriter(t *testing.T) { func TestWithCodeResponseWriter(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cw := &WithCodeResponseWriter{Writer: w} cw := NewWithCodeResponseWriter(w)
cw.Header().Set("X-Test", "test") cw.Header().Set("X-Test", "test")
cw.WriteHeader(http.StatusServiceUnavailable) cw.WriteHeader(http.StatusServiceUnavailable)
@ -34,9 +34,7 @@ func TestWithCodeResponseWriter(t *testing.T) {
func TestWithCodeResponseWriter_Hijack(t *testing.T) { func TestWithCodeResponseWriter_Hijack(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
writer := &WithCodeResponseWriter{ writer := NewWithCodeResponseWriter(NewWithCodeResponseWriter(resp))
Writer: resp,
}
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
writer.Hijack() writer.Hijack()
}) })

Loading…
Cancel
Save