From 28cb2c58044950850a6facd30a96d6a8f054ff22 Mon Sep 17 00:00:00 2001 From: chen quan Date: Sat, 17 Feb 2024 15:06:45 +0800 Subject: [PATCH] feat: support sse ignore timeout (#2041) Co-authored-by: Kevin Wan --- rest/handler/timeouthandler.go | 8 ++++++-- rest/handler/timeouthandler_test.go | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/rest/handler/timeouthandler.go b/rest/handler/timeouthandler.go index 25a98126..9a1b001d 100644 --- a/rest/handler/timeouthandler.go +++ b/rest/handler/timeouthandler.go @@ -24,6 +24,8 @@ const ( reason = "Request Timeout" headerUpgrade = "Upgrade" valueWebsocket = "websocket" + headerAccept = "Accept" + valueSSE = "text/event-stream" ) // TimeoutHandler returns the handler with given timeout. @@ -56,7 +58,9 @@ func (h *timeoutHandler) errorBody() string { } func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.Header.Get(headerUpgrade) == valueWebsocket { + if r.Header.Get(headerUpgrade) == valueWebsocket || + // Server-Sent Event ignore timeout. + r.Header.Get(headerAccept) == valueSSE { h.handler.ServeHTTP(w, r) return } @@ -110,7 +114,7 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else { w.WriteHeader(http.StatusServiceUnavailable) } - io.WriteString(w, h.errorBody()) + _, _ = io.WriteString(w, h.errorBody()) }) tw.timedOut = true } diff --git a/rest/handler/timeouthandler_test.go b/rest/handler/timeouthandler_test.go index 0d96c592..ee6fdb91 100644 --- a/rest/handler/timeouthandler_test.go +++ b/rest/handler/timeouthandler_test.go @@ -156,6 +156,22 @@ func TestTimeoutPanic(t *testing.T) { }) } +func TestTimeoutSSE(t *testing.T) { + timeoutHandler := TimeoutHandler(time.Millisecond) + handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Millisecond * 10) + r.Header.Set("Content-Type", "text/event-stream") + r.Header.Set("Cache-Control", "no-cache") + r.Header.Set("Connection", "keep-alive") + r.Header.Set("Transfer-Encoding", "chunked") + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set(headerAccept, valueSSE) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} func TestTimeoutWebsocket(t *testing.T) { timeoutHandler := TimeoutHandler(time.Millisecond) handler := timeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {