From 97a8b3ade59014539081ce63e3c07ca30c239b70 Mon Sep 17 00:00:00 2001 From: chen quan Date: Wed, 23 Nov 2022 22:50:08 +0800 Subject: [PATCH] fix(rest): fix issues#2628 (#2629) --- rest/handler/tracinghandler.go | 9 +++------ rest/handler/tracinghandler_test.go | 13 ++++++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/rest/handler/tracinghandler.go b/rest/handler/tracinghandler.go index 9e258dfc..2d6b7119 100644 --- a/rest/handler/tracinghandler.go +++ b/rest/handler/tracinghandler.go @@ -26,20 +26,17 @@ func TracingHandler(serviceName, path string) func(http.Handler) http.Handler { tracer := otel.GetTracerProvider().Tracer(trace.TraceName) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer func() { - next.ServeHTTP(w, r) - }() - - ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) spanName := path if len(spanName) == 0 { spanName = r.URL.Path } if _, ok := notTracingSpans.Load(spanName); ok { + next.ServeHTTP(w, r) return } + ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) spanCtx, span := tracer.Start( ctx, spanName, @@ -51,7 +48,7 @@ func TracingHandler(serviceName, path string) func(http.Handler) http.Handler { // convenient for tracking error messages propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header())) - r = r.WithContext(spanCtx) + next.ServeHTTP(w, r.WithContext(spanCtx)) }) } } diff --git a/rest/handler/tracinghandler_test.go b/rest/handler/tracinghandler_test.go index c1d224d6..f33d5926 100644 --- a/rest/handler/tracinghandler_test.go +++ b/rest/handler/tracinghandler_test.go @@ -27,9 +27,9 @@ func TestOtelHandler(t *testing.T) { t.Run(test, func(t *testing.T) { h := chain.New(TracingHandler("foo", test)).Then( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header)) - spanCtx := trace.SpanContextFromContext(ctx) - assert.True(t, spanCtx.IsValid()) + span := trace.SpanFromContext(r.Context()) + assert.True(t, span.SpanContext().IsValid()) + assert.True(t, span.IsRecording()) })) ts := httptest.NewServer(h) defer ts.Close() @@ -52,7 +52,7 @@ func TestOtelHandler(t *testing.T) { } } -func TestDontTracingSpanName(t *testing.T) { +func TestDontTracingSpan(t *testing.T) { ztrace.StartAgent(ztrace.Config{ Name: "go-zero-test", Endpoint: "http://localhost:14268/api/traces", @@ -66,12 +66,15 @@ func TestDontTracingSpanName(t *testing.T) { t.Run(test, func(t *testing.T) { h := chain.New(TracingHandler("foo", test)).Then( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - spanCtx := trace.SpanContextFromContext(r.Context()) + span := trace.SpanFromContext(r.Context()) + spanCtx := span.SpanContext() if test == "bar" { assert.False(t, spanCtx.IsValid()) + assert.False(t, span.IsRecording()) return } + assert.True(t, span.IsRecording()) assert.True(t, spanCtx.IsValid()) })) ts := httptest.NewServer(h)