From abd1fa96a9f615d407de9b09ca7041e0df7e0d44 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Mon, 9 Oct 2023 21:57:26 +0800 Subject: [PATCH] fix: UpdateStmt doesn't update the statement correctly in sqlx/bulkinserter.go (#3607) --- core/mapping/utils.go | 2 +- core/stores/sqlx/bulkinserter.go | 10 +++++ core/stores/sqlx/bulkinserter_test.go | 61 +++++++++++++++++++++++++-- 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/core/mapping/utils.go b/core/mapping/utils.go index 72216bfc..67ea0ae3 100644 --- a/core/mapping/utils.go +++ b/core/mapping/utils.go @@ -30,7 +30,7 @@ const ( leftSquareBracket = '[' rightSquareBracket = ']' segmentSeparator = ',' - intSize = 32 << (^uint(0) >> 63) + intSize = 32 << (^uint(0) >> 63) // 32 or 64 ) var ( diff --git a/core/stores/sqlx/bulkinserter.go b/core/stores/sqlx/bulkinserter.go index 789d1e80..f251b76b 100644 --- a/core/stores/sqlx/bulkinserter.go +++ b/core/stores/sqlx/bulkinserter.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "strings" + "sync" "time" "github.com/zeromicro/go-zero/core/executors" @@ -30,6 +31,7 @@ type ( executor *executors.PeriodicalExecutor inserter *dbInserter stmt bulkStmt + lock sync.RWMutex // guards stmt } bulkStmt struct { @@ -65,6 +67,9 @@ func (bi *BulkInserter) Flush() { // Insert inserts given args. func (bi *BulkInserter) Insert(args ...any) error { + bi.lock.RLock() + defer bi.lock.RUnlock() + value, err := format(bi.stmt.valueFormat, args...) if err != nil { return err @@ -95,6 +100,11 @@ func (bi *BulkInserter) UpdateStmt(stmt string) error { return err } + bi.lock.Lock() + defer bi.lock.Unlock() + + // with write lock, it doesn't matter what's the order of setting bi.stmt and calling flush. + bi.stmt = bkStmt bi.executor.Flush() bi.executor.Sync(func() { bi.inserter.stmt = bkStmt diff --git a/core/stores/sqlx/bulkinserter_test.go b/core/stores/sqlx/bulkinserter_test.go index ae4bca1b..6ffa349e 100644 --- a/core/stores/sqlx/bulkinserter_test.go +++ b/core/stores/sqlx/bulkinserter_test.go @@ -5,6 +5,9 @@ import ( "database/sql" "errors" "strconv" + "strings" + "sync" + "sync/atomic" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -13,14 +16,19 @@ import ( ) type mockedConn struct { - query string - args []any - execErr error + query string + args []any + execErr error + updateCallback func(query string, args []any) } func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...any) (sql.Result, error) { c.query = query c.args = args + if c.updateCallback != nil { + c.updateCallback(query, args) + } + return nil, c.execErr } @@ -144,3 +152,50 @@ func TestBulkInserter_Update(t *testing.T) { assert.NotNil(t, inserter.UpdateStmt("foo")) assert.NotNil(t, inserter.Insert("foo", "bar")) } + +func TestBulkInserter_UpdateStmt(t *testing.T) { + var updated int32 + conn := mockedConn{ + execErr: errors.New("foo"), + updateCallback: func(query string, args []any) { + count := atomic.AddInt32(&updated, 1) + assert.Empty(t, args) + assert.Equal(t, 100, strings.Count(query, "foo")) + if count == 1 { + assert.Equal(t, 0, strings.Count(query, "bar")) + } else { + assert.Equal(t, 100, strings.Count(query, "bar")) + } + }, + } + + inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom) VALUES(?)`) + assert.NoError(t, err) + + var wg1 sync.WaitGroup + wg1.Add(2) + for i := 0; i < 2; i++ { + go func() { + defer wg1.Done() + for i := 0; i < 50; i++ { + assert.NoError(t, inserter.Insert("foo")) + } + }() + } + wg1.Wait() + + assert.NoError(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user) VALUES(?, ?)`)) + + var wg2 sync.WaitGroup + wg2.Add(1) + go func() { + defer wg2.Done() + for i := 0; i < 100; i++ { + assert.NoError(t, inserter.Insert("foo", "bar")) + } + inserter.Flush() + }() + wg2.Wait() + + assert.Equal(t, int32(2), atomic.LoadInt32(&updated)) +}