diff --git a/rest/handler/tracinghandler.go b/rest/handler/tracinghandler.go index be220f1c..4e9e5f1b 100644 --- a/rest/handler/tracinghandler.go +++ b/rest/handler/tracinghandler.go @@ -2,7 +2,9 @@ package handler import ( "net/http" + "sync" + "github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/trace" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/propagation" @@ -10,6 +12,13 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" ) +var dontTracingSpanNames sync.Map + +// DontTracingSpanName disable tracing for the specified spanName. +func DontTracingSpanName(spanName string) { + dontTracingSpanNames.Store(spanName, lang.Placeholder) +} + // TracingHandler return a middleware that process the opentelemetry. func TracingHandler(serviceName, path string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { @@ -17,11 +26,21 @@ func TracingHandler(serviceName, path string) func(http.Handler) http.Handler { tracer := otel.GetTracerProvider().Tracer(trace.TraceName) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + next.ServeHTTP(w, r) + }() + ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) spanName := path if len(spanName) == 0 { spanName = r.URL.Path } + + _, ok := dontTracingSpanNames.Load(spanName) + if ok { + return + } + spanCtx, span := tracer.Start( ctx, spanName, @@ -33,7 +52,7 @@ func TracingHandler(serviceName, path string) func(http.Handler) http.Handler { // convenient for tracking error messages propagator.Inject(spanCtx, propagation.HeaderCarrier(w.Header())) - next.ServeHTTP(w, r.WithContext(spanCtx)) + r = r.WithContext(spanCtx) }) } } diff --git a/rest/handler/tracinghandler_test.go b/rest/handler/tracinghandler_test.go index abbc5127..bbc60911 100644 --- a/rest/handler/tracinghandler_test.go +++ b/rest/handler/tracinghandler_test.go @@ -51,3 +51,46 @@ func TestOtelHandler(t *testing.T) { }) } } + +func TestDontTracingSpanName(t *testing.T) { + ztrace.StartAgent(ztrace.Config{ + Name: "go-zero-test", + Endpoint: "http://localhost:14268/api/traces", + Batcher: "jaeger", + Sampler: 1.0, + }) + + DontTracingSpanName("bar") + + for _, test := range []string{"", "bar", "foo"} { + t.Run(test, func(t *testing.T) { + h := chain.New(TracingHandler("foo", test)).Then( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + spanCtx := trace.SpanContextFromContext(r.Context()) + if test == "bar" { + assert.False(t, spanCtx.IsValid()) + return + } + + assert.True(t, spanCtx.IsValid()) + })) + ts := httptest.NewServer(h) + defer ts.Close() + + client := ts.Client() + err := func(ctx context.Context) error { + ctx, span := otel.Tracer("httptrace/client").Start(ctx, "test") + defer span.End() + + req, _ := http.NewRequest("GET", ts.URL, nil) + otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header)) + + res, err := client.Do(req) + assert.Nil(t, err) + return res.Body.Close() + }(context.Background()) + + assert.Nil(t, err) + }) + } +}