feat: treat client closed requests as code 499 (#1350)

* feat: treat client closed requests as code 499

* chore: add comments
master
Kevin Wan 3 years ago committed by GitHub
parent 2cdf5e7395
commit 4ba2ff7cdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,19 +1,187 @@
package handler
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"path"
"runtime"
"strings"
"sync"
"time"
"github.com/tal-tech/go-zero/rest/internal"
)
const reason = "Request Timeout"
const (
statusClientClosedRequest = 499
reason = "Request Timeout"
)
// TimeoutHandler returns the handler with given timeout.
// If client closed request, code 499 will be logged.
// Notice: even if canceled in server side, 499 will be logged as well.
func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
if duration > 0 {
return http.TimeoutHandler(next, duration, reason)
return &timeoutHandler{
handler: next,
dt: duration,
}
}
return next
}
}
// timeoutHandler is the handler that controls the request timeout.
// Why we implement it on our own, because the stdlib implementation
// treats the ClientClosedRequest as http.StatusServiceUnavailable.
// And we write the codes in logs as code 499, which is defined by nginx.
type timeoutHandler struct {
handler http.Handler
dt time.Duration
}
func (h *timeoutHandler) errorBody() string {
return reason
}
func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt)
defer cancelCtx()
r = r.WithContext(ctx)
done := make(chan struct{})
tw := &timeoutWriter{
w: w,
h: make(http.Header),
req: r,
}
panicChan := make(chan interface{}, 1)
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
h.handler.ServeHTTP(tw, r)
close(done)
}()
select {
case p := <-panicChan:
panic(p)
case <-done:
tw.mu.Lock()
defer tw.mu.Unlock()
dst := w.Header()
for k, vv := range tw.h {
dst[k] = vv
}
if !tw.wroteHeader {
tw.code = http.StatusOK
}
w.WriteHeader(tw.code)
w.Write(tw.wbuf.Bytes())
case <-ctx.Done():
tw.mu.Lock()
defer tw.mu.Unlock()
if errors.Is(ctx.Err(), context.Canceled) {
w.WriteHeader(statusClientClosedRequest)
} else {
w.WriteHeader(http.StatusServiceUnavailable)
}
io.WriteString(w, h.errorBody())
tw.timedOut = true
}
}
type timeoutWriter struct {
w http.ResponseWriter
h http.Header
wbuf bytes.Buffer
req *http.Request
mu sync.Mutex
timedOut bool
wroteHeader bool
code int
}
var _ http.Pusher = (*timeoutWriter)(nil)
// Push implements the Pusher interface.
func (tw *timeoutWriter) Push(target string, opts *http.PushOptions) error {
if pusher, ok := tw.w.(http.Pusher); ok {
return pusher.Push(target, opts)
}
return http.ErrNotSupported
}
func (tw *timeoutWriter) Header() http.Header { return tw.h }
func (tw *timeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
return 0, http.ErrHandlerTimeout
}
if !tw.wroteHeader {
tw.writeHeaderLocked(http.StatusOK)
}
return tw.wbuf.Write(p)
}
func (tw *timeoutWriter) writeHeaderLocked(code int) {
checkWriteHeaderCode(code)
switch {
case tw.timedOut:
return
case tw.wroteHeader:
if tw.req != nil {
caller := relevantCaller()
internal.Errorf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)",
caller.Function, path.Base(caller.File), caller.Line)
}
default:
tw.wroteHeader = true
tw.code = code
}
}
func (tw *timeoutWriter) WriteHeader(code int) {
tw.mu.Lock()
defer tw.mu.Unlock()
tw.writeHeaderLocked(code)
}
func checkWriteHeaderCode(code int) {
if code < 100 || code > 599 {
panic(fmt.Sprintf("invalid WriteHeader code %v", code))
}
}
// relevantCaller searches the call stack for the first function outside of net/http.
// The purpose of this function is to provide more helpful error messages.
func relevantCaller() runtime.Frame {
pc := make([]uintptr, 16)
n := runtime.Callers(1, pc)
frames := runtime.CallersFrames(pc[:n])
var frame runtime.Frame
for {
frame, more := frames.Next()
if !strings.HasPrefix(frame.Function, "net/http.") {
return frame
}
if !more {
break
}
}
return frame
}

@ -1,6 +1,7 @@
package handler
import (
"context"
"io/ioutil"
"log"
"net/http"
@ -39,6 +40,20 @@ func TestWithinTimeout(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.Code)
}
func TestWithTimeoutTimedout(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Millisecond)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Millisecond * 10)
w.Write([]byte(`foo`))
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
}
func TestWithoutTimeout(t *testing.T) {
timeoutHandler := TimeoutHandler(0)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -50,3 +65,91 @@ func TestWithoutTimeout(t *testing.T) {
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
func TestTimeoutPanic(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Minute)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("foo")
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
assert.Panics(t, func() {
handler.ServeHTTP(resp, req)
})
}
func TestTimeoutWroteHeaderTwice(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Minute)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`hello`))
w.Header().Set("foo", "bar")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
func TestTimeoutWriteBadCode(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Minute)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(1000)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
assert.Panics(t, func() {
handler.ServeHTTP(resp, req)
})
}
func TestTimeoutClientClosed(t *testing.T) {
timeoutHandler := TimeoutHandler(time.Minute)
handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(1000)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
ctx, cancel := context.WithCancel(context.Background())
req = req.WithContext(ctx)
cancel()
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, statusClientClosedRequest, resp.Code)
}
func TestTimeoutPusher(t *testing.T) {
handler := &timeoutWriter{
w: mockedPusher{},
}
assert.Panics(t, func() {
handler.Push("any", nil)
})
handler = &timeoutWriter{
w: httptest.NewRecorder(),
}
assert.Equal(t, http.ErrNotSupported, handler.Push("any", nil))
}
type mockedPusher struct{}
func (m mockedPusher) Header() http.Header {
panic("implement me")
}
func (m mockedPusher) Write(bytes []byte) (int, error) {
panic("implement me")
}
func (m mockedPusher) WriteHeader(statusCode int) {
panic("implement me")
}
func (m mockedPusher) Push(target string, opts *http.PushOptions) error {
panic("implement me")
}

Loading…
Cancel
Save