From f6d9e19ecba691b3601d8039ba7f8bb943f8d118 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Thu, 9 Sep 2021 11:40:28 +0800 Subject: [PATCH] expose sql.DB to let orm operate on it (#1015) * expose sql.DB to let orm operate on it * add missing RawDB methods * add NewSqlConnFromDB for cooperate with dtm --- core/stores/sqlc/cachedsql_test.go | 8 +++++ core/stores/sqlx/bulkinserter_test.go | 4 +++ core/stores/sqlx/sqlconn.go | 44 ++++++++++++++++++++++----- core/stores/sqlx/sqlconn_test.go | 7 +++-- core/stores/sqlx/tx.go | 2 +- tools/goctl/model/sql/test/sqlconn.go | 6 ++++ 6 files changed, 61 insertions(+), 10 deletions(-) diff --git a/core/stores/sqlc/cachedsql_test.go b/core/stores/sqlc/cachedsql_test.go index c4651801..5afffe05 100644 --- a/core/stores/sqlc/cachedsql_test.go +++ b/core/stores/sqlc/cachedsql_test.go @@ -600,6 +600,10 @@ func (d dummySqlConn) QueryRowsPartial(v interface{}, query string, args ...inte return nil } +func (d dummySqlConn) RawDB() (*sql.DB, error) { + return nil, nil +} + func (d dummySqlConn) Transact(func(session sqlx.Session) error) error { return nil } @@ -621,6 +625,10 @@ func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{} return c.dummySqlConn.QueryRows(v, query, args...) } +func (c *trackedConn) RawDB() (*sql.DB, error) { + return nil, nil +} + func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error { c.transactValue = true return c.dummySqlConn.Transact(fn) diff --git a/core/stores/sqlx/bulkinserter_test.go b/core/stores/sqlx/bulkinserter_test.go index 264aee89..24aebb4d 100644 --- a/core/stores/sqlx/bulkinserter_test.go +++ b/core/stores/sqlx/bulkinserter_test.go @@ -43,6 +43,10 @@ func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...inter panic("should not called") } +func (c *mockedConn) RawDB() (*sql.DB, error) { + panic("should not called") +} + func (c *mockedConn) Transact(func(session Session) error) error { panic("should not called") } diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go index d30ce0ba..cb4f8861 100644 --- a/core/stores/sqlx/sqlconn.go +++ b/core/stores/sqlx/sqlconn.go @@ -6,6 +6,9 @@ import ( "github.com/tal-tech/go-zero/core/breaker" ) +// datasource placeholder for logging error. +const rawDB = "sql.DB" + // ErrNotFound is an alias of sql.ErrNoRows var ErrNotFound = sql.ErrNoRows @@ -23,6 +26,7 @@ type ( // SqlConn only stands for raw connections, so Transact method can be called. SqlConn interface { Session + RawDB() (*sql.DB, error) Transact(func(session Session) error) error } @@ -43,13 +47,15 @@ type ( // Because CORBA doesn't support PREPARE, so we need to combine the // query arguments into one string and do underlying query without arguments commonSqlConn struct { - driverName string datasource string + connProv connProvider beginTx beginnable brk breaker.Breaker accept func(error) bool } + connProvider func() (*sql.DB, error) + sessionConn interface { Exec(query string, args ...interface{}) (sql.Result, error) Query(query string, args ...interface{}) (*sql.Rows, error) @@ -69,10 +75,30 @@ type ( // NewSqlConn returns a SqlConn with given driver name and datasource. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn { conn := &commonSqlConn{ - driverName: driverName, datasource: datasource, - beginTx: begin, - brk: breaker.NewBreaker(), + connProv: func() (*sql.DB, error) { + return getSqlConn(driverName, datasource) + }, + beginTx: begin, + brk: breaker.NewBreaker(), + } + for _, opt := range opts { + opt(conn) + } + + return conn +} + +// NewSqlConnFromDB returns a SqlConn with the given sql.DB. +// Use it with caution, it's provided for other ORM to interact with. +func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn { + conn := &commonSqlConn{ + datasource: rawDB, + connProv: func() (*sql.DB, error) { + return db, nil + }, + beginTx: begin, + brk: breaker.NewBreaker(), } for _, opt := range opts { opt(conn) @@ -84,7 +110,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn { func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) { err = db.brk.DoWithAcceptable(func() error { var conn *sql.DB - conn, err = getSqlConn(db.driverName, db.datasource) + conn, err = db.connProv() if err != nil { logInstanceError(db.datasource, err) return err @@ -100,7 +126,7 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) { err = db.brk.DoWithAcceptable(func() error { var conn *sql.DB - conn, err = getSqlConn(db.driverName, db.datasource) + conn, err = db.connProv() if err != nil { logInstanceError(db.datasource, err) return err @@ -145,6 +171,10 @@ func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...inter }, q, args...) } +func (db *commonSqlConn) RawDB() (*sql.DB, error) { + return db.connProv() +} + func (db *commonSqlConn) Transact(fn func(Session) error) error { return db.brk.DoWithAcceptable(func() error { return transact(db, db.beginTx, fn) @@ -163,7 +193,7 @@ func (db *commonSqlConn) acceptable(err error) bool { func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error { var qerr error return db.brk.DoWithAcceptable(func() error { - conn, err := getSqlConn(db.driverName, db.datasource) + conn, err := db.connProv() if err != nil { logInstanceError(db.datasource, err) return err diff --git a/core/stores/sqlx/sqlconn_test.go b/core/stores/sqlx/sqlconn_test.go index b313a35d..676aaf5f 100644 --- a/core/stores/sqlx/sqlconn_test.go +++ b/core/stores/sqlx/sqlconn_test.go @@ -21,12 +21,15 @@ func TestSqlConn(t *testing.T) { mock.ExpectExec("any") mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"})) conn := NewMysql(mockedDatasource) + db, err := conn.RawDB() + assert.Nil(t, err) + rawConn := NewSqlConnFromDB(db, withMysqlAcceptable()) badConn := NewMysql("badsql") - _, err := conn.Exec("any", "value") + _, err = conn.Exec("any", "value") assert.NotNil(t, err) _, err = badConn.Exec("any", "value") assert.NotNil(t, err) - _, err = conn.Prepare("any") + _, err = rawConn.Prepare("any") assert.NotNil(t, err) _, err = badConn.Prepare("any") assert.NotNil(t, err) diff --git a/core/stores/sqlx/tx.go b/core/stores/sqlx/tx.go index 97991569..bf7d280e 100644 --- a/core/stores/sqlx/tx.go +++ b/core/stores/sqlx/tx.go @@ -71,7 +71,7 @@ func begin(db *sql.DB) (trans, error) { } func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) { - conn, err := getSqlConn(db.driverName, db.datasource) + conn, err := db.connProv() if err != nil { logInstanceError(db.datasource, err) return err diff --git a/tools/goctl/model/sql/test/sqlconn.go b/tools/goctl/model/sql/test/sqlconn.go index 164dedfc..77a4d94b 100644 --- a/tools/goctl/model/sql/test/sqlconn.go +++ b/tools/goctl/model/sql/test/sqlconn.go @@ -13,6 +13,7 @@ type ( MockConn struct { db *sql.DB } + statement struct { stmt *sql.Stmt } @@ -62,6 +63,11 @@ func (conn *MockConn) QueryRowsPartial(v interface{}, q string, args ...interfac }, q, args...) } +// RawDB returns the underlying sql.DB. +func (conn *MockConn) RawDB() (*sql.DB, error) { + return conn.db, nil +} + // Transact is the implemention of sqlx.SqlConn, nothing to do func (conn *MockConn) Transact(func(session sqlx.Session) error) error { return nil