add error handle tests

master
kevin 4 years ago
parent abcb28e506
commit 9592639cb4

@ -9,17 +9,28 @@ import (
) )
var ( var (
errorHandler = defaultErrorHandler errorHandler func(error) (int, interface{})
lock sync.RWMutex lock sync.RWMutex
) )
func Error(w http.ResponseWriter, err error) { func Error(w http.ResponseWriter, err error) {
lock.RLock() lock.RLock()
code, body := errorHandler(err) handler := errorHandler
lock.RUnlock() lock.RUnlock()
if handler == nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
code, body := errorHandler(err)
e, ok := body.(error)
if ok {
http.Error(w, e.Error(), code)
} else {
WriteJson(w, code, body) WriteJson(w, code, body)
} }
}
func Ok(w http.ResponseWriter) { func Ok(w http.ResponseWriter) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@ -51,7 +62,3 @@ func WriteJson(w http.ResponseWriter, code int, v interface{}) {
logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n) logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
} }
} }
func defaultErrorHandler(err error) (int, interface{}) {
return http.StatusBadRequest, err
}

@ -19,13 +19,65 @@ func init() {
} }
func TestError(t *testing.T) { func TestError(t *testing.T) {
const body = "foo" const (
body = "foo"
wrappedBody = `"foo"`
)
tests := []struct {
name string
input string
errorHandler func(error) (int, interface{})
expectBody string
expectCode int
}{
{
name: "default error handler",
input: body,
expectBody: body,
expectCode: http.StatusBadRequest,
},
{
name: "customized error handler return string",
input: body,
errorHandler: func(err error) (int, interface{}) {
return http.StatusForbidden, err.Error()
},
expectBody: wrappedBody,
expectCode: http.StatusForbidden,
},
{
name: "customized error handler return error",
input: body,
errorHandler: func(err error) (int, interface{}) {
return http.StatusForbidden, err
},
expectBody: body,
expectCode: http.StatusForbidden,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
w := tracedResponseWriter{ w := tracedResponseWriter{
headers: make(map[string][]string), headers: make(map[string][]string),
} }
Error(&w, errors.New(body)) if test.errorHandler != nil {
assert.Equal(t, http.StatusBadRequest, w.code) lock.RLock()
assert.Equal(t, body, strings.TrimSpace(w.builder.String())) prev := errorHandler
lock.RUnlock()
SetErrorHandler(test.errorHandler)
defer func() {
lock.Lock()
errorHandler = prev
lock.Unlock()
}()
}
Error(&w, errors.New(test.input))
assert.Equal(t, test.expectCode, w.code)
assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
})
}
} }
func TestOk(t *testing.T) { func TestOk(t *testing.T) {

Loading…
Cancel
Save