* fix #1070

* test: add more tests
master
Kevin Wan 3 years ago committed by GitHub
parent b8ea16a88e
commit 62266d8f91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,6 +13,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/tal-tech/go-zero/rest/httpx"
"github.com/tal-tech/go-zero/rest/internal" "github.com/tal-tech/go-zero/rest/internal"
) )
@ -91,12 +92,14 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer tw.mu.Unlock() defer tw.mu.Unlock()
// there isn't any user-defined middleware before TimoutHandler, // there isn't any user-defined middleware before TimoutHandler,
// so we can guarantee that cancelation in biz related code won't come here. // so we can guarantee that cancelation in biz related code won't come here.
if errors.Is(ctx.Err(), context.Canceled) { httpx.Error(w, ctx.Err(), func(w http.ResponseWriter, err error) {
w.WriteHeader(statusClientClosedRequest) if errors.Is(err, context.Canceled) {
} else { w.WriteHeader(statusClientClosedRequest)
w.WriteHeader(http.StatusServiceUnavailable) } else {
} w.WriteHeader(http.StatusServiceUnavailable)
io.WriteString(w, h.errorBody()) }
io.WriteString(w, h.errorBody())
})
tw.timedOut = true tw.timedOut = true
} }
} }

@ -14,13 +14,17 @@ var (
) )
// Error writes err into w. // Error writes err into w.
func Error(w http.ResponseWriter, err error) { func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
lock.RLock() lock.RLock()
handler := errorHandler handler := errorHandler
lock.RUnlock() lock.RUnlock()
if handler == nil { if handler == nil {
http.Error(w, err.Error(), http.StatusBadRequest) if len(fns) > 0 {
fns[0](w, err)
} else {
http.Error(w, err.Error(), http.StatusBadRequest)
}
return return
} }

@ -95,6 +95,18 @@ func TestError(t *testing.T) {
} }
} }
func TestErrorWithHandler(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
Error(&w, errors.New("foo"), func(w http.ResponseWriter, err error) {
http.Error(w, err.Error(), 499)
})
assert.Equal(t, 499, w.code)
assert.True(t, w.hasBody)
assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
}
func TestOk(t *testing.T) { func TestOk(t *testing.T) {
w := tracedResponseWriter{ w := tracedResponseWriter{
headers: make(map[string][]string), headers: make(map[string][]string),

@ -52,7 +52,6 @@ func ApiCommand(c *cli.Context) error {
} }
defer fp.Close() defer fp.Close()
home := c.String("home") home := c.String("home")
remote := c.String("remote") remote := c.String("remote")
if len(remote) > 0 { if len(remote) > 0 {

Loading…
Cancel
Save