You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/core/stores/sqlx/bulkinserter_test.go

99 lines
2.8 KiB
Go

package sqlx
import (
"database/sql"
"strconv"
"testing"
"zero/core/logx"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
)
type mockedConn struct {
query string
args []interface{}
}
func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
c.query = query
c.args = args
return nil, nil
}
func (c *mockedConn) Prepare(query string) (StmtSession, error) {
panic("should not called")
}
func (c *mockedConn) QueryRow(v interface{}, query string, args ...interface{}) error {
panic("should not called")
}
func (c *mockedConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error {
panic("should not called")
}
func (c *mockedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
panic("should not called")
}
func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error {
panic("should not called")
}
func (c *mockedConn) Transact(func(session Session) error) error {
panic("should not called")
}
func TestBulkInserter(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
assert.Nil(t, err)
for i := 0; i < 5; i++ {
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
}
inserter.Flush()
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
`('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`,
conn.query)
assert.Nil(t, conn.args)
})
}
func TestBulkInserterSuffix(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
assert.Nil(t, err)
for i := 0; i < 5; i++ {
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
}
inserter.Flush()
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
`('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`,
conn.query)
assert.Nil(t, conn.args)
})
}
func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
fn(db, mock)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}