You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/core/stores/sqlx/tx_test.go

313 lines
8.0 KiB
Go

package sqlx
import (
"context"
"database/sql"
"errors"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/stores/dbtest"
)
const (
mockCommit = 1
mockRollback = 2
)
type mockTx struct {
status int
}
func (mt *mockTx) Commit() error {
mt.status |= mockCommit
return nil
}
func (mt *mockTx) Exec(_ string, _ ...any) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) Prepare(_ string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) PrepareCtx(_ context.Context, _ string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) QueryRow(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowPartial(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRows(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowsPartial(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) Rollback() error {
mt.status |= mockRollback
return nil
}
func beginMock(mock *mockTx) beginnable {
return func(*sql.DB) (trans, error) {
return mock, nil
}
}
func TestTransactCommit(t *testing.T) {
mock := &mockTx{}
err := transactOnConn(context.Background(), nil, beginMock(mock),
func(context.Context, Session) error {
return nil
})
assert.Equal(t, mockCommit, mock.status)
assert.Nil(t, err)
}
func TestTransactRollback(t *testing.T) {
mock := &mockTx{}
err := transactOnConn(context.Background(), nil, beginMock(mock),
func(context.Context, Session) error {
return errors.New("rollback")
})
assert.Equal(t, mockRollback, mock.status)
assert.NotNil(t, err)
}
func TestTxExceptions(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectCommit()
conn := NewSqlConnFromDB(db)
assert.NoError(t, conn.Transact(func(session Session) error {
return nil
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
return nil, errors.New("foo")
},
beginTx: begin,
onError: func(ctx context.Context, err error) {},
brk: breaker.NewBreaker(),
}
assert.Error(t, conn.Transact(func(session Session) error {
return nil
}))
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
_, err := conn.RawDB()
assert.Equal(t, errNoRawDBFromTx, err)
assert.Equal(t, errCantNestTx, conn.Transact(nil))
assert.Equal(t, errCantNestTx, conn.TransactCtx(context.Background(), nil))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
return errors.New("foo")
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback().WillReturnError(errors.New("foo"))
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
panic("foo")
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
panic(errors.New("foo"))
}))
})
}
func TestTxSession(t *testing.T) {
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
res, err := conn.Exec("any")
assert.NoError(t, err)
last, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), last)
affected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), affected)
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
_, err = conn.Exec("any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any")
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
assert.NotNil(t, stmt)
mock.ExpectPrepare("any").WillReturnError(errors.New("foo"))
_, err = conn.Prepare("any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
mock.ExpectQuery("any").WillReturnRows(rows)
var val string
err := conn.QueryRow(&val, "any")
assert.NoError(t, err)
assert.Equal(t, "foo", val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRow(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
mock.ExpectQuery("any").WillReturnRows(rows)
var val string
err := conn.QueryRowPartial(&val, "any")
assert.NoError(t, err)
assert.Equal(t, "foo", val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRowPartial(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val []string
err := conn.QueryRows(&val, "any")
assert.NoError(t, err)
assert.Equal(t, []string{"foo", "bar"}, val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRows(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val []string
err := conn.QueryRowsPartial(&val, "any")
assert.NoError(t, err)
assert.Equal(t, []string{"foo", "bar"}, val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRowsPartial(&val, "any")
assert.Equal(t, "foo", err.Error())
})
}
func TestTxRollback(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectQuery("foo").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
_, err := c.Exec("any")
assert.NoError(t, err)
var val string
return c.QueryRow(&val, "foo")
})
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
if _, err := c.Exec("any"); err != nil {
return err
}
var val string
assert.NoError(t, c.QueryRow(&val, "foo"))
return nil
})
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectQuery("foo").WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("bar"))
mock.ExpectCommit()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
_, err := c.Exec("any")
assert.NoError(t, err)
var val string
assert.NoError(t, c.QueryRow(&val, "foo"))
assert.Equal(t, "bar", val)
return nil
})
assert.NoError(t, err)
})
}
func runTxTest(t *testing.T, f func(conn SqlConn, mock sqlmock.Sqlmock)) {
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
sess := NewSessionFromTx(tx)
conn := NewSqlConnFromSession(sess)
f(conn, mock)
})
}