fix: mysql WithAcceptable bug (#3986)

master^2
MarkJoyMa 8 months ago committed by GitHub
parent 2207477b60
commit c7dacb0146
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -13,7 +13,7 @@ const (
// NewMysql returns a mysql connection.
func NewMysql(datasource string, opts ...SqlOption) SqlConn {
opts = append(opts, withMysqlAcceptable())
opts = append([]SqlOption{withMysqlAcceptable()}, opts...)
return NewSqlConn(mysqlDriverName, datasource, opts...)
}

@ -2,11 +2,11 @@ package sqlx
import (
"errors"
"reflect"
"testing"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stat"
@ -38,7 +38,6 @@ func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) {
func TestMysqlAcceptable(t *testing.T) {
conn := NewMysql("nomysql").(*commonSqlConn)
withMysqlAcceptable()(conn)
assert.EqualValues(t, reflect.ValueOf(mysqlAcceptable).Pointer(), reflect.ValueOf(conn.accept).Pointer())
assert.True(t, mysqlAcceptable(nil))
assert.False(t, mysqlAcceptable(errors.New("any")))
assert.False(t, mysqlAcceptable(new(mysql.MySQLError)))

@ -315,6 +315,13 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
// acceptable is the func to check if the error can be accepted.
func WithAcceptable(acceptable func(err error) bool) SqlOption {
return func(conn *commonSqlConn) {
conn.accept = acceptable
if conn.accept == nil {
conn.accept = acceptable
} else {
pre := conn.accept
conn.accept = func(err error) bool {
return pre(err) || acceptable(err)
}
}
}
}

@ -8,6 +8,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/dbtest"
@ -264,6 +265,45 @@ func TestBreakerWithScanError(t *testing.T) {
})
}
func TestWithAcceptable(t *testing.T) {
var (
acceptableErr = errors.New("acceptable")
acceptableErr2 = errors.New("acceptable2")
acceptableErr3 = errors.New("acceptable3")
)
opts := []SqlOption{
WithAcceptable(func(err error) bool {
if err == nil {
return true
}
return errors.Is(err, acceptableErr)
}),
WithAcceptable(func(err error) bool {
if err == nil {
return true
}
return errors.Is(err, acceptableErr2)
}),
WithAcceptable(func(err error) bool {
if err == nil {
return true
}
return errors.Is(err, acceptableErr3)
}),
}
var conn = &commonSqlConn{}
for _, opt := range opts {
opt(conn)
}
assert.True(t, conn.accept(nil))
assert.False(t, conn.accept(assert.AnError))
assert.True(t, conn.accept(acceptableErr))
assert.True(t, conn.accept(acceptableErr2))
assert.True(t, conn.accept(acceptableErr3))
}
func buildConn() (mock sqlmock.Sqlmock, err error) {
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB

Loading…
Cancel
Save