diff --git a/rest/handler/timeouthandler.go b/rest/handler/timeouthandler.go index bb53acb8..0e81c161 100644 --- a/rest/handler/timeouthandler.go +++ b/rest/handler/timeouthandler.go @@ -1,19 +1,187 @@ package handler import ( + "bytes" + "context" + "errors" + "fmt" + "io" "net/http" + "path" + "runtime" + "strings" + "sync" "time" + + "github.com/tal-tech/go-zero/rest/internal" ) -const reason = "Request Timeout" +const ( + statusClientClosedRequest = 499 + reason = "Request Timeout" +) // TimeoutHandler returns the handler with given timeout. +// If client closed request, code 499 will be logged. +// Notice: even if canceled in server side, 499 will be logged as well. func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { if duration > 0 { - return http.TimeoutHandler(next, duration, reason) + return &timeoutHandler{ + handler: next, + dt: duration, + } } return next } } + +// timeoutHandler is the handler that controls the request timeout. +// Why we implement it on our own, because the stdlib implementation +// treats the ClientClosedRequest as http.StatusServiceUnavailable. +// And we write the codes in logs as code 499, which is defined by nginx. +type timeoutHandler struct { + handler http.Handler + dt time.Duration +} + +func (h *timeoutHandler) errorBody() string { + return reason +} + +func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt) + defer cancelCtx() + + r = r.WithContext(ctx) + done := make(chan struct{}) + tw := &timeoutWriter{ + w: w, + h: make(http.Header), + req: r, + } + panicChan := make(chan interface{}, 1) + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + h.handler.ServeHTTP(tw, r) + close(done) + }() + select { + case p := <-panicChan: + panic(p) + case <-done: + tw.mu.Lock() + defer tw.mu.Unlock() + dst := w.Header() + for k, vv := range tw.h { + dst[k] = vv + } + if !tw.wroteHeader { + tw.code = http.StatusOK + } + w.WriteHeader(tw.code) + w.Write(tw.wbuf.Bytes()) + case <-ctx.Done(): + tw.mu.Lock() + defer tw.mu.Unlock() + if errors.Is(ctx.Err(), context.Canceled) { + w.WriteHeader(statusClientClosedRequest) + } else { + w.WriteHeader(http.StatusServiceUnavailable) + } + io.WriteString(w, h.errorBody()) + tw.timedOut = true + } +} + +type timeoutWriter struct { + w http.ResponseWriter + h http.Header + wbuf bytes.Buffer + req *http.Request + + mu sync.Mutex + timedOut bool + wroteHeader bool + code int +} + +var _ http.Pusher = (*timeoutWriter)(nil) + +// Push implements the Pusher interface. +func (tw *timeoutWriter) Push(target string, opts *http.PushOptions) error { + if pusher, ok := tw.w.(http.Pusher); ok { + return pusher.Push(target, opts) + } + return http.ErrNotSupported +} + +func (tw *timeoutWriter) Header() http.Header { return tw.h } + +func (tw *timeoutWriter) Write(p []byte) (int, error) { + tw.mu.Lock() + defer tw.mu.Unlock() + + if tw.timedOut { + return 0, http.ErrHandlerTimeout + } + + if !tw.wroteHeader { + tw.writeHeaderLocked(http.StatusOK) + } + return tw.wbuf.Write(p) +} + +func (tw *timeoutWriter) writeHeaderLocked(code int) { + checkWriteHeaderCode(code) + + switch { + case tw.timedOut: + return + case tw.wroteHeader: + if tw.req != nil { + caller := relevantCaller() + internal.Errorf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)", + caller.Function, path.Base(caller.File), caller.Line) + } + default: + tw.wroteHeader = true + tw.code = code + } +} + +func (tw *timeoutWriter) WriteHeader(code int) { + tw.mu.Lock() + defer tw.mu.Unlock() + tw.writeHeaderLocked(code) +} + +func checkWriteHeaderCode(code int) { + if code < 100 || code > 599 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + +// relevantCaller searches the call stack for the first function outside of net/http. +// The purpose of this function is to provide more helpful error messages. +func relevantCaller() runtime.Frame { + pc := make([]uintptr, 16) + n := runtime.Callers(1, pc) + frames := runtime.CallersFrames(pc[:n]) + var frame runtime.Frame + for { + frame, more := frames.Next() + if !strings.HasPrefix(frame.Function, "net/http.") { + return frame + } + if !more { + break + } + } + return frame +} diff --git a/rest/handler/timeouthandler_test.go b/rest/handler/timeouthandler_test.go index e1a7ecd7..74c70687 100644 --- a/rest/handler/timeouthandler_test.go +++ b/rest/handler/timeouthandler_test.go @@ -1,6 +1,7 @@ package handler import ( + "context" "io/ioutil" "log" "net/http" @@ -39,6 +40,20 @@ func TestWithinTimeout(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Code) } +func TestWithTimeoutTimedout(t *testing.T) { + timeoutHandler := TimeoutHandler(time.Millisecond) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Millisecond * 10) + w.Write([]byte(`foo`)) + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) +} + func TestWithoutTimeout(t *testing.T) { timeoutHandler := TimeoutHandler(0) handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -50,3 +65,91 @@ func TestWithoutTimeout(t *testing.T) { handler.ServeHTTP(resp, req) assert.Equal(t, http.StatusOK, resp.Code) } + +func TestTimeoutPanic(t *testing.T) { + timeoutHandler := TimeoutHandler(time.Minute) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("foo") + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + assert.Panics(t, func() { + handler.ServeHTTP(resp, req) + }) +} + +func TestTimeoutWroteHeaderTwice(t *testing.T) { + timeoutHandler := TimeoutHandler(time.Minute) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`hello`)) + w.Header().Set("foo", "bar") + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} + +func TestTimeoutWriteBadCode(t *testing.T) { + timeoutHandler := TimeoutHandler(time.Minute) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(1000) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + assert.Panics(t, func() { + handler.ServeHTTP(resp, req) + }) +} + +func TestTimeoutClientClosed(t *testing.T) { + timeoutHandler := TimeoutHandler(time.Minute) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(1000) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + ctx, cancel := context.WithCancel(context.Background()) + req = req.WithContext(ctx) + cancel() + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, statusClientClosedRequest, resp.Code) +} + +func TestTimeoutPusher(t *testing.T) { + handler := &timeoutWriter{ + w: mockedPusher{}, + } + + assert.Panics(t, func() { + handler.Push("any", nil) + }) + + handler = &timeoutWriter{ + w: httptest.NewRecorder(), + } + assert.Equal(t, http.ErrNotSupported, handler.Push("any", nil)) +} + +type mockedPusher struct{} + +func (m mockedPusher) Header() http.Header { + panic("implement me") +} + +func (m mockedPusher) Write(bytes []byte) (int, error) { + panic("implement me") +} + +func (m mockedPusher) WriteHeader(statusCode int) { + panic("implement me") +} + +func (m mockedPusher) Push(target string, opts *http.PushOptions) error { + panic("implement me") +}