You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/core/stores/sqlx/bulkinserter_test.go

202 lines
5.8 KiB
Go

4 years ago
package sqlx
import (
"context"
4 years ago
"database/sql"
4 years ago
"errors"
4 years ago
"strconv"
"strings"
"sync"
"sync/atomic"
4 years ago
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stores/dbtest"
4 years ago
)
type mockedConn struct {
query string
args []any
execErr error
updateCallback func(query string, args []any)
4 years ago
}
func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...any) (sql.Result, error) {
4 years ago
c.query = query
c.args = args
if c.updateCallback != nil {
c.updateCallback(query, args)
}
4 years ago
return nil, c.execErr
4 years ago
}
func (c *mockedConn) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
panic("implement me")
}
func (c *mockedConn) QueryRowCtx(ctx context.Context, v any, query string, args ...any) error {
panic("implement me")
}
func (c *mockedConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
panic("implement me")
}
func (c *mockedConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
panic("implement me")
}
func (c *mockedConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
panic("implement me")
}
func (c *mockedConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
panic("should not called")
}
func (c *mockedConn) Exec(query string, args ...any) (sql.Result, error) {
return c.ExecCtx(context.Background(), query, args...)
}
4 years ago
func (c *mockedConn) Prepare(query string) (StmtSession, error) {
panic("should not called")
}
func (c *mockedConn) QueryRow(v any, query string, args ...any) error {
4 years ago
panic("should not called")
}
func (c *mockedConn) QueryRowPartial(v any, query string, args ...any) error {
4 years ago
panic("should not called")
}
func (c *mockedConn) QueryRows(v any, query string, args ...any) error {
4 years ago
panic("should not called")
}
func (c *mockedConn) QueryRowsPartial(v any, query string, args ...any) error {
4 years ago
panic("should not called")
}
func (c *mockedConn) RawDB() (*sql.DB, error) {
panic("should not called")
}
4 years ago
func (c *mockedConn) Transact(func(session Session) error) error {
panic("should not called")
}
func TestBulkInserter(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
4 years ago
var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
assert.Nil(t, err)
for i := 0; i < 5; i++ {
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
}
inserter.Flush()
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
`('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`,
conn.query)
assert.Nil(t, conn.args)
})
}
func TestBulkInserterSuffix(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
4 years ago
var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
assert.Nil(t, err)
4 years ago
assert.Nil(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user, count) VALUES`+
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`))
4 years ago
for i := 0; i < 5; i++ {
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
}
4 years ago
inserter.SetResultHandler(func(result sql.Result, err error) {})
4 years ago
inserter.Flush()
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
`('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`,
conn.query)
assert.Nil(t, conn.args)
})
}
4 years ago
func TestBulkInserterBadStatement(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
4 years ago
var conn mockedConn
_, err := NewBulkInserter(&conn, "foo")
assert.NotNil(t, err)
})
}
func TestBulkInserter_Update(t *testing.T) {
conn := mockedConn{
execErr: errors.New("foo"),
}
_, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES()`)
assert.NotNil(t, err)
_, err = NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?)`)
assert.NotNil(t, err)
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
assert.Nil(t, err)
inserter.inserter.Execute([]string{"bar"})
inserter.SetResultHandler(func(result sql.Result, err error) {
})
inserter.UpdateOrDelete(func() {})
inserter.inserter.Execute([]string(nil))
assert.NotNil(t, inserter.UpdateStmt("foo"))
assert.NotNil(t, inserter.Insert("foo", "bar"))
}
func TestBulkInserter_UpdateStmt(t *testing.T) {
var updated int32
conn := mockedConn{
execErr: errors.New("foo"),
updateCallback: func(query string, args []any) {
count := atomic.AddInt32(&updated, 1)
assert.Empty(t, args)
assert.Equal(t, 100, strings.Count(query, "foo"))
if count == 1 {
assert.Equal(t, 0, strings.Count(query, "bar"))
} else {
assert.Equal(t, 100, strings.Count(query, "bar"))
}
},
}
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom) VALUES(?)`)
assert.NoError(t, err)
var wg1 sync.WaitGroup
wg1.Add(2)
for i := 0; i < 2; i++ {
go func() {
defer wg1.Done()
for i := 0; i < 50; i++ {
assert.NoError(t, inserter.Insert("foo"))
}
}()
}
wg1.Wait()
assert.NoError(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user) VALUES(?, ?)`))
var wg2 sync.WaitGroup
wg2.Add(1)
go func() {
defer wg2.Done()
for i := 0; i < 100; i++ {
assert.NoError(t, inserter.Insert("foo", "bar"))
}
inserter.Flush()
}()
wg2.Wait()
assert.Equal(t, int32(2), atomic.LoadInt32(&updated))
}