From 641ebf166780d418389acc00df48da3ae601a774 Mon Sep 17 00:00:00 2001 From: xiang Date: Wed, 4 Jan 2023 10:21:57 +0800 Subject: [PATCH] feat: trace http.status_code (#2708) * feat: trace http.status_code * feat: implements http.Flusher & http.Hijacker for traceResponseWriter * test: delete notTracingSpans after test * feat: trace http.status_code * feat: implements http.Flusher & http.Hijacker for traceResponseWriter * test: delete notTracingSpans after test * refactor: update trace handler span message * fix: code conflict --- rest/handler/tracinghandler.go | 59 ++++++++++++++++++++++++++++- rest/handler/tracinghandler_test.go | 51 +++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/rest/handler/tracinghandler.go b/rest/handler/tracinghandler.go index 2d6b7119..7d6b285f 100644 --- a/rest/handler/tracinghandler.go +++ b/rest/handler/tracinghandler.go @@ -1,17 +1,26 @@ package handler import ( + "bufio" + "errors" + "net" "net/http" "sync" "github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/trace" "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/propagation" semconv "go.opentelemetry.io/otel/semconv/v1.4.0" oteltrace "go.opentelemetry.io/otel/trace" ) +const ( + traceKeyStatusCode = "http.status_code" +) + var notTracingSpans sync.Map // DontTraceSpan disable tracing for the specified span name. @@ -19,6 +28,40 @@ func DontTraceSpan(spanName string) { notTracingSpans.Store(spanName, lang.Placeholder) } +type traceResponseWriter struct { + w http.ResponseWriter + code int +} + +// Flush implements the http.Flusher interface. +func (w *traceResponseWriter) Flush() { + if flusher, ok := w.w.(http.Flusher); ok { + flusher.Flush() + } +} + +// Hijack implements the http.Hijacker interface. +func (w *traceResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacked, ok := w.w.(http.Hijacker); ok { + return hijacked.Hijack() + } + + return nil, nil, errors.New("server doesn't support hijacking") +} + +func (w *traceResponseWriter) Header() http.Header { + return w.w.Header() +} + +func (w *traceResponseWriter) Write(data []byte) (int, error) { + return w.w.Write(data) +} + +func (w *traceResponseWriter) WriteHeader(statusCode int) { + w.w.WriteHeader(statusCode) + w.code = statusCode +} + // TracingHandler return a middleware that process the opentelemetry. func TracingHandler(serviceName, path string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { @@ -48,7 +91,21 @@ func TracingHandler(serviceName, path string) func(http.Handler) http.Handler { // convenient for tracking error messages propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header())) - next.ServeHTTP(w, r.WithContext(spanCtx)) + trw := &traceResponseWriter{ + w: w, + code: http.StatusOK, + } + next.ServeHTTP(trw, r.WithContext(spanCtx)) + + span.SetAttributes(attribute.KeyValue{ + Key: traceKeyStatusCode, + Value: attribute.IntValue(trw.code), + }) + if trw.code >= http.StatusBadRequest { + span.SetStatus(codes.Error, http.StatusText(trw.code)) + } else { + span.SetStatus(codes.Ok, http.StatusText(trw.code)) + } }) } } diff --git a/rest/handler/tracinghandler_test.go b/rest/handler/tracinghandler_test.go index f33d5926..79a94c4d 100644 --- a/rest/handler/tracinghandler_test.go +++ b/rest/handler/tracinghandler_test.go @@ -2,8 +2,10 @@ package handler import ( "context" + "io" "net/http" "net/http/httptest" + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -59,8 +61,10 @@ func TestDontTracingSpan(t *testing.T) { Batcher: "jaeger", Sampler: 1.0, }) + defer ztrace.StopAgent() DontTraceSpan("bar") + defer notTracingSpans.Delete("bar") for _, test := range []string{"", "bar", "foo"} { t.Run(test, func(t *testing.T) { @@ -97,3 +101,50 @@ func TestDontTracingSpan(t *testing.T) { }) } } + +func TestTraceResponseWriter(t *testing.T) { + ztrace.StartAgent(ztrace.Config{ + Name: "go-zero-test", + Endpoint: "http://localhost:14268/api/traces", + Batcher: "jaeger", + Sampler: 1.0, + }) + defer ztrace.StopAgent() + + for _, test := range []int{0, 200, 300, 400, 401, 500, 503} { + t.Run(strconv.Itoa(test), func(t *testing.T) { + h := chain.New(TracingHandler("foo", "bar")).Then( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + span := trace.SpanFromContext(r.Context()) + spanCtx := span.SpanContext() + assert.True(t, span.IsRecording()) + assert.True(t, spanCtx.IsValid()) + if test != 0 { + w.WriteHeader(test) + } + w.Write([]byte("hello")) + })) + ts := httptest.NewServer(h) + defer ts.Close() + + client := ts.Client() + err := func(ctx context.Context) error { + ctx, span := otel.Tracer("httptrace/client").Start(ctx, "test") + defer span.End() + + req, _ := http.NewRequest("GET", ts.URL, nil) + otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header)) + + res, err := client.Do(req) + assert.Nil(t, err) + resBody := make([]byte, 5) + _, err = res.Body.Read(resBody) + assert.Equal(t, io.EOF, err) + assert.Equal(t, []byte("hello"), resBody, "response body fail") + return res.Body.Close() + }(context.Background()) + + assert.Nil(t, err) + }) + } +}