chore: refactor (#1814)

master
Kevin Wan 3 years ago committed by GitHub
parent 162e9cef86
commit bc3c9484d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -473,9 +473,10 @@ func (p keepablePromise) keep(err error) error {
func acceptable(err error) bool { func acceptable(err error) bool {
return err == nil || err == mongo.ErrNoDocuments || err == mongo.ErrNilValue || return err == nil || err == mongo.ErrNoDocuments || err == mongo.ErrNilValue ||
err == mongo.ErrNilDocument || err == mongo.ErrNilCursor || err == mongo.ErrEmptySlice || err == mongo.ErrNilDocument || err == mongo.ErrNilCursor || err == mongo.ErrEmptySlice ||
// session err // session errors
err == session.ErrSessionEnded || err == session.ErrNoTransactStarted || err == session.ErrTransactInProgress || err == session.ErrSessionEnded || err == session.ErrNoTransactStarted ||
err == session.ErrAbortAfterCommit || err == session.ErrAbortTwice || err == session.ErrCommitAfterAbort || err == session.ErrTransactInProgress || err == session.ErrAbortAfterCommit ||
err == session.ErrAbortTwice || err == session.ErrCommitAfterAbort ||
err == session.ErrUnackWCUnsupported || err == session.ErrSnapshotTransaction err == session.ErrUnackWCUnsupported || err == session.ErrSnapshotTransaction
} }

@ -21,7 +21,7 @@ type (
opts []Option opts []Option
} }
wrapSession struct { wrappedSession struct {
mongo.Session mongo.Session
brk breaker.Breaker brk breaker.Breaker
} }
@ -74,7 +74,10 @@ func (m *Model) StartSession(opts ...*mopt.SessionOptions) (sess mongo.Session,
return sessionErr return sessionErr
} }
sess = &wrapSession{Session: session, brk: m.brk} sess = &wrappedSession{
Session: session,
brk: m.brk,
}
return nil return nil
}, acceptable) }, acceptable)
@ -166,7 +169,7 @@ func (m *Model) FindOneAndUpdate(ctx context.Context, v, filter interface{}, upd
return res.Decode(v) return res.Decode(v)
} }
func (w *wrapSession) AbortTransaction(ctx context.Context) error { func (w *wrappedSession) AbortTransaction(ctx context.Context) error {
ctx, span := startSpan(ctx) ctx, span := startSpan(ctx)
defer span.End() defer span.End()
@ -175,7 +178,7 @@ func (w *wrapSession) AbortTransaction(ctx context.Context) error {
}, acceptable) }, acceptable)
} }
func (w *wrapSession) CommitTransaction(ctx context.Context) error { func (w *wrappedSession) CommitTransaction(ctx context.Context) error {
ctx, span := startSpan(ctx) ctx, span := startSpan(ctx)
defer span.End() defer span.End()
@ -184,7 +187,11 @@ func (w *wrapSession) CommitTransaction(ctx context.Context) error {
}, acceptable) }, acceptable)
} }
func (w *wrapSession) WithTransaction(ctx context.Context, fn func(sessCtx mongo.SessionContext) (interface{}, error), opts ...*mopt.TransactionOptions) (res interface{}, err error) { func (w *wrappedSession) WithTransaction(
ctx context.Context,
fn func(sessCtx mongo.SessionContext) (interface{}, error),
opts ...*mopt.TransactionOptions,
) (res interface{}, err error) {
ctx, span := startSpan(ctx) ctx, span := startSpan(ctx)
defer span.End() defer span.End()
@ -196,7 +203,7 @@ func (w *wrapSession) WithTransaction(ctx context.Context, fn func(sessCtx mongo
return return
} }
func (w *wrapSession) EndSession(ctx context.Context) { func (w *wrappedSession) EndSession(ctx context.Context) {
ctx, span := startSpan(ctx) ctx, span := startSpan(ctx)
defer span.End() defer span.End()

@ -18,6 +18,7 @@ func TestModel_StartSession(t *testing.T) {
m := createModel(mt) m := createModel(mt)
sess, err := m.StartSession() sess, err := m.StartSession()
assert.Nil(t, err) assert.Nil(t, err)
defer sess.EndSession(context.Background())
_, err = sess.WithTransaction(context.Background(), func(sessCtx mongo.SessionContext) (interface{}, error) { _, err = sess.WithTransaction(context.Background(), func(sessCtx mongo.SessionContext) (interface{}, error) {
_ = sessCtx.StartTransaction() _ = sessCtx.StartTransaction()
@ -26,10 +27,8 @@ func TestModel_StartSession(t *testing.T) {
return nil, nil return nil, nil
}) })
assert.Nil(t, err) assert.Nil(t, err)
assert.NoError(t, sess.CommitTransaction(context.Background())) assert.NoError(t, sess.CommitTransaction(context.Background()))
assert.Error(t, sess.AbortTransaction(context.Background())) assert.Error(t, sess.AbortTransaction(context.Background()))
sess.EndSession(context.Background())
}) })
} }

