You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
2.3 KiB
Go
104 lines
2.3 KiB
Go
package sqlx
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
const (
|
|
mockCommit = 1
|
|
mockRollback = 2
|
|
)
|
|
|
|
type mockTx struct {
|
|
status int
|
|
}
|
|
|
|
func (mt *mockTx) Commit() error {
|
|
mt.status |= mockCommit
|
|
return nil
|
|
}
|
|
|
|
func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (mt *mockTx) Prepare(query string) (StmtSession, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (mt *mockTx) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (mt *mockTx) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (mt *mockTx) Rollback() error {
|
|
mt.status |= mockRollback
|
|
return nil
|
|
}
|
|
|
|
func beginMock(mock *mockTx) beginnable {
|
|
return func(*sql.DB) (trans, error) {
|
|
return mock, nil
|
|
}
|
|
}
|
|
|
|
func TestTransactCommit(t *testing.T) {
|
|
mock := &mockTx{}
|
|
err := transactOnConn(context.Background(), nil, beginMock(mock),
|
|
func(context.Context, Session) error {
|
|
return nil
|
|
})
|
|
assert.Equal(t, mockCommit, mock.status)
|
|
assert.Nil(t, err)
|
|
}
|
|
|
|
func TestTransactRollback(t *testing.T) {
|
|
mock := &mockTx{}
|
|
err := transactOnConn(context.Background(), nil, beginMock(mock),
|
|
func(context.Context, Session) error {
|
|
return errors.New("rollback")
|
|
})
|
|
assert.Equal(t, mockRollback, mock.status)
|
|
assert.NotNil(t, err)
|
|
}
|