diff --git a/rest/handler/timeouthandler.go b/rest/handler/timeouthandler.go index 1b9532cf..aab564a3 100644 --- a/rest/handler/timeouthandler.go +++ b/rest/handler/timeouthandler.go @@ -13,6 +13,7 @@ import ( "sync" "time" + "github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/internal" ) @@ -91,12 +92,14 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer tw.mu.Unlock() // there isn't any user-defined middleware before TimoutHandler, // so we can guarantee that cancelation in biz related code won't come here. - if errors.Is(ctx.Err(), context.Canceled) { - w.WriteHeader(statusClientClosedRequest) - } else { - w.WriteHeader(http.StatusServiceUnavailable) - } - io.WriteString(w, h.errorBody()) + httpx.Error(w, ctx.Err(), func(w http.ResponseWriter, err error) { + if errors.Is(err, context.Canceled) { + w.WriteHeader(statusClientClosedRequest) + } else { + w.WriteHeader(http.StatusServiceUnavailable) + } + io.WriteString(w, h.errorBody()) + }) tw.timedOut = true } } diff --git a/rest/httpx/responses.go b/rest/httpx/responses.go index 055f521b..ef1ebc0a 100644 --- a/rest/httpx/responses.go +++ b/rest/httpx/responses.go @@ -14,13 +14,17 @@ var ( ) // Error writes err into w. -func Error(w http.ResponseWriter, err error) { +func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) { lock.RLock() handler := errorHandler lock.RUnlock() if handler == nil { - http.Error(w, err.Error(), http.StatusBadRequest) + if len(fns) > 0 { + fns[0](w, err) + } else { + http.Error(w, err.Error(), http.StatusBadRequest) + } return } diff --git a/rest/httpx/responses_test.go b/rest/httpx/responses_test.go index adb94c79..17d14d66 100644 --- a/rest/httpx/responses_test.go +++ b/rest/httpx/responses_test.go @@ -95,6 +95,18 @@ func TestError(t *testing.T) { } } +func TestErrorWithHandler(t *testing.T) { + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + Error(&w, errors.New("foo"), func(w http.ResponseWriter, err error) { + http.Error(w, err.Error(), 499) + }) + assert.Equal(t, 499, w.code) + assert.True(t, w.hasBody) + assert.Equal(t, "foo", strings.TrimSpace(w.builder.String())) +} + func TestOk(t *testing.T) { w := tracedResponseWriter{ headers: make(map[string][]string), diff --git a/tools/goctl/api/apigen/gen.go b/tools/goctl/api/apigen/gen.go index b4c7ea09..fcf352c4 100644 --- a/tools/goctl/api/apigen/gen.go +++ b/tools/goctl/api/apigen/gen.go @@ -52,7 +52,6 @@ func ApiCommand(c *cli.Context) error { } defer fp.Close() - home := c.String("home") remote := c.String("remote") if len(remote) > 0 {