You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/core/stores/sqlx/tx.go

161 lines
3.7 KiB
Go

4 years ago
package sqlx
import (
"context"
4 years ago
"database/sql"
"fmt"
)
type (
beginnable func(*sql.DB) (trans, error)
trans interface {
Session
Commit() error
Rollback() error
}
txSession struct {
*sql.Tx
}
)
// NewSessionFromTx returns a Session with the given sql.Tx.
// Use it with caution, it's provided for other ORM to interact with.
func NewSessionFromTx(tx *sql.Tx) Session {
return txSession{Tx: tx}
}
4 years ago
func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
return t.ExecCtx(context.Background(), q, args...)
}
func (t txSession) ExecCtx(ctx context.Context, q string, args ...interface{}) (sql.Result, error) {
ctx, span := startSpan(ctx)
defer span.End()
return exec(ctx, t.Tx, q, args...)
4 years ago
}
func (t txSession) Prepare(q string) (StmtSession, error) {
return t.PrepareCtx(context.Background(), q)
}
func (t txSession) PrepareCtx(ctx context.Context, q string) (StmtSession, error) {
ctx, span := startSpan(ctx)
defer span.End()
stmt, err := t.Tx.PrepareContext(ctx, q)
if err != nil {
4 years ago
return nil, err
}
return statement{
query: q,
stmt: stmt,
}, nil
4 years ago
}
func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
return t.QueryRowCtx(context.Background(), v, q, args...)
}
func (t txSession) QueryRowCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return query(ctx, t.Tx, func(rows *sql.Rows) error {
4 years ago
return unmarshalRow(v, rows, true)
}, q, args...)
}
func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return t.QueryRowPartialCtx(context.Background(), v, q, args...)
}
func (t txSession) QueryRowPartialCtx(ctx context.Context, v interface{}, q string,
args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return query(ctx, t.Tx, func(rows *sql.Rows) error {
4 years ago
return unmarshalRow(v, rows, false)
}, q, args...)
}
func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
return t.QueryRowsCtx(context.Background(), v, q, args...)
}
func (t txSession) QueryRowsCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return query(ctx, t.Tx, func(rows *sql.Rows) error {
4 years ago
return unmarshalRows(v, rows, true)
}, q, args...)
}
func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return t.QueryRowsPartialCtx(context.Background(), v, q, args...)
}
func (t txSession) QueryRowsPartialCtx(ctx context.Context, v interface{}, q string,
args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return query(ctx, t.Tx, func(rows *sql.Rows) error {
4 years ago
return unmarshalRows(v, rows, false)
}, q, args...)
}
func begin(db *sql.DB) (trans, error) {
tx, err := db.Begin()
if err != nil {
4 years ago
return nil, err
}
return txSession{
Tx: tx,
}, nil
4 years ago
}
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
fn func(context.Context, Session) error) (err error) {
conn, err := db.connProv()
4 years ago
if err != nil {
db.onError(err)
4 years ago
return err
}
return transactOnConn(ctx, conn, b, fn)
4 years ago
}
func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable,
fn func(context.Context, Session) error) (err error) {
4 years ago
var tx trans
tx, err = b(conn)
if err != nil {
return
}
4 years ago
defer func() {
if p := recover(); p != nil {
if e := tx.Rollback(); e != nil {
err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e)
4 years ago
} else {
err = fmt.Errorf("recoveer from %#v", p)
}
} else if err != nil {
if e := tx.Rollback(); e != nil {
err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e)
4 years ago
}
} else {
err = tx.Commit()
}
}()
return fn(ctx, tx)
4 years ago
}