diff --git a/core/stores/sqlx/mysql_test.go b/core/stores/sqlx/mysql_test.go index e6170411..e7027707 100644 --- a/core/stores/sqlx/mysql_test.go +++ b/core/stores/sqlx/mysql_test.go @@ -1,6 +1,8 @@ package sqlx import ( + "errors" + "reflect" "testing" "github.com/go-sql-driver/mysql" @@ -33,6 +35,15 @@ func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) { assert.True(t, found) } +func TestMysqlAcceptable(t *testing.T) { + conn := NewMysql("nomysql").(*commonSqlConn) + withMysqlAcceptable()(conn) + assert.EqualValues(t, reflect.ValueOf(mysqlAcceptable).Pointer(), reflect.ValueOf(conn.accept).Pointer()) + assert.True(t, mysqlAcceptable(nil)) + assert.False(t, mysqlAcceptable(errors.New("any"))) + assert.False(t, mysqlAcceptable(new(mysql.MySQLError))) +} + func tryOnDuplicateEntryError(t *testing.T, accept func(error) bool) error { logx.Disable() diff --git a/core/stores/sqlx/sqlconn_test.go b/core/stores/sqlx/sqlconn_test.go new file mode 100644 index 00000000..b313a35d --- /dev/null +++ b/core/stores/sqlx/sqlconn_test.go @@ -0,0 +1,61 @@ +package sqlx + +import ( + "database/sql" + "io" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/logx" +) + +const mockedDatasource = "sqlmock" + +func init() { + logx.Disable() +} + +func TestSqlConn(t *testing.T) { + mock := buildConn() + mock.ExpectExec("any") + mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"})) + conn := NewMysql(mockedDatasource) + badConn := NewMysql("badsql") + _, err := conn.Exec("any", "value") + assert.NotNil(t, err) + _, err = badConn.Exec("any", "value") + assert.NotNil(t, err) + _, err = conn.Prepare("any") + assert.NotNil(t, err) + _, err = badConn.Prepare("any") + assert.NotNil(t, err) + var val string + assert.NotNil(t, conn.QueryRow(&val, "any")) + assert.NotNil(t, badConn.QueryRow(&val, "any")) + assert.NotNil(t, conn.QueryRowPartial(&val, "any")) + assert.NotNil(t, badConn.QueryRowPartial(&val, "any")) + assert.NotNil(t, conn.QueryRows(&val, "any")) + assert.NotNil(t, badConn.QueryRows(&val, "any")) + assert.NotNil(t, conn.QueryRowsPartial(&val, "any")) + assert.NotNil(t, badConn.QueryRowsPartial(&val, "any")) + assert.NotNil(t, conn.Transact(func(session Session) error { + return nil + })) + assert.NotNil(t, badConn.Transact(func(session Session) error { + return nil + })) +} + +func buildConn() (mock sqlmock.Sqlmock) { + connManager.GetResource(mockedDatasource, func() (io.Closer, error) { + var db *sql.DB + var err error + db, mock, err = sqlmock.New() + return &pingedDB{ + DB: db, + }, err + }) + + return +}