diff --git a/core/stores/mon/collection.go b/core/stores/mon/collection.go index 41cedc06..9bb5bb44 100644 --- a/core/stores/mon/collection.go +++ b/core/stores/mon/collection.go @@ -16,7 +16,11 @@ import ( tracesdk "go.opentelemetry.io/otel/trace" ) -const defaultSlowThreshold = time.Millisecond * 500 +const ( + defaultSlowThreshold = time.Millisecond * 500 + // spanName is the span name of the mongo calls. + spanName = "mongo" +) // ErrNotFound is an alias of mongo.ErrNoDocuments var ErrNotFound = mongo.ErrNoDocuments @@ -482,5 +486,5 @@ func acceptable(err error) bool { func startSpan(ctx context.Context) (context.Context, tracesdk.Span) { tracer := otel.GetTracerProvider().Tracer(trace.TraceName) - return tracer.Start(ctx, "mongo") + return tracer.Start(ctx, spanName) } diff --git a/core/stores/redis/hook.go b/core/stores/redis/hook.go index 43043be4..b3b1f144 100644 --- a/core/stores/redis/hook.go +++ b/core/stores/redis/hook.go @@ -19,7 +19,6 @@ const spanName = "redis" var ( startTimeKey = contextKey("startTime") - spanKey = contextKey("span") durationHook = hook{tracer: otel.GetTracerProvider().Tracer(trace.TraceName)} ) @@ -96,17 +95,10 @@ func logDuration(ctx context.Context, cmd red.Cmder, duration time.Duration) { } func (h hook) startSpan(ctx context.Context) context.Context { - ctx, span := h.tracer.Start(ctx, spanName) - return context.WithValue(ctx, spanKey, span) + ctx, _ = h.tracer.Start(ctx, spanName) + return ctx } func (h hook) endSpan(ctx context.Context) { - spanVal := ctx.Value(spanKey) - if spanVal == nil { - return - } - - if span, ok := spanVal.(tracestd.Span); ok { - span.End() - } + tracestd.SpanFromContext(ctx).End() } diff --git a/core/stores/redis/hook_test.go b/core/stores/redis/hook_test.go index 554e869c..696b7d7a 100644 --- a/core/stores/redis/hook_test.go +++ b/core/stores/redis/hook_test.go @@ -10,6 +10,7 @@ import ( red "github.com/go-redis/redis/v8" "github.com/stretchr/testify/assert" ztrace "github.com/zeromicro/go-zero/core/trace" + tracesdk "go.opentelemetry.io/otel/trace" ) func TestHookProcessCase1(t *testing.T) { @@ -32,7 +33,7 @@ func TestHookProcessCase1(t *testing.T) { assert.Nil(t, durationHook.AfterProcess(ctx, red.NewCmd(context.Background()))) assert.False(t, strings.Contains(buf.String(), "slow")) - assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name()) + assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name()) } func TestHookProcessCase2(t *testing.T) { @@ -52,7 +53,7 @@ func TestHookProcessCase2(t *testing.T) { if err != nil { t.Fatal(err) } - assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name()) + assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name()) time.Sleep(slowThreshold.Load() + time.Millisecond) @@ -93,7 +94,7 @@ func TestHookProcessPipelineCase1(t *testing.T) { if err != nil { t.Fatal(err) } - assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name()) + assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name()) assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{ red.NewCmd(context.Background()), @@ -118,7 +119,7 @@ func TestHookProcessPipelineCase2(t *testing.T) { if err != nil { t.Fatal(err) } - assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name()) + assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name()) time.Sleep(slowThreshold.Load() + time.Millisecond)