From c7dacb01468bad3a21dc34b18eb8635f875fd29c Mon Sep 17 00:00:00 2001 From: MarkJoyMa <64180138+MarkJoyMa@users.noreply.github.com> Date: Fri, 8 Mar 2024 12:23:41 +0800 Subject: [PATCH] fix: mysql WithAcceptable bug (#3986) --- core/stores/sqlx/mysql.go | 2 +- core/stores/sqlx/mysql_test.go | 3 +-- core/stores/sqlx/sqlconn.go | 9 ++++++- core/stores/sqlx/sqlconn_test.go | 40 ++++++++++++++++++++++++++++++++ 4 files changed, 50 insertions(+), 4 deletions(-) diff --git a/core/stores/sqlx/mysql.go b/core/stores/sqlx/mysql.go index 3c026921..e7797095 100644 --- a/core/stores/sqlx/mysql.go +++ b/core/stores/sqlx/mysql.go @@ -13,7 +13,7 @@ const ( // NewMysql returns a mysql connection. func NewMysql(datasource string, opts ...SqlOption) SqlConn { - opts = append(opts, withMysqlAcceptable()) + opts = append([]SqlOption{withMysqlAcceptable()}, opts...) return NewSqlConn(mysqlDriverName, datasource, opts...) } diff --git a/core/stores/sqlx/mysql_test.go b/core/stores/sqlx/mysql_test.go index 68698e4b..d343de69 100644 --- a/core/stores/sqlx/mysql_test.go +++ b/core/stores/sqlx/mysql_test.go @@ -2,11 +2,11 @@ package sqlx import ( "errors" - "reflect" "testing" "github.com/go-sql-driver/mysql" "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stat" @@ -38,7 +38,6 @@ func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) { func TestMysqlAcceptable(t *testing.T) { conn := NewMysql("nomysql").(*commonSqlConn) withMysqlAcceptable()(conn) - assert.EqualValues(t, reflect.ValueOf(mysqlAcceptable).Pointer(), reflect.ValueOf(conn.accept).Pointer()) assert.True(t, mysqlAcceptable(nil)) assert.False(t, mysqlAcceptable(errors.New("any"))) assert.False(t, mysqlAcceptable(new(mysql.MySQLError))) diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go index 11838056..d6ca746d 100644 --- a/core/stores/sqlx/sqlconn.go +++ b/core/stores/sqlx/sqlconn.go @@ -315,6 +315,13 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) // acceptable is the func to check if the error can be accepted. func WithAcceptable(acceptable func(err error) bool) SqlOption { return func(conn *commonSqlConn) { - conn.accept = acceptable + if conn.accept == nil { + conn.accept = acceptable + } else { + pre := conn.accept + conn.accept = func(err error) bool { + return pre(err) || acceptable(err) + } + } } } diff --git a/core/stores/sqlx/sqlconn_test.go b/core/stores/sqlx/sqlconn_test.go index ab83e2c6..f599bdf2 100644 --- a/core/stores/sqlx/sqlconn_test.go +++ b/core/stores/sqlx/sqlconn_test.go @@ -8,6 +8,7 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stores/dbtest" @@ -264,6 +265,45 @@ func TestBreakerWithScanError(t *testing.T) { }) } +func TestWithAcceptable(t *testing.T) { + var ( + acceptableErr = errors.New("acceptable") + acceptableErr2 = errors.New("acceptable2") + acceptableErr3 = errors.New("acceptable3") + ) + opts := []SqlOption{ + WithAcceptable(func(err error) bool { + if err == nil { + return true + } + return errors.Is(err, acceptableErr) + }), + WithAcceptable(func(err error) bool { + if err == nil { + return true + } + return errors.Is(err, acceptableErr2) + }), + WithAcceptable(func(err error) bool { + if err == nil { + return true + } + return errors.Is(err, acceptableErr3) + }), + } + + var conn = &commonSqlConn{} + for _, opt := range opts { + opt(conn) + } + + assert.True(t, conn.accept(nil)) + assert.False(t, conn.accept(assert.AnError)) + assert.True(t, conn.accept(acceptableErr)) + assert.True(t, conn.accept(acceptableErr2)) + assert.True(t, conn.accept(acceptableErr3)) +} + func buildConn() (mock sqlmock.Sqlmock, err error) { _, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) { var db *sql.DB