@ -14,6 +14,9 @@ import (
tracestd "go.opentelemetry.io/otel/trace" tracestd "go.opentelemetry.io/otel/trace"
) )
// spanName is the span name of the redis calls.
const spanName = "redis"
var ( var (
startTimeKey = contextKey("startTime") startTimeKey = contextKey("startTime")
spanKey = contextKey("span") spanKey = contextKey("span")
@ -28,11 +31,11 @@ type (
) )
func (h hook) BeforeProcess(ctx context.Context, _ red.Cmder) (context.Context, error) { func (h hook) BeforeProcess(ctx context.Context, _ red.Cmder) (context.Context, error) {
return h.spanStart(context.WithValue(ctx, startTimeKey, timex.Now())), nil return h.startSpan(context.WithValue(ctx, startTimeKey, timex.Now())), nil
} }
func (h hook) AfterProcess(ctx context.Context, cmd red.Cmder) error { func (h hook) AfterProcess(ctx context.Context, cmd red.Cmder) error {
h.spanEnd(ctx) h.endSpan(ctx)
val := ctx.Value(startTimeKey) val := ctx.Value(startTimeKey)
if val == nil { if val == nil {
@ -53,11 +56,11 @@ func (h hook) AfterProcess(ctx context.Context, cmd red.Cmder) error {
} }
func (h hook) BeforeProcessPipeline(ctx context.Context, _ []red.Cmder) (context.Context, error) { func (h hook) BeforeProcessPipeline(ctx context.Context, _ []red.Cmder) (context.Context, error) {
return h.spanStart(context.WithValue(ctx, startTimeKey, timex.Now())), nil return h.startSpan(context.WithValue(ctx, startTimeKey, timex.Now())), nil
} }
func (h hook) AfterProcessPipeline(ctx context.Context, cmds []red.Cmder) error { func (h hook) AfterProcessPipeline(ctx context.Context, cmds []red.Cmder) error {
h.spanEnd(ctx) h.endSpan(ctx)
if len(cmds) == 0 { if len(cmds) == 0 {
return nil return nil
@ -92,12 +95,12 @@ func logDuration(ctx context.Context, cmd red.Cmder, duration time.Duration) {
logx.WithContext(ctx).WithDuration(duration).Slowf("[REDIS] slowcall on executing: %s", buf.String()) logx.WithContext(ctx).WithDuration(duration).Slowf("[REDIS] slowcall on executing: %s", buf.String())
} }
func (h hook) spanStart(ctx context.Context) context.Context { func (h hook) startSpan(ctx context.Context) context.Context {
ctx, span := h.tracer.Start(ctx, "redis") ctx, span := h.tracer.Start(ctx, spanName)
return context.WithValue(ctx, spanKey, span) return context.WithValue(ctx, spanKey, span)
} }
func (h hook) spanEnd(ctx context.Context) { func (h hook) endSpan(ctx context.Context) {
spanVal := ctx.Value(spanKey) spanVal := ctx.Value(spanKey)
if spanVal == nil { if spanVal == nil {
return return

@ -11,6 +11,9 @@ import (
tracesdk "go.opentelemetry.io/otel/trace" tracesdk "go.opentelemetry.io/otel/trace"
) )
// spanName is used to identify the span name for the SQL execution.
const spanName = "sql"
// ErrNotFound is an alias of sql.ErrNoRows // ErrNotFound is an alias of sql.ErrNoRows
var ErrNotFound = sql.ErrNoRows var ErrNotFound = sql.ErrNoRows
@ -240,7 +243,6 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{},
return db.queryRows(ctx, func(rows *sql.Rows) error { return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false) return unmarshalRows(v, rows, false)
}, q, args...) }, q, args...)
} }
func (db *commonSqlConn) RawDB() (*sql.DB, error) { func (db *commonSqlConn) RawDB() (*sql.DB, error) {
@ -362,5 +364,5 @@ func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args
func startSpan(ctx context.Context) (context.Context, tracesdk.Span) { func startSpan(ctx context.Context) (context.Context, tracesdk.Span) {
tracer := otel.GetTracerProvider().Tracer(trace.TraceName) tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
return tracer.Start(ctx, "sql") return tracer.Start(ctx, spanName)
} }

Loading…
Cancel
Save