From 92b450eb11fa21c015f4d6c8b3ad060ce3bbf00b Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Mon, 18 Apr 2022 20:14:46 +0800 Subject: [PATCH] fix: ignore timeout on websocket (#1802) --- rest/handler/timeouthandler.go | 7 +++++++ rest/handler/timeouthandler_test.go | 13 +++++++++++++ 2 files changed, 20 insertions(+) diff --git a/rest/handler/timeouthandler.go b/rest/handler/timeouthandler.go index 672f44b6..ca4b4a49 100644 --- a/rest/handler/timeouthandler.go +++ b/rest/handler/timeouthandler.go @@ -20,6 +20,8 @@ import ( const ( statusClientClosedRequest = 499 reason = "Request Timeout" + headerUpgrade = "Upgrade" + valueWebsocket = "websocket" ) // TimeoutHandler returns the handler with given timeout. @@ -52,6 +54,11 @@ func (h *timeoutHandler) errorBody() string { } func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Header.Get(headerUpgrade) == valueWebsocket { + h.handler.ServeHTTP(w, r) + return + } + ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt) defer cancelCtx() diff --git a/rest/handler/timeouthandler_test.go b/rest/handler/timeouthandler_test.go index 74c70687..d4d2ad20 100644 --- a/rest/handler/timeouthandler_test.go +++ b/rest/handler/timeouthandler_test.go @@ -79,6 +79,19 @@ func TestTimeoutPanic(t *testing.T) { }) } +func TestTimeoutWebsocket(t *testing.T) { + timeoutHandler := TimeoutHandler(time.Millisecond) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Millisecond * 10) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set(headerUpgrade, valueWebsocket) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} + func TestTimeoutWroteHeaderTwice(t *testing.T) { timeoutHandler := TimeoutHandler(time.Minute) handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {