package sqlx import ( "context" "database/sql" "errors" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/internal/dbtest" ) const ( mockCommit = 1 mockRollback = 2 ) type mockTx struct { status int } func (mt *mockTx) Commit() error { mt.status |= mockCommit return nil } func (mt *mockTx) Exec(_ string, _ ...any) (sql.Result, error) { return nil, nil } func (mt *mockTx) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) { return nil, nil } func (mt *mockTx) Prepare(_ string) (StmtSession, error) { return nil, nil } func (mt *mockTx) PrepareCtx(_ context.Context, _ string) (StmtSession, error) { return nil, nil } func (mt *mockTx) QueryRow(_ any, _ string, _ ...any) error { return nil } func (mt *mockTx) QueryRowCtx(_ context.Context, _ any, _ string, _ ...any) error { return nil } func (mt *mockTx) QueryRowPartial(_ any, _ string, _ ...any) error { return nil } func (mt *mockTx) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error { return nil } func (mt *mockTx) QueryRows(_ any, _ string, _ ...any) error { return nil } func (mt *mockTx) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error { return nil } func (mt *mockTx) QueryRowsPartial(_ any, _ string, _ ...any) error { return nil } func (mt *mockTx) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error { return nil } func (mt *mockTx) Rollback() error { mt.status |= mockRollback return nil } func beginMock(mock *mockTx) beginnable { return func(*sql.DB) (trans, error) { return mock, nil } } func TestTransactCommit(t *testing.T) { mock := &mockTx{} err := transactOnConn(context.Background(), nil, beginMock(mock), func(context.Context, Session) error { return nil }) assert.Equal(t, mockCommit, mock.status) assert.Nil(t, err) } func TestTransactRollback(t *testing.T) { mock := &mockTx{} err := transactOnConn(context.Background(), nil, beginMock(mock), func(context.Context, Session) error { return errors.New("rollback") }) assert.Equal(t, mockRollback, mock.status) assert.NotNil(t, err) } func TestTxExceptions(t *testing.T) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectCommit() conn := NewSqlConnFromDB(db) assert.NoError(t, conn.Transact(func(session Session) error { return nil })) }) dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { conn := &commonSqlConn{ connProv: func() (*sql.DB, error) { return nil, errors.New("foo") }, beginTx: begin, onError: func(ctx context.Context, err error) {}, brk: breaker.NewBreaker(), } assert.Error(t, conn.Transact(func(session Session) error { return nil })) }) runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { _, err := conn.RawDB() assert.Equal(t, errNoRawDBFromTx, err) assert.Equal(t, errCantNestTx, conn.Transact(nil)) assert.Equal(t, errCantNestTx, conn.TransactCtx(context.Background(), nil)) }) dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { mock.ExpectBegin() conn := NewSqlConnFromDB(db) assert.Error(t, conn.Transact(func(session Session) error { return errors.New("foo") })) }) dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectRollback().WillReturnError(errors.New("foo")) conn := NewSqlConnFromDB(db) assert.Error(t, conn.Transact(func(session Session) error { panic("foo") })) }) dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectRollback() conn := NewSqlConnFromDB(db) assert.Error(t, conn.Transact(func(session Session) error { panic(errors.New("foo")) })) }) } func TestTxSession(t *testing.T) { runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3)) res, err := conn.Exec("any") assert.NoError(t, err) last, err := res.LastInsertId() assert.NoError(t, err) assert.Equal(t, int64(2), last) affected, err := res.RowsAffected() assert.NoError(t, err) assert.Equal(t, int64(3), affected) mock.ExpectExec("any").WillReturnError(errors.New("foo")) _, err = conn.Exec("any") assert.Equal(t, "foo", err.Error()) }) runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { mock.ExpectPrepare("any") stmt, err := conn.Prepare("any") assert.NoError(t, err) assert.NotNil(t, stmt) mock.ExpectPrepare("any").WillReturnError(errors.New("foo")) _, err = conn.Prepare("any") assert.Equal(t, "foo", err.Error()) }) runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{"col"}).AddRow("foo") mock.ExpectQuery("any").WillReturnRows(rows) var val string err := conn.QueryRow(&val, "any") assert.NoError(t, err) assert.Equal(t, "foo", val) mock.ExpectQuery("any").WillReturnError(errors.New("foo")) err = conn.QueryRow(&val, "any") assert.Equal(t, "foo", err.Error()) }) runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{"col"}).AddRow("foo") mock.ExpectQuery("any").WillReturnRows(rows) var val string err := conn.QueryRowPartial(&val, "any") assert.NoError(t, err) assert.Equal(t, "foo", val) mock.ExpectQuery("any").WillReturnError(errors.New("foo")) err = conn.QueryRowPartial(&val, "any") assert.Equal(t, "foo", err.Error()) }) runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar") mock.ExpectQuery("any").WillReturnRows(rows) var val []string err := conn.QueryRows(&val, "any") assert.NoError(t, err) assert.Equal(t, []string{"foo", "bar"}, val) mock.ExpectQuery("any").WillReturnError(errors.New("foo")) err = conn.QueryRows(&val, "any") assert.Equal(t, "foo", err.Error()) }) runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar") mock.ExpectQuery("any").WillReturnRows(rows) var val []string err := conn.QueryRowsPartial(&val, "any") assert.NoError(t, err) assert.Equal(t, []string{"foo", "bar"}, val) mock.ExpectQuery("any").WillReturnError(errors.New("foo")) err = conn.QueryRowsPartial(&val, "any") assert.Equal(t, "foo", err.Error()) }) } func TestTxRollback(t *testing.T) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3)) mock.ExpectQuery("foo").WillReturnError(errors.New("foo")) mock.ExpectRollback() conn := NewSqlConnFromDB(db) err := conn.Transact(func(session Session) error { c := NewSqlConnFromSession(session) _, err := c.Exec("any") assert.NoError(t, err) var val string return c.QueryRow(&val, "foo") }) assert.Error(t, err) }) dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec("any").WillReturnError(errors.New("foo")) mock.ExpectRollback() conn := NewSqlConnFromDB(db) err := conn.Transact(func(session Session) error { c := NewSqlConnFromSession(session) if _, err := c.Exec("any"); err != nil { return err } var val string assert.NoError(t, c.QueryRow(&val, "foo")) return nil }) assert.Error(t, err) }) dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3)) mock.ExpectQuery("foo").WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("bar")) mock.ExpectCommit() conn := NewSqlConnFromDB(db) err := conn.Transact(func(session Session) error { c := NewSqlConnFromSession(session) _, err := c.Exec("any") assert.NoError(t, err) var val string assert.NoError(t, c.QueryRow(&val, "foo")) assert.Equal(t, "bar", val) return nil }) assert.NoError(t, err) }) } func runTxTest(t *testing.T, f func(conn SqlConn, mock sqlmock.Sqlmock)) { dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) { sess := NewSessionFromTx(tx) conn := NewSqlConnFromSession(sess) f(conn, mock) }) }