diff --git a/core/stores/mon/collection.go b/core/stores/mon/collection.go index c9e3ed2a..cfed2f68 100644 --- a/core/stores/mon/collection.go +++ b/core/stores/mon/collection.go @@ -8,8 +8,12 @@ import ( "github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/timex" + "github.com/zeromicro/go-zero/core/trace" "go.mongodb.org/mongo-driver/mongo" mopt "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/x/mongo/driver/session" + "go.opentelemetry.io/otel" + tracesdk "go.opentelemetry.io/otel/trace" ) const defaultSlowThreshold = time.Millisecond * 500 @@ -112,6 +116,9 @@ func newCollection(collection *mongo.Collection, brk breaker.Breaker) Collection func (c *decoratedCollection) Aggregate(ctx context.Context, pipeline interface{}, opts ...*mopt.AggregateOptions) (cur *mongo.Cursor, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { starTime := timex.Now() defer func() { @@ -126,6 +133,9 @@ func (c *decoratedCollection) Aggregate(ctx context.Context, pipeline interface{ func (c *decoratedCollection) BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...*mopt.BulkWriteOptions) (res *mongo.BulkWriteResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -140,6 +150,9 @@ func (c *decoratedCollection) BulkWrite(ctx context.Context, models []mongo.Writ func (c *decoratedCollection) CountDocuments(ctx context.Context, filter interface{}, opts ...*mopt.CountOptions) (count int64, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -154,6 +167,9 @@ func (c *decoratedCollection) CountDocuments(ctx context.Context, filter interfa func (c *decoratedCollection) DeleteMany(ctx context.Context, filter interface{}, opts ...*mopt.DeleteOptions) (res *mongo.DeleteResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -168,6 +184,9 @@ func (c *decoratedCollection) DeleteMany(ctx context.Context, filter interface{} func (c *decoratedCollection) DeleteOne(ctx context.Context, filter interface{}, opts ...*mopt.DeleteOptions) (res *mongo.DeleteResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -182,6 +201,9 @@ func (c *decoratedCollection) DeleteOne(ctx context.Context, filter interface{}, func (c *decoratedCollection) Distinct(ctx context.Context, fieldName string, filter interface{}, opts ...*mopt.DistinctOptions) (val []interface{}, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -196,6 +218,9 @@ func (c *decoratedCollection) Distinct(ctx context.Context, fieldName string, fi func (c *decoratedCollection) EstimatedDocumentCount(ctx context.Context, opts ...*mopt.EstimatedDocumentCountOptions) (val int64, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -210,6 +235,9 @@ func (c *decoratedCollection) EstimatedDocumentCount(ctx context.Context, func (c *decoratedCollection) Find(ctx context.Context, filter interface{}, opts ...*mopt.FindOptions) (cur *mongo.Cursor, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -224,6 +252,9 @@ func (c *decoratedCollection) Find(ctx context.Context, filter interface{}, func (c *decoratedCollection) FindOne(ctx context.Context, filter interface{}, opts ...*mopt.FindOneOptions) (res *mongo.SingleResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -239,6 +270,9 @@ func (c *decoratedCollection) FindOne(ctx context.Context, filter interface{}, func (c *decoratedCollection) FindOneAndDelete(ctx context.Context, filter interface{}, opts ...*mopt.FindOneAndDeleteOptions) (res *mongo.SingleResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -255,6 +289,9 @@ func (c *decoratedCollection) FindOneAndDelete(ctx context.Context, filter inter func (c *decoratedCollection) FindOneAndReplace(ctx context.Context, filter interface{}, replacement interface{}, opts ...*mopt.FindOneAndReplaceOptions) ( res *mongo.SingleResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -270,6 +307,9 @@ func (c *decoratedCollection) FindOneAndReplace(ctx context.Context, filter inte func (c *decoratedCollection) FindOneAndUpdate(ctx context.Context, filter interface{}, update interface{}, opts ...*mopt.FindOneAndUpdateOptions) (res *mongo.SingleResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -285,6 +325,9 @@ func (c *decoratedCollection) FindOneAndUpdate(ctx context.Context, filter inter func (c *decoratedCollection) InsertMany(ctx context.Context, documents []interface{}, opts ...*mopt.InsertManyOptions) (res *mongo.InsertManyResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -299,6 +342,9 @@ func (c *decoratedCollection) InsertMany(ctx context.Context, documents []interf func (c *decoratedCollection) InsertOne(ctx context.Context, document interface{}, opts ...*mopt.InsertOneOptions) (res *mongo.InsertOneResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -313,6 +359,9 @@ func (c *decoratedCollection) InsertOne(ctx context.Context, document interface{ func (c *decoratedCollection) ReplaceOne(ctx context.Context, filter interface{}, replacement interface{}, opts ...*mopt.ReplaceOptions) (res *mongo.UpdateResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -327,6 +376,9 @@ func (c *decoratedCollection) ReplaceOne(ctx context.Context, filter interface{} func (c *decoratedCollection) UpdateByID(ctx context.Context, id interface{}, update interface{}, opts ...*mopt.UpdateOptions) (res *mongo.UpdateResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -341,6 +393,9 @@ func (c *decoratedCollection) UpdateByID(ctx context.Context, id interface{}, up func (c *decoratedCollection) UpdateMany(ctx context.Context, filter interface{}, update interface{}, opts ...*mopt.UpdateOptions) (res *mongo.UpdateResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -355,6 +410,9 @@ func (c *decoratedCollection) UpdateMany(ctx context.Context, filter interface{} func (c *decoratedCollection) UpdateOne(ctx context.Context, filter interface{}, update interface{}, opts ...*mopt.UpdateOptions) (res *mongo.UpdateResult, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = c.brk.DoWithAcceptable(func() error { startTime := timex.Now() defer func() { @@ -414,5 +472,14 @@ func (p keepablePromise) keep(err error) error { func acceptable(err error) bool { return err == nil || err == mongo.ErrNoDocuments || err == mongo.ErrNilValue || - err == mongo.ErrNilDocument || err == mongo.ErrNilCursor || err == mongo.ErrEmptySlice + err == mongo.ErrNilDocument || err == mongo.ErrNilCursor || err == mongo.ErrEmptySlice || + // session err + err == session.ErrSessionEnded || err == session.ErrNoTransactStarted || err == session.ErrTransactInProgress || + err == session.ErrAbortAfterCommit || err == session.ErrAbortTwice || err == session.ErrCommitAfterAbort || + err == session.ErrUnackWCUnsupported || err == session.ErrSnapshotTransaction +} + +func startSpan(ctx context.Context) (context.Context, tracesdk.Span) { + tracer := otel.GetTracerProvider().Tracer(trace.TraceName) + return tracer.Start(ctx, "mongo") } diff --git a/core/stores/mon/model.go b/core/stores/mon/model.go index 7d07bbcd..dd4bab2e 100644 --- a/core/stores/mon/model.go +++ b/core/stores/mon/model.go @@ -11,14 +11,21 @@ import ( mopt "go.mongodb.org/mongo-driver/mongo/options" ) -// Model is a mongodb store model that represents a collection. -type Model struct { - Collection - name string - cli *mongo.Client - brk breaker.Breaker - opts []Option -} +type ( + // Model is a mongodb store model that represents a collection. + Model struct { + Collection + name string + cli *mongo.Client + brk breaker.Breaker + opts []Option + } + + wrapSession struct { + mongo.Session + brk breaker.Breaker + } +) // MustNewModel returns a Model, exits on errors. func MustNewModel(uri, db, collection string, opts ...Option) *Model { @@ -62,8 +69,14 @@ func (m *Model) StartSession(opts ...*mopt.SessionOptions) (sess mongo.Session, logDuration(m.name, "StartSession", starTime, err) }() - sess, err = m.cli.StartSession(opts...) - return err + session, sessionErr := m.cli.StartSession(opts...) + if sessionErr != nil { + return sessionErr + } + + sess = &wrapSession{Session: session, brk: m.brk} + + return nil }, acceptable) return } @@ -152,3 +165,43 @@ func (m *Model) FindOneAndUpdate(ctx context.Context, v, filter interface{}, upd return res.Decode(v) } + +func (w *wrapSession) AbortTransaction(ctx context.Context) error { + ctx, span := startSpan(ctx) + defer span.End() + + return w.brk.DoWithAcceptable(func() error { + return w.Session.AbortTransaction(ctx) + }, acceptable) +} + +func (w *wrapSession) CommitTransaction(ctx context.Context) error { + ctx, span := startSpan(ctx) + defer span.End() + + return w.brk.DoWithAcceptable(func() error { + return w.Session.CommitTransaction(ctx) + }, acceptable) +} + +func (w *wrapSession) WithTransaction(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*mopt.TransactionOptions) (res interface{}, err error) { + ctx, span := startSpan(ctx) + defer span.End() + + err = w.brk.DoWithAcceptable(func() error { + res, err = w.Session.WithTransaction(ctx, fn, opts...) + return err + }, acceptable) + + return +} + +func (w *wrapSession) EndSession(ctx context.Context) { + ctx, span := startSpan(ctx) + defer span.End() + + _ = w.brk.DoWithAcceptable(func() error { + w.Session.EndSession(ctx) + return nil + }, acceptable) +} diff --git a/core/stores/mon/model_test.go b/core/stores/mon/model_test.go index 7adb0a0e..f7dcd4e2 100644 --- a/core/stores/mon/model_test.go +++ b/core/stores/mon/model_test.go @@ -18,6 +18,17 @@ func TestModel_StartSession(t *testing.T) { m := createModel(mt) sess, err := m.StartSession() assert.Nil(t, err) + + _, err = sess.WithTransaction(context.Background(), func(sessCtx mongo.SessionContext) (interface{}, error) { + _ = sessCtx.StartTransaction() + sessCtx.Client().Database("1") + sessCtx.EndSession(context.Background()) + return nil, nil + }) + assert.Nil(t, err) + + assert.NoError(t, sess.CommitTransaction(context.Background())) + assert.Error(t, sess.AbortTransaction(context.Background())) sess.EndSession(context.Background()) }) } diff --git a/core/stores/redis/hook.go b/core/stores/redis/hook.go index a0013c34..00373113 100644 --- a/core/stores/redis/hook.go +++ b/core/stores/redis/hook.go @@ -9,23 +9,31 @@ import ( "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/mapping" "github.com/zeromicro/go-zero/core/timex" + "github.com/zeromicro/go-zero/core/trace" + "go.opentelemetry.io/otel" + tracestd "go.opentelemetry.io/otel/trace" ) var ( startTimeKey = contextKey("startTime") - durationHook = hook{} + spanKey = contextKey("span") + durationHook = hook{tracer: otel.GetTracerProvider().Tracer(trace.TraceName)} ) type ( contextKey string - hook struct{} + hook struct { + tracer tracestd.Tracer + } ) func (h hook) BeforeProcess(ctx context.Context, _ red.Cmder) (context.Context, error) { - return context.WithValue(ctx, startTimeKey, timex.Now()), nil + return h.spanStart(context.WithValue(ctx, startTimeKey, timex.Now())), nil } func (h hook) AfterProcess(ctx context.Context, cmd red.Cmder) error { + h.spanEnd(ctx) + val := ctx.Value(startTimeKey) if val == nil { return nil @@ -45,10 +53,12 @@ func (h hook) AfterProcess(ctx context.Context, cmd red.Cmder) error { } func (h hook) BeforeProcessPipeline(ctx context.Context, _ []red.Cmder) (context.Context, error) { - return context.WithValue(ctx, startTimeKey, timex.Now()), nil + return h.spanStart(context.WithValue(ctx, startTimeKey, timex.Now())), nil } func (h hook) AfterProcessPipeline(ctx context.Context, cmds []red.Cmder) error { + h.spanEnd(ctx) + if len(cmds) == 0 { return nil } @@ -81,3 +91,19 @@ func logDuration(ctx context.Context, cmd red.Cmder, duration time.Duration) { } logx.WithContext(ctx).WithDuration(duration).Slowf("[REDIS] slowcall on executing: %s", buf.String()) } + +func (h hook) spanStart(ctx context.Context) context.Context { + ctx, span := h.tracer.Start(ctx, "redis") + return context.WithValue(ctx, spanKey, span) +} + +func (h hook) spanEnd(ctx context.Context) { + spanVal := ctx.Value(spanKey) + if spanVal == nil { + return + } + + if span, ok := spanVal.(tracestd.Span); ok { + span.End() + } +} diff --git a/core/stores/redis/hook_test.go b/core/stores/redis/hook_test.go index f1e993ab..554e869c 100644 --- a/core/stores/redis/hook_test.go +++ b/core/stores/redis/hook_test.go @@ -9,9 +9,17 @@ import ( red "github.com/go-redis/redis/v8" "github.com/stretchr/testify/assert" + ztrace "github.com/zeromicro/go-zero/core/trace" ) func TestHookProcessCase1(t *testing.T) { + ztrace.StartAgent(ztrace.Config{ + Name: "go-zero-test", + Endpoint: "http://localhost:14268/api/traces", + Batcher: "jaeger", + Sampler: 1.0, + }) + writer := log.Writer() var buf strings.Builder log.SetOutput(&buf) @@ -24,9 +32,17 @@ func TestHookProcessCase1(t *testing.T) { assert.Nil(t, durationHook.AfterProcess(ctx, red.NewCmd(context.Background()))) assert.False(t, strings.Contains(buf.String(), "slow")) + assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name()) } func TestHookProcessCase2(t *testing.T) { + ztrace.StartAgent(ztrace.Config{ + Name: "go-zero-test", + Endpoint: "http://localhost:14268/api/traces", + Batcher: "jaeger", + Sampler: 1.0, + }) + writer := log.Writer() var buf strings.Builder log.SetOutput(&buf) @@ -36,11 +52,14 @@ func TestHookProcessCase2(t *testing.T) { if err != nil { t.Fatal(err) } + assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name()) time.Sleep(slowThreshold.Load() + time.Millisecond) assert.Nil(t, durationHook.AfterProcess(ctx, red.NewCmd(context.Background(), "foo", "bar"))) assert.True(t, strings.Contains(buf.String(), "slow")) + assert.True(t, strings.Contains(buf.String(), "trace")) + assert.True(t, strings.Contains(buf.String(), "span")) } func TestHookProcessCase3(t *testing.T) { @@ -74,6 +93,7 @@ func TestHookProcessPipelineCase1(t *testing.T) { if err != nil { t.Fatal(err) } + assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name()) assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{ red.NewCmd(context.Background()), @@ -82,6 +102,13 @@ func TestHookProcessPipelineCase1(t *testing.T) { } func TestHookProcessPipelineCase2(t *testing.T) { + ztrace.StartAgent(ztrace.Config{ + Name: "go-zero-test", + Endpoint: "http://localhost:14268/api/traces", + Batcher: "jaeger", + Sampler: 1.0, + }) + writer := log.Writer() var buf strings.Builder log.SetOutput(&buf) @@ -91,6 +118,7 @@ func TestHookProcessPipelineCase2(t *testing.T) { if err != nil { t.Fatal(err) } + assert.Equal(t, "redis", ctx.Value(spanKey).(interface{ Name() string }).Name()) time.Sleep(slowThreshold.Load() + time.Millisecond) @@ -98,6 +126,8 @@ func TestHookProcessPipelineCase2(t *testing.T) { red.NewCmd(context.Background(), "foo", "bar"), })) assert.True(t, strings.Contains(buf.String(), "slow")) + assert.True(t, strings.Contains(buf.String(), "trace")) + assert.True(t, strings.Contains(buf.String(), "span")) } func TestHookProcessPipelineCase3(t *testing.T) { diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go index e2b017d2..6a2737c9 100644 --- a/core/stores/sqlx/sqlconn.go +++ b/core/stores/sqlx/sqlconn.go @@ -6,6 +6,9 @@ import ( "github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/trace" + "go.opentelemetry.io/otel" + tracesdk "go.opentelemetry.io/otel/trace" ) // ErrNotFound is an alias of sql.ErrNoRows @@ -134,6 +137,9 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...interface{}) ( result sql.Result, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = db.brk.DoWithAcceptable(func() error { var conn *sql.DB conn, err = db.connProv() @@ -154,6 +160,9 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) { } func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) { + ctx, span := startSpan(ctx) + defer span.End() + err = db.brk.DoWithAcceptable(func() error { var conn *sql.DB conn, err = db.connProv() @@ -183,6 +192,9 @@ func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + return db.queryRows(ctx, func(rows *sql.Rows) error { return unmarshalRow(v, rows, true) }, q, args...) @@ -194,6 +206,9 @@ func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interf func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + return db.queryRows(ctx, func(rows *sql.Rows) error { return unmarshalRow(v, rows, false) }, q, args...) @@ -205,6 +220,9 @@ func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + return db.queryRows(ctx, func(rows *sql.Rows) error { return unmarshalRows(v, rows, true) }, q, args...) @@ -216,9 +234,13 @@ func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...inter func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + return db.queryRows(ctx, func(rows *sql.Rows) error { return unmarshalRows(v, rows, false) }, q, args...) + } func (db *commonSqlConn) RawDB() (*sql.DB, error) { @@ -232,6 +254,9 @@ func (db *commonSqlConn) Transact(fn func(Session) error) error { } func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error { + ctx, span := startSpan(ctx) + defer span.End() + return db.brk.DoWithAcceptable(func() error { return transact(ctx, db, db.beginTx, fn) }, db.acceptable) @@ -248,6 +273,9 @@ func (db *commonSqlConn) acceptable(err error) bool { func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error, q string, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + var qerr error return db.brk.DoWithAcceptable(func() error { conn, err := db.connProv() @@ -274,6 +302,9 @@ func (s statement) Exec(args ...interface{}) (sql.Result, error) { } func (s statement) ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error) { + ctx, span := startSpan(ctx) + defer span.End() + return execStmt(ctx, s.stmt, s.query, args...) } @@ -282,6 +313,9 @@ func (s statement) QueryRow(v interface{}, args ...interface{}) error { } func (s statement) QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { return unmarshalRow(v, rows, true) }, s.query, args...) @@ -292,6 +326,9 @@ func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error { } func (s statement) QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { return unmarshalRow(v, rows, false) }, s.query, args...) @@ -302,6 +339,9 @@ func (s statement) QueryRows(v interface{}, args ...interface{}) error { } func (s statement) QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { return unmarshalRows(v, rows, true) }, s.query, args...) @@ -312,7 +352,15 @@ func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error { } func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { return unmarshalRows(v, rows, false) }, s.query, args...) } + +func startSpan(ctx context.Context) (context.Context, tracesdk.Span) { + tracer := otel.GetTracerProvider().Tracer(trace.TraceName) + return tracer.Start(ctx, "sql") +} diff --git a/core/stores/sqlx/tx.go b/core/stores/sqlx/tx.go index 67c02ff1..98d3b1d8 100644 --- a/core/stores/sqlx/tx.go +++ b/core/stores/sqlx/tx.go @@ -31,6 +31,9 @@ func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) { } func (t txSession) ExecCtx(ctx context.Context, q string, args ...interface{}) (sql.Result, error) { + ctx, span := startSpan(ctx) + defer span.End() + return exec(ctx, t.Tx, q, args...) } @@ -39,6 +42,9 @@ func (t txSession) Prepare(q string) (StmtSession, error) { } func (t txSession) PrepareCtx(ctx context.Context, q string) (StmtSession, error) { + ctx, span := startSpan(ctx) + defer span.End() + stmt, err := t.Tx.PrepareContext(ctx, q) if err != nil { return nil, err @@ -55,6 +61,9 @@ func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error } func (t txSession) QueryRowCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + return query(ctx, t.Tx, func(rows *sql.Rows) error { return unmarshalRow(v, rows, true) }, q, args...) @@ -66,6 +75,9 @@ func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) func (t txSession) QueryRowPartialCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + return query(ctx, t.Tx, func(rows *sql.Rows) error { return unmarshalRow(v, rows, false) }, q, args...) @@ -76,6 +88,9 @@ func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error } func (t txSession) QueryRowsCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + return query(ctx, t.Tx, func(rows *sql.Rows) error { return unmarshalRows(v, rows, true) }, q, args...) @@ -87,6 +102,9 @@ func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{} func (t txSession) QueryRowsPartialCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error { + ctx, span := startSpan(ctx) + defer span.End() + return query(ctx, t.Tx, func(rows *sql.Rows) error { return unmarshalRows(v, rows, false) }, q, args...)