diff --git a/gateway/server.go b/gateway/server.go index 80309629..bd9a1f69 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -122,7 +122,7 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A return func(w http.ResponseWriter, r *http.Request) { parser, err := internal.NewRequestParser(r, resolver) if err != nil { - httpx.Error(w, err) + httpx.ErrorCtx(r.Context(), w, err) return } @@ -134,12 +134,12 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A handler := internal.NewEventHandler(w, resolver) if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), rpcPath, s.prepareMetadata(r.Header), handler, parser.Next); err != nil { - httpx.Error(w, err) + httpx.ErrorCtx(r.Context(), w, err) } st := handler.Status if st.Code() != codes.OK { - httpx.Error(w, st.Err()) + httpx.ErrorCtx(r.Context(), w, st.Err()) } } } diff --git a/rest/handler/timeouthandler.go b/rest/handler/timeouthandler.go index ca4b4a49..d234e3cb 100644 --- a/rest/handler/timeouthandler.go +++ b/rest/handler/timeouthandler.go @@ -99,7 +99,7 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer tw.mu.Unlock() // there isn't any user-defined middleware before TimoutHandler, // so we can guarantee that cancelation in biz related code won't come here. - httpx.Error(w, ctx.Err(), func(w http.ResponseWriter, err error) { + httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) { if errors.Is(err, context.Canceled) { w.WriteHeader(statusClientClosedRequest) } else { diff --git a/rest/httpx/responses.go b/rest/httpx/responses.go index e3b2c919..9f2efa0a 100644 --- a/rest/httpx/responses.go +++ b/rest/httpx/responses.go @@ -1,6 +1,7 @@ package httpx import ( + "context" "encoding/json" "net/http" "sync" @@ -11,8 +12,9 @@ import ( ) var ( - errorHandler func(error) (int, interface{}) - lock sync.RWMutex + errorHandler func(error) (int, interface{}) + lock sync.RWMutex + errorHandlerCtx func(context.Context, error) (int, interface{}) ) // Error writes err into w. @@ -87,3 +89,71 @@ func WriteJson(w http.ResponseWriter, code int, v interface{}) { logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n) } } + +// Error writes err into w. +func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) { + lock.RLock() + handlerCtx := errorHandlerCtx + lock.RUnlock() + + if handlerCtx == nil { + if len(fns) > 0 { + fns[0](w, err) + } else if errcode.IsGrpcError(err) { + // don't unwrap error and get status.Message(), + // it hides the rpc error headers. + http.Error(w, err.Error(), errcode.CodeFromGrpcError(err)) + } else { + http.Error(w, err.Error(), http.StatusBadRequest) + } + + return + } + + code, body := handlerCtx(ctx, err) + if body == nil { + w.WriteHeader(code) + return + } + + e, ok := body.(error) + if ok { + http.Error(w, e.Error(), code) + } else { + WriteJsonCtx(ctx, w, code, body) + } +} + +// OkJson writes v into w with 200 OK. +func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v interface{}) { + WriteJsonCtx(ctx, w, http.StatusOK, v) +} + +// WriteJson writes v as json string into w with code. +func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interface{}) { + bs, err := json.Marshal(v) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set(ContentType, header.JsonContentType) + w.WriteHeader(code) + + if n, err := w.Write(bs); err != nil { + // http.ErrHandlerTimeout has been handled by http.TimeoutHandler, + // so it's ignored here. + if err != http.ErrHandlerTimeout { + logx.WithContext(ctx).Errorf("write response failed, error: %s", err) + } + } else if n < len(bs) { + logx.WithContext(ctx).Errorf("actual bytes: %d, written bytes: %d", len(bs), n) + } +} + +// SetErrorHandler sets the error handler, which is called on calling Error. +func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, interface{})) { + lock.Lock() + defer lock.Unlock() + errorHandlerCtx = handlerCtx +} diff --git a/rest/httpx/responses_test.go b/rest/httpx/responses_test.go index 8e8a882e..aef9f390 100644 --- a/rest/httpx/responses_test.go +++ b/rest/httpx/responses_test.go @@ -1,6 +1,7 @@ package httpx import ( + "context" "errors" "net/http" "strings" @@ -214,3 +215,115 @@ func (w *tracedResponseWriter) WriteHeader(code int) { w.wroteHeader = true w.code = code } + +func TestErrorCtx(t *testing.T) { + const ( + body = "foo" + wrappedBody = `"foo"` + ) + + tests := []struct { + name string + input string + errorHandlerCtx func(context.Context, error) (int, interface{}) + expectHasBody bool + expectBody string + expectCode int + }{ + { + name: "default error handler", + input: body, + expectHasBody: true, + expectBody: body, + expectCode: http.StatusBadRequest, + }, + { + name: "customized error handler return string", + input: body, + errorHandlerCtx: func(ctx context.Context, err error) (int, interface{}) { + return http.StatusForbidden, err.Error() + }, + expectHasBody: true, + expectBody: wrappedBody, + expectCode: http.StatusForbidden, + }, + { + name: "customized error handler return error", + input: body, + errorHandlerCtx: func(ctx context.Context, err error) (int, interface{}) { + return http.StatusForbidden, err + }, + expectHasBody: true, + expectBody: body, + expectCode: http.StatusForbidden, + }, + { + name: "customized error handler return nil", + input: body, + errorHandlerCtx: func(context.Context, error) (int, interface{}) { + return http.StatusForbidden, nil + }, + expectHasBody: false, + expectBody: "", + expectCode: http.StatusForbidden, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + if test.errorHandlerCtx != nil { + lock.RLock() + prev := errorHandlerCtx + lock.RUnlock() + SetErrorHandlerCtx(test.errorHandlerCtx) + defer func() { + lock.Lock() + test.errorHandlerCtx = prev + lock.Unlock() + }() + } + ErrorCtx(context.Background(), &w, errors.New(test.input)) + assert.Equal(t, test.expectCode, w.code) + assert.Equal(t, test.expectHasBody, w.hasBody) + assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String())) + }) + } + + //The current handler is a global event,Set default values to avoid impacting subsequent unit tests + SetErrorHandlerCtx(nil) +} + +func TestErrorWithGrpcErrorCtx(t *testing.T) { + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + ErrorCtx(context.Background(), &w, status.Error(codes.Unavailable, "foo")) + assert.Equal(t, http.StatusServiceUnavailable, w.code) + assert.True(t, w.hasBody) + assert.True(t, strings.Contains(w.builder.String(), "foo")) +} + +func TestErrorWithHandlerCtx(t *testing.T) { + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + ErrorCtx(context.Background(), &w, errors.New("foo"), func(w http.ResponseWriter, err error) { + http.Error(w, err.Error(), 499) + }) + assert.Equal(t, 499, w.code) + assert.True(t, w.hasBody) + assert.Equal(t, "foo", strings.TrimSpace(w.builder.String())) +} + +func TestWriteJsonCtxMarshalFailed(t *testing.T) { + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + WriteJsonCtx(context.Background(), &w, http.StatusOK, map[string]interface{}{ + "Data": complex(0, 0), + }) + assert.Equal(t, http.StatusInternalServerError, w.code) +} diff --git a/tools/goctl/api/gogen/handler.tpl b/tools/goctl/api/gogen/handler.tpl index 033dc22e..a6480a23 100644 --- a/tools/goctl/api/gogen/handler.tpl +++ b/tools/goctl/api/gogen/handler.tpl @@ -11,16 +11,16 @@ func {{.HandlerName}}(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { {{if .HasRequest}}var req types.{{.RequestType}} if err := httpx.Parse(r, &req); err != nil { - httpx.Error(w, err) + httpx.ErrorCtx(r.Context(), w, err) return } {{end}}l := {{.LogicName}}.New{{.LogicType}}(r.Context(), svcCtx) {{if .HasResp}}resp, {{end}}err := l.{{.Call}}({{if .HasRequest}}&req{{end}}) if err != nil { - httpx.Error(w, err) + httpx.ErrorCtx(r.Context(), w, err) } else { - {{if .HasResp}}httpx.OkJson(w, resp){{else}}httpx.Ok(w){{end}} + {{if .HasResp}}httpx.OkJsonCtx(r.Context(), w, resp){{else}}httpx.Ok(w){{end}} } } }