From 3c6951577d5a3ec15f6338ba88c8ccfefa81e2ab Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Mon, 15 Mar 2021 20:11:09 +0800 Subject: [PATCH] make hijack more stable (#565) --- rest/handler/authhandler.go | 6 ++++- rest/handler/authhandler_test.go | 30 ++++++++++++++++++++++ rest/handler/cryptionhandler.go | 6 ++++- rest/handler/cryptionhandler_test.go | 13 ++++++++++ rest/handler/loghandler.go | 17 ++++++++++++- rest/handler/loghandler_test.go | 38 ++++++++++++++++++++++++++++ 6 files changed, 107 insertions(+), 3 deletions(-) diff --git a/rest/handler/authhandler.go b/rest/handler/authhandler.go index e6ccf695..16289961 100644 --- a/rest/handler/authhandler.go +++ b/rest/handler/authhandler.go @@ -143,7 +143,11 @@ func (grw *guardedResponseWriter) Header() http.Header { // Hijack implements the http.Hijacker interface. // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. func (grw *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return grw.writer.(http.Hijacker).Hijack() + if hijacked, ok := grw.writer.(http.Hijacker); ok { + return hijacked.Hijack() + } + + return nil, nil, errors.New("server doesn't support hijacking") } func (grw *guardedResponseWriter) Write(body []byte) (int, error) { diff --git a/rest/handler/authhandler_test.go b/rest/handler/authhandler_test.go index 22ae3384..16e13d74 100644 --- a/rest/handler/authhandler_test.go +++ b/rest/handler/authhandler_test.go @@ -1,6 +1,8 @@ package handler import ( + "bufio" + "net" "net/http" "net/http/httptest" "testing" @@ -87,6 +89,26 @@ func TestAuthHandler_NilError(t *testing.T) { }) } +func TestAuthHandler_Flush(t *testing.T) { + resp := httptest.NewRecorder() + handler := newGuardedResponseWriter(resp) + handler.Flush() + assert.True(t, resp.Flushed) +} + +func TestAuthHandler_Hijack(t *testing.T) { + resp := httptest.NewRecorder() + writer := newGuardedResponseWriter(resp) + assert.NotPanics(t, func() { + writer.Hijack() + }) + + writer = newGuardedResponseWriter(mockedHijackable{resp}) + assert.NotPanics(t, func() { + writer.Hijack() + }) +} + func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) { now := time.Now().Unix() claims := make(jwt.MapClaims) @@ -101,3 +123,11 @@ func buildToken(secretKey string, payloads map[string]interface{}, seconds int64 return token.SignedString([]byte(secretKey)) } + +type mockedHijackable struct { + *httptest.ResponseRecorder +} + +func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, nil +} diff --git a/rest/handler/cryptionhandler.go b/rest/handler/cryptionhandler.go index caae3f9d..002c6190 100644 --- a/rest/handler/cryptionhandler.go +++ b/rest/handler/cryptionhandler.go @@ -99,7 +99,11 @@ func (w *cryptionResponseWriter) Header() http.Header { // Hijack implements the http.Hijacker interface. // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. func (w *cryptionResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.ResponseWriter.(http.Hijacker).Hijack() + if hijacked, ok := w.ResponseWriter.(http.Hijacker); ok { + return hijacked.Hijack() + } + + return nil, nil, errors.New("server doesn't support hijacking") } func (w *cryptionResponseWriter) Write(p []byte) (int, error) { diff --git a/rest/handler/cryptionhandler_test.go b/rest/handler/cryptionhandler_test.go index 9819f4dc..abc705dc 100644 --- a/rest/handler/cryptionhandler_test.go +++ b/rest/handler/cryptionhandler_test.go @@ -103,3 +103,16 @@ func TestCryptionHandlerFlush(t *testing.T) { assert.Nil(t, err) assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String()) } + +func TestCryptionHandler_Hijack(t *testing.T) { + resp := httptest.NewRecorder() + writer := newCryptionResponseWriter(resp) + assert.NotPanics(t, func() { + writer.Hijack() + }) + + writer = newCryptionResponseWriter(mockedHijackable{resp}) + assert.NotPanics(t, func() { + writer.Hijack() + }) +} diff --git a/rest/handler/loghandler.go b/rest/handler/loghandler.go index 4a0800d5..3e21786d 100644 --- a/rest/handler/loghandler.go +++ b/rest/handler/loghandler.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "errors" "fmt" "io" "net" @@ -40,7 +41,11 @@ func (w *loggedResponseWriter) Header() http.Header { // Hijack implements the http.Hijacker interface. // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. func (w *loggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.w.(http.Hijacker).Hijack() + if hijacked, ok := w.w.(http.Hijacker); ok { + return hijacked.Hijack() + } + + return nil, nil, errors.New("server doesn't support hijacking") } func (w *loggedResponseWriter) Write(bytes []byte) (int, error) { @@ -91,6 +96,16 @@ func (w *detailLoggedResponseWriter) Header() http.Header { return w.writer.Header() } +// Hijack implements the http.Hijacker interface. +// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. +func (w *detailLoggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacked, ok := w.writer.w.(http.Hijacker); ok { + return hijacked.Hijack() + } + + return nil, nil, errors.New("server doesn't support hijacking") +} + func (w *detailLoggedResponseWriter) Write(bs []byte) (int, error) { w.buf.Write(bs) return w.writer.Write(bs) diff --git a/rest/handler/loghandler_test.go b/rest/handler/loghandler_test.go index ed97ca58..f94576fc 100644 --- a/rest/handler/loghandler_test.go +++ b/rest/handler/loghandler_test.go @@ -62,6 +62,44 @@ func TestLogHandlerSlow(t *testing.T) { } } +func TestLogHandler_Hijack(t *testing.T) { + resp := httptest.NewRecorder() + writer := &loggedResponseWriter{ + w: resp, + } + assert.NotPanics(t, func() { + writer.Hijack() + }) + + writer = &loggedResponseWriter{ + w: mockedHijackable{resp}, + } + assert.NotPanics(t, func() { + writer.Hijack() + }) +} + +func TestDetailedLogHandler_Hijack(t *testing.T) { + resp := httptest.NewRecorder() + writer := &detailLoggedResponseWriter{ + writer: &loggedResponseWriter{ + w: resp, + }, + } + assert.NotPanics(t, func() { + writer.Hijack() + }) + + writer = &detailLoggedResponseWriter{ + writer: &loggedResponseWriter{ + w: mockedHijackable{resp}, + }, + } + assert.NotPanics(t, func() { + writer.Hijack() + }) +} + func BenchmarkLogHandler(b *testing.B) { b.ReportAllocs()