From 7a75dce465ec049b1d11205f12cd4f2ec5b69772 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Wed, 14 Dec 2022 23:36:56 +0800 Subject: [PATCH] refactor: remove duplicated code (#2705) --- rest/httpx/responses.go | 115 ++++++++++++++++++---------------------- 1 file changed, 51 insertions(+), 64 deletions(-) diff --git a/rest/httpx/responses.go b/rest/httpx/responses.go index 9f2efa0a..fa6ca808 100644 --- a/rest/httpx/responses.go +++ b/rest/httpx/responses.go @@ -3,6 +3,7 @@ package httpx import ( "context" "encoding/json" + "fmt" "net/http" "sync" @@ -13,8 +14,8 @@ import ( var ( errorHandler func(error) (int, interface{}) - lock sync.RWMutex errorHandlerCtx func(context.Context, error) (int, interface{}) + lock sync.RWMutex ) // Error writes err into w. @@ -23,32 +24,26 @@ func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, handler := errorHandler lock.RUnlock() - if handler == nil { - if len(fns) > 0 { - fns[0](w, err) - } else if errcode.IsGrpcError(err) { - // don't unwrap error and get status.Message(), - // it hides the rpc error headers. - http.Error(w, err.Error(), errcode.CodeFromGrpcError(err)) - } else { - http.Error(w, err.Error(), http.StatusBadRequest) - } + doHandleError(w, err, handler, WriteJson, fns...) +} - return - } +// ErrorCtx writes err into w. +func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, + fns ...func(w http.ResponseWriter, err error)) { + lock.RLock() + handlerCtx := errorHandlerCtx + lock.RUnlock() - code, body := handler(err) - if body == nil { - w.WriteHeader(code) - return + var handler func(error) (int, interface{}) + if handlerCtx != nil { + handler = func(err error) (int, interface{}) { + return handlerCtx(ctx, err) + } } - - e, ok := body.(error) - if ok { - http.Error(w, e.Error(), code) - } else { - WriteJson(w, code, body) + writeJson := func(w http.ResponseWriter, code int, v interface{}) { + WriteJsonCtx(ctx, w, code, v) } + doHandleError(w, err, handler, writeJson, fns...) } // Ok writes HTTP 200 OK into w. @@ -61,6 +56,11 @@ func OkJson(w http.ResponseWriter, v interface{}) { WriteJson(w, http.StatusOK, v) } +// OkJsonCtx writes v into w with 200 OK. +func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v interface{}) { + WriteJsonCtx(ctx, w, http.StatusOK, v) +} + // SetErrorHandler sets the error handler, which is called on calling Error. func SetErrorHandler(handler func(error) (int, interface{})) { lock.Lock() @@ -68,37 +68,35 @@ func SetErrorHandler(handler func(error) (int, interface{})) { errorHandler = handler } +// SetErrorHandlerCtx sets the error handler, which is called on calling Error. +func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, interface{})) { + lock.Lock() + defer lock.Unlock() + errorHandlerCtx = handlerCtx +} + // WriteJson writes v as json string into w with code. func WriteJson(w http.ResponseWriter, code int, v interface{}) { - bs, err := json.Marshal(v) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + if err := doWriteJson(w, code, v); err != nil { + logx.Error(err) } +} - w.Header().Set(ContentType, header.JsonContentType) - w.WriteHeader(code) - - if n, err := w.Write(bs); err != nil { - // http.ErrHandlerTimeout has been handled by http.TimeoutHandler, - // so it's ignored here. - if err != http.ErrHandlerTimeout { - logx.Errorf("write response failed, error: %s", err) - } - } else if n < len(bs) { - logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n) +// WriteJsonCtx writes v as json string into w with code. +func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interface{}) { + if err := doWriteJson(w, code, v); err != nil { + logx.WithContext(ctx).Error(err) } } -// Error writes err into w. -func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) { - lock.RLock() - handlerCtx := errorHandlerCtx - lock.RUnlock() - - if handlerCtx == nil { +func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, interface{}), + writeJson func(w http.ResponseWriter, code int, v interface{}), + fns ...func(w http.ResponseWriter, err error)) { + if handler == nil { if len(fns) > 0 { - fns[0](w, err) + for _, fn := range fns { + fn(w, err) + } } else if errcode.IsGrpcError(err) { // don't unwrap error and get status.Message(), // it hides the rpc error headers. @@ -110,7 +108,7 @@ func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func return } - code, body := handlerCtx(ctx, err) + code, body := handler(err) if body == nil { w.WriteHeader(code) return @@ -120,21 +118,15 @@ func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func if ok { http.Error(w, e.Error(), code) } else { - WriteJsonCtx(ctx, w, code, body) + writeJson(w, code, body) } } -// OkJson writes v into w with 200 OK. -func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v interface{}) { - WriteJsonCtx(ctx, w, http.StatusOK, v) -} - -// WriteJson writes v as json string into w with code. -func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interface{}) { +func doWriteJson(w http.ResponseWriter, code int, v interface{}) error { bs, err := json.Marshal(v) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) - return + return fmt.Errorf("marshal json failed, error: %w", err) } w.Header().Set(ContentType, header.JsonContentType) @@ -144,16 +136,11 @@ func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interf // http.ErrHandlerTimeout has been handled by http.TimeoutHandler, // so it's ignored here. if err != http.ErrHandlerTimeout { - logx.WithContext(ctx).Errorf("write response failed, error: %s", err) + return fmt.Errorf("write response failed, error: %w", err) } } else if n < len(bs) { - logx.WithContext(ctx).Errorf("actual bytes: %d, written bytes: %d", len(bs), n) + return fmt.Errorf("actual bytes: %d, written bytes: %d", len(bs), n) } -} -// SetErrorHandler sets the error handler, which is called on calling Error. -func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, interface{})) { - lock.Lock() - defer lock.Unlock() - errorHandlerCtx = handlerCtx + return nil }