fix: UpdateStmt doesn't update the statement correctly in sqlx/bulkinserter.go (#3607)

master
Kevin Wan 1 year ago committed by GitHub
parent 5aedd9c076
commit abd1fa96a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,7 +30,7 @@ const (
leftSquareBracket = '['
rightSquareBracket = ']'
segmentSeparator = ','
intSize = 32 << (^uint(0) >> 63)
intSize = 32 << (^uint(0) >> 63) // 32 or 64
)
var (

@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"strings"
"sync"
"time"
"github.com/zeromicro/go-zero/core/executors"
@ -30,6 +31,7 @@ type (
executor *executors.PeriodicalExecutor
inserter *dbInserter
stmt bulkStmt
lock sync.RWMutex // guards stmt
}
bulkStmt struct {
@ -65,6 +67,9 @@ func (bi *BulkInserter) Flush() {
// Insert inserts given args.
func (bi *BulkInserter) Insert(args ...any) error {
bi.lock.RLock()
defer bi.lock.RUnlock()
value, err := format(bi.stmt.valueFormat, args...)
if err != nil {
return err
@ -95,6 +100,11 @@ func (bi *BulkInserter) UpdateStmt(stmt string) error {
return err
}
bi.lock.Lock()
defer bi.lock.Unlock()
// with write lock, it doesn't matter what's the order of setting bi.stmt and calling flush.
bi.stmt = bkStmt
bi.executor.Flush()
bi.executor.Sync(func() {
bi.inserter.stmt = bkStmt

@ -5,6 +5,9 @@ import (
"database/sql"
"errors"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"github.com/DATA-DOG/go-sqlmock"
@ -16,11 +19,16 @@ type mockedConn struct {
query string
args []any
execErr error
updateCallback func(query string, args []any)
}
func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...any) (sql.Result, error) {
c.query = query
c.args = args
if c.updateCallback != nil {
c.updateCallback(query, args)
}
return nil, c.execErr
}
@ -144,3 +152,50 @@ func TestBulkInserter_Update(t *testing.T) {
assert.NotNil(t, inserter.UpdateStmt("foo"))
assert.NotNil(t, inserter.Insert("foo", "bar"))
}
func TestBulkInserter_UpdateStmt(t *testing.T) {
var updated int32
conn := mockedConn{
execErr: errors.New("foo"),
updateCallback: func(query string, args []any) {
count := atomic.AddInt32(&updated, 1)
assert.Empty(t, args)
assert.Equal(t, 100, strings.Count(query, "foo"))
if count == 1 {
assert.Equal(t, 0, strings.Count(query, "bar"))
} else {
assert.Equal(t, 100, strings.Count(query, "bar"))
}
},
}
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom) VALUES(?)`)
assert.NoError(t, err)
var wg1 sync.WaitGroup
wg1.Add(2)
for i := 0; i < 2; i++ {
go func() {
defer wg1.Done()
for i := 0; i < 50; i++ {
assert.NoError(t, inserter.Insert("foo"))
}
}()
}
wg1.Wait()
assert.NoError(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user) VALUES(?, ?)`))
var wg2 sync.WaitGroup
wg2.Add(1)
go func() {
defer wg2.Done()
for i := 0; i < 100; i++ {
assert.NoError(t, inserter.Insert("foo", "bar"))
}
inserter.Flush()
}()
wg2.Wait()
assert.Equal(t, int32(2), atomic.LoadInt32(&updated))
}

Loading…
Cancel
Save