refactor(rest): keep rest log collector context key private (#3407)

master
cong 1 year ago committed by GitHub
parent b71453985c
commit 61e562d0c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,6 @@ package handler
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -44,7 +43,7 @@ func LogHandler(next http.Handler) http.Handler {
var dup io.ReadCloser var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body) r.Body, dup = iox.DupReadCloser(r.Body)
next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs))) next.ServeHTTP(&lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
r.Body = dup r.Body = dup
logBrief(r, lrw.Code, timer, logs) logBrief(r, lrw.Code, timer, logs)
}) })
@ -102,7 +101,7 @@ func DetailedLogHandler(next http.Handler) http.Handler {
var dup io.ReadCloser var dup io.ReadCloser
r.Body, dup = iox.DupReadCloser(r.Body) r.Body, dup = iox.DupReadCloser(r.Body)
logs := new(internal.LogCollector) logs := new(internal.LogCollector)
next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs))) next.ServeHTTP(lrw, r.WithContext(internal.WithLogCollector(r.Context(), logs)))
r.Body = dup r.Body = dup
logDetails(r, lrw, timer, logs) logDetails(r, lrw, timer, logs)
}) })

@ -22,7 +22,7 @@ func TestLogHandler(t *testing.T) {
for _, logHandler := range handlers { for _, logHandler := range handlers {
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything") internal.LogCollectorFromContext(r.Context()).Append("anything")
w.Header().Set("X-Test", "test") w.Header().Set("X-Test", "test")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)
_, err := w.Write([]byte("content")) _, err := w.Write([]byte("content"))
@ -49,7 +49,7 @@ func TestLogHandlerVeryLong(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buf) req := httptest.NewRequest(http.MethodPost, "http://localhost", &buf)
handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything") internal.LogCollectorFromContext(r.Context()).Append("anything")
_, _ = io.Copy(io.Discard, r.Body) _, _ = io.Copy(io.Discard, r.Body)
w.Header().Set("X-Test", "test") w.Header().Set("X-Test", "test")
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)

@ -2,6 +2,7 @@ package internal
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"net/http" "net/http"
"sync" "sync"
@ -10,13 +11,32 @@ import (
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
) )
// LogContext is a context key. // logContextKey is a context key.
var LogContext = contextKey("request_logs") var logContextKey = contextKey("request_logs")
// A LogCollector is used to collect logs. type (
type LogCollector struct { // LogCollector is used to collect logs.
LogCollector struct {
Messages []string Messages []string
lock sync.Mutex lock sync.Mutex
}
contextKey string
)
// WithLogCollector returns a new context with LogCollector.
func WithLogCollector(ctx context.Context, lc *LogCollector) context.Context {
return context.WithValue(ctx, logContextKey, lc)
}
// LogCollectorFromContext returns LogCollector from ctx.
func LogCollectorFromContext(ctx context.Context) *LogCollector {
val := ctx.Value(logContextKey)
if val == nil {
return nil
}
return val.(*LogCollector)
} }
// Append appends msg into log context. // Append appends msg into log context.
@ -73,9 +93,9 @@ func Infof(r *http.Request, format string, v ...any) {
} }
func appendLog(r *http.Request, message string) { func appendLog(r *http.Request, message string) {
logs := r.Context().Value(LogContext) logs := LogCollectorFromContext(r.Context())
if logs != nil { if logs != nil {
logs.(*LogCollector).Append(message) logs.Append(message)
} }
} }
@ -90,9 +110,3 @@ func formatf(r *http.Request, format string, v ...any) string {
func formatWithReq(r *http.Request, v string) string { func formatWithReq(r *http.Request, v string) string {
return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v) return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v)
} }
type contextKey string
func (c contextKey) String() string {
return "rest/internal context key " + string(c)
}

@ -14,7 +14,7 @@ import (
func TestInfo(t *testing.T) { func TestInfo(t *testing.T) {
collector := new(LogCollector) collector := new(LogCollector)
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
req = req.WithContext(context.WithValue(req.Context(), LogContext, collector)) req = req.WithContext(WithLogCollector(req.Context(), collector))
Info(req, "first") Info(req, "first")
Infof(req, "second %s", "third") Infof(req, "second %s", "third")
val := collector.Flush() val := collector.Flush()
@ -35,7 +35,10 @@ func TestError(t *testing.T) {
assert.True(t, strings.Contains(val, "third")) assert.True(t, strings.Contains(val, "third"))
} }
func TestContextKey_String(t *testing.T) { func TestLogCollectorContext(t *testing.T) {
val := contextKey("foo") ctx := context.Background()
assert.True(t, strings.Contains(val.String(), "foo")) assert.Nil(t, LogCollectorFromContext(ctx))
collector := new(LogCollector)
ctx = WithLogCollector(ctx, collector)
assert.Equal(t, collector, LogCollectorFromContext(ctx))
} }

Loading…
Cancel
Save