chore: optimize code (#1818)

Signed-off-by: chenquan <chenquan.dev@gmail.com>
master
chen quan 3 years ago committed by GitHub
parent 095b603788
commit 22b157bb6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,7 +16,11 @@ import (
tracesdk "go.opentelemetry.io/otel/trace" tracesdk "go.opentelemetry.io/otel/trace"
) )
const defaultSlowThreshold = time.Millisecond * 500 const (
defaultSlowThreshold = time.Millisecond * 500
// spanName is the span name of the mongo calls.
spanName = "mongo"
)
// ErrNotFound is an alias of mongo.ErrNoDocuments // ErrNotFound is an alias of mongo.ErrNoDocuments
var ErrNotFound = mongo.ErrNoDocuments var ErrNotFound = mongo.ErrNoDocuments
@ -482,5 +486,5 @@ func acceptable(err error) bool {
func startSpan(ctx context.Context) (context.Context, tracesdk.Span) { func startSpan(ctx context.Context) (context.Context, tracesdk.Span) {
tracer := otel.GetTracerProvider().Tracer(trace.TraceName) tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
return tracer.Start(ctx, "mongo") return tracer.Start(ctx, spanName)
} }

@ -19,7 +19,6 @@ const spanName = "redis"
var ( var (
startTimeKey = contextKey("startTime") startTimeKey = contextKey("startTime")
spanKey = contextKey("span")
durationHook = hook{tracer: otel.GetTracerProvider().Tracer(trace.TraceName)} durationHook = hook{tracer: otel.GetTracerProvider().Tracer(trace.TraceName)}
) )
@ -96,17 +95,10 @@ func logDuration(ctx context.Context, cmd red.Cmder, duration time.Duration) {
} }
func (h hook) startSpan(ctx context.Context) context.Context { func (h hook) startSpan(ctx context.Context) context.Context {
ctx, span := h.tracer.Start(ctx, spanName) ctx, _ = h.tracer.Start(ctx, spanName)
return context.WithValue(ctx, spanKey, span) return ctx
} }
func (h hook) endSpan(ctx context.Context) { func (h hook) endSpan(ctx context.Context) {
spanVal := ctx.Value(spanKey) tracestd.SpanFromContext(ctx).End()
if spanVal == nil {
return
}
if span, ok := spanVal.(tracestd.Span); ok {
span.End()
}
} }

@ -10,6 +10,7 @@ import (
red "github.com/go-redis/redis/v8" red "github.com/go-redis/redis/v8"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
ztrace "github.com/zeromicro/go-zero/core/trace" ztrace "github.com/zeromicro/go-zero/core/trace"
tracesdk "go.opentelemetry.io/otel/trace"
) )
func TestHookProcessCase1(t *testing.T) { func TestHookProcessCase1(t *testing.T) {
@ -32,7 +33,7 @@ func TestHookProcessCase1(t *testing.T) {
assert.Nil(t, durationHook.AfterProcess(ctx, red.NewCmd(context.Background()))) assert.Nil(t, durationHook.AfterProcess(ctx, red.NewCmd(context.Background())))
assert.False(t, strings.Contains(buf.String(), "slow")) assert.False(t, strings.Contains(buf.String(), "slow"))
assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name()) assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
} }
func TestHookProcessCase2(t *testing.T) { func TestHookProcessCase2(t *testing.T) {
@ -52,7 +53,7 @@ func TestHookProcessCase2(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name()) assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
time.Sleep(slowThreshold.Load() + time.Millisecond) time.Sleep(slowThreshold.Load() + time.Millisecond)
@ -93,7 +94,7 @@ func TestHookProcessPipelineCase1(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name()) assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{ assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
red.NewCmd(context.Background()), red.NewCmd(context.Background()),
@ -118,7 +119,7 @@ func TestHookProcessPipelineCase2(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name()) assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
time.Sleep(slowThreshold.Load() + time.Millisecond) time.Sleep(slowThreshold.Load() + time.Millisecond)

Loading…
Cancel
Save