feature : responses whit context (#2637)

master
heyehang 2 years ago committed by GitHub
parent 9941055eaa
commit a644ec7edd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -122,7 +122,7 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
return func(w http.ResponseWriter, r *http.Request) {
parser, err := internal.NewRequestParser(r, resolver)
if err != nil {
httpx.Error(w, err)
httpx.ErrorCtx(r.Context(), w, err)
return
}
@ -134,12 +134,12 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
handler := internal.NewEventHandler(w, resolver)
if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), rpcPath, s.prepareMetadata(r.Header),
handler, parser.Next); err != nil {
httpx.Error(w, err)
httpx.ErrorCtx(r.Context(), w, err)
}
st := handler.Status
if st.Code() != codes.OK {
httpx.Error(w, st.Err())
httpx.ErrorCtx(r.Context(), w, st.Err())
}
}
}

@ -99,7 +99,7 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer tw.mu.Unlock()
// there isn't any user-defined middleware before TimoutHandler,
// so we can guarantee that cancelation in biz related code won't come here.
httpx.Error(w, ctx.Err(), func(w http.ResponseWriter, err error) {
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
if errors.Is(err, context.Canceled) {
w.WriteHeader(statusClientClosedRequest)
} else {

@ -1,6 +1,7 @@
package httpx
import (
"context"
"encoding/json"
"net/http"
"sync"
@ -11,8 +12,9 @@ import (
)
var (
errorHandler func(error) (int, interface{})
lock sync.RWMutex
errorHandler func(error) (int, interface{})
lock sync.RWMutex
errorHandlerCtx func(context.Context, error) (int, interface{})
)
// Error writes err into w.
@ -87,3 +89,71 @@ func WriteJson(w http.ResponseWriter, code int, v interface{}) {
logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
}
}
// Error writes err into w.
func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
lock.RLock()
handlerCtx := errorHandlerCtx
lock.RUnlock()
if handlerCtx == nil {
if len(fns) > 0 {
fns[0](w, err)
} else if errcode.IsGrpcError(err) {
// don't unwrap error and get status.Message(),
// it hides the rpc error headers.
http.Error(w, err.Error(), errcode.CodeFromGrpcError(err))
} else {
http.Error(w, err.Error(), http.StatusBadRequest)
}
return
}
code, body := handlerCtx(ctx, err)
if body == nil {
w.WriteHeader(code)
return
}
e, ok := body.(error)
if ok {
http.Error(w, e.Error(), code)
} else {
WriteJsonCtx(ctx, w, code, body)
}
}
// OkJson writes v into w with 200 OK.
func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v interface{}) {
WriteJsonCtx(ctx, w, http.StatusOK, v)
}
// WriteJson writes v as json string into w with code.
func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interface{}) {
bs, err := json.Marshal(v)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set(ContentType, header.JsonContentType)
w.WriteHeader(code)
if n, err := w.Write(bs); err != nil {
// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
// so it's ignored here.
if err != http.ErrHandlerTimeout {
logx.WithContext(ctx).Errorf("write response failed, error: %s", err)
}
} else if n < len(bs) {
logx.WithContext(ctx).Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
}
}
// SetErrorHandler sets the error handler, which is called on calling Error.
func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, interface{})) {
lock.Lock()
defer lock.Unlock()
errorHandlerCtx = handlerCtx
}

@ -1,6 +1,7 @@
package httpx
import (
"context"
"errors"
"net/http"
"strings"
@ -214,3 +215,115 @@ func (w *tracedResponseWriter) WriteHeader(code int) {
w.wroteHeader = true
w.code = code
}
func TestErrorCtx(t *testing.T) {
const (
body = "foo"
wrappedBody = `"foo"`
)
tests := []struct {
name string
input string
errorHandlerCtx func(context.Context, error) (int, interface{})
expectHasBody bool
expectBody string
expectCode int
}{
{
name: "default error handler",
input: body,
expectHasBody: true,
expectBody: body,
expectCode: http.StatusBadRequest,
},
{
name: "customized error handler return string",
input: body,
errorHandlerCtx: func(ctx context.Context, err error) (int, interface{}) {
return http.StatusForbidden, err.Error()
},
expectHasBody: true,
expectBody: wrappedBody,
expectCode: http.StatusForbidden,
},
{
name: "customized error handler return error",
input: body,
errorHandlerCtx: func(ctx context.Context, err error) (int, interface{}) {
return http.StatusForbidden, err
},
expectHasBody: true,
expectBody: body,
expectCode: http.StatusForbidden,
},
{
name: "customized error handler return nil",
input: body,
errorHandlerCtx: func(context.Context, error) (int, interface{}) {
return http.StatusForbidden, nil
},
expectHasBody: false,
expectBody: "",
expectCode: http.StatusForbidden,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
if test.errorHandlerCtx != nil {
lock.RLock()
prev := errorHandlerCtx
lock.RUnlock()
SetErrorHandlerCtx(test.errorHandlerCtx)
defer func() {
lock.Lock()
test.errorHandlerCtx = prev
lock.Unlock()
}()
}
ErrorCtx(context.Background(), &w, errors.New(test.input))
assert.Equal(t, test.expectCode, w.code)
assert.Equal(t, test.expectHasBody, w.hasBody)
assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
})
}
//The current handler is a global event,Set default values to avoid impacting subsequent unit tests
SetErrorHandlerCtx(nil)
}
func TestErrorWithGrpcErrorCtx(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
ErrorCtx(context.Background(), &w, status.Error(codes.Unavailable, "foo"))
assert.Equal(t, http.StatusServiceUnavailable, w.code)
assert.True(t, w.hasBody)
assert.True(t, strings.Contains(w.builder.String(), "foo"))
}
func TestErrorWithHandlerCtx(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
ErrorCtx(context.Background(), &w, errors.New("foo"), func(w http.ResponseWriter, err error) {
http.Error(w, err.Error(), 499)
})
assert.Equal(t, 499, w.code)
assert.True(t, w.hasBody)
assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
}
func TestWriteJsonCtxMarshalFailed(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
WriteJsonCtx(context.Background(), &w, http.StatusOK, map[string]interface{}{
"Data": complex(0, 0),
})
assert.Equal(t, http.StatusInternalServerError, w.code)
}

@ -11,16 +11,16 @@ func {{.HandlerName}}(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
{{if .HasRequest}}var req types.{{.RequestType}}
if err := httpx.Parse(r, &req); err != nil {
httpx.Error(w, err)
httpx.ErrorCtx(r.Context(), w, err)
return
}
{{end}}l := {{.LogicName}}.New{{.LogicType}}(r.Context(), svcCtx)
{{if .HasResp}}resp, {{end}}err := l.{{.Call}}({{if .HasRequest}}&req{{end}})
if err != nil {
httpx.Error(w, err)
httpx.ErrorCtx(r.Context(), w, err)
} else {
{{if .HasResp}}httpx.OkJson(w, resp){{else}}httpx.Ok(w){{end}}
{{if .HasResp}}httpx.OkJsonCtx(r.Context(), w, resp){{else}}httpx.Ok(w){{end}}
}
}
}

Loading…
Cancel
Save