From 9592639cb449b5108e51cbe4f5630146547ff389 Mon Sep 17 00:00:00 2001 From: kevin Date: Tue, 17 Nov 2020 18:04:48 +0800 Subject: [PATCH] add error handle tests --- rest/httpx/responses.go | 21 ++++++++---- rest/httpx/responses_test.go | 64 ++++++++++++++++++++++++++++++++---- 2 files changed, 72 insertions(+), 13 deletions(-) diff --git a/rest/httpx/responses.go b/rest/httpx/responses.go index f2b932fa..39a7ce06 100644 --- a/rest/httpx/responses.go +++ b/rest/httpx/responses.go @@ -9,16 +9,27 @@ import ( ) var ( - errorHandler = defaultErrorHandler + errorHandler func(error) (int, interface{}) lock sync.RWMutex ) func Error(w http.ResponseWriter, err error) { lock.RLock() - code, body := errorHandler(err) + handler := errorHandler lock.RUnlock() - WriteJson(w, code, body) + if handler == nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + code, body := errorHandler(err) + e, ok := body.(error) + if ok { + http.Error(w, e.Error(), code) + } else { + WriteJson(w, code, body) + } } func Ok(w http.ResponseWriter) { @@ -51,7 +62,3 @@ func WriteJson(w http.ResponseWriter, code int, v interface{}) { logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n) } } - -func defaultErrorHandler(err error) (int, interface{}) { - return http.StatusBadRequest, err -} diff --git a/rest/httpx/responses_test.go b/rest/httpx/responses_test.go index fdb23578..71f57648 100644 --- a/rest/httpx/responses_test.go +++ b/rest/httpx/responses_test.go @@ -19,13 +19,65 @@ func init() { } func TestError(t *testing.T) { - const body = "foo" - w := tracedResponseWriter{ - headers: make(map[string][]string), + const ( + body = "foo" + wrappedBody = `"foo"` + ) + + tests := []struct { + name string + input string + errorHandler func(error) (int, interface{}) + expectBody string + expectCode int + }{ + { + name: "default error handler", + input: body, + expectBody: body, + expectCode: http.StatusBadRequest, + }, + { + name: "customized error handler return string", + input: body, + errorHandler: func(err error) (int, interface{}) { + return http.StatusForbidden, err.Error() + }, + expectBody: wrappedBody, + expectCode: http.StatusForbidden, + }, + { + name: "customized error handler return error", + input: body, + errorHandler: func(err error) (int, interface{}) { + return http.StatusForbidden, err + }, + expectBody: body, + expectCode: http.StatusForbidden, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + if test.errorHandler != nil { + lock.RLock() + prev := errorHandler + lock.RUnlock() + SetErrorHandler(test.errorHandler) + defer func() { + lock.Lock() + errorHandler = prev + lock.Unlock() + }() + } + Error(&w, errors.New(test.input)) + assert.Equal(t, test.expectCode, w.code) + assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String())) + }) } - Error(&w, errors.New(body)) - assert.Equal(t, http.StatusBadRequest, w.code) - assert.Equal(t, body, strings.TrimSpace(w.builder.String())) } func TestOk(t *testing.T) {