package sqlx import ( "database/sql" "github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/logx" ) // ErrNotFound is an alias of sql.ErrNoRows var ErrNotFound = sql.ErrNoRows type ( // Session stands for raw connections or transaction sessions Session interface { Exec(query string, args ...interface{}) (sql.Result, error) Prepare(query string) (StmtSession, error) QueryRow(v interface{}, query string, args ...interface{}) error QueryRowPartial(v interface{}, query string, args ...interface{}) error QueryRows(v interface{}, query string, args ...interface{}) error QueryRowsPartial(v interface{}, query string, args ...interface{}) error } // SqlConn only stands for raw connections, so Transact method can be called. SqlConn interface { Session // RawDB is for other ORM to operate with, use it with caution. // Notice: don't close it. RawDB() (*sql.DB, error) Transact(func(session Session) error) error } // SqlOption defines the method to customize a sql connection. SqlOption func(*commonSqlConn) // StmtSession interface represents a session that can be used to execute statements. StmtSession interface { Close() error Exec(args ...interface{}) (sql.Result, error) QueryRow(v interface{}, args ...interface{}) error QueryRowPartial(v interface{}, args ...interface{}) error QueryRows(v interface{}, args ...interface{}) error QueryRowsPartial(v interface{}, args ...interface{}) error } // thread-safe // Because CORBA doesn't support PREPARE, so we need to combine the // query arguments into one string and do underlying query without arguments commonSqlConn struct { connProv connProvider onError func(error) beginTx beginnable brk breaker.Breaker accept func(error) bool } connProvider func() (*sql.DB, error) sessionConn interface { Exec(query string, args ...interface{}) (sql.Result, error) Query(query string, args ...interface{}) (*sql.Rows, error) } statement struct { query string stmt *sql.Stmt } stmtConn interface { Exec(args ...interface{}) (sql.Result, error) Query(args ...interface{}) (*sql.Rows, error) } ) // NewSqlConn returns a SqlConn with given driver name and datasource. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn { conn := &commonSqlConn{ connProv: func() (*sql.DB, error) { return getSqlConn(driverName, datasource) }, onError: func(err error) { logInstanceError(datasource, err) }, beginTx: begin, brk: breaker.NewBreaker(), } for _, opt := range opts { opt(conn) } return conn } // NewSqlConnFromDB returns a SqlConn with the given sql.DB. // Use it with caution, it's provided for other ORM to interact with. func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn { conn := &commonSqlConn{ connProv: func() (*sql.DB, error) { return db, nil }, onError: func(err error) { logx.Errorf("Error on getting sql instance: %v", err) }, beginTx: begin, brk: breaker.NewBreaker(), } for _, opt := range opts { opt(conn) } return conn } func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) { err = db.brk.DoWithAcceptable(func() error { var conn *sql.DB conn, err = db.connProv() if err != nil { db.onError(err) return err } result, err = exec(conn, q, args...) return err }, db.acceptable) return } func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) { err = db.brk.DoWithAcceptable(func() error { var conn *sql.DB conn, err = db.connProv() if err != nil { db.onError(err) return err } st, err := conn.Prepare(query) if err != nil { return err } stmt = statement{ query: query, stmt: st, } return nil }, db.acceptable) return } func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error { return db.queryRows(func(rows *sql.Rows) error { return unmarshalRow(v, rows, true) }, q, args...) } func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error { return db.queryRows(func(rows *sql.Rows) error { return unmarshalRow(v, rows, false) }, q, args...) } func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error { return db.queryRows(func(rows *sql.Rows) error { return unmarshalRows(v, rows, true) }, q, args...) } func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { return db.queryRows(func(rows *sql.Rows) error { return unmarshalRows(v, rows, false) }, q, args...) } func (db *commonSqlConn) RawDB() (*sql.DB, error) { return db.connProv() } func (db *commonSqlConn) Transact(fn func(Session) error) error { return db.brk.DoWithAcceptable(func() error { return transact(db, db.beginTx, fn) }, db.acceptable) } func (db *commonSqlConn) acceptable(err error) bool { ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone if db.accept == nil { return ok } return ok || db.accept(err) } func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error { var qerr error return db.brk.DoWithAcceptable(func() error { conn, err := db.connProv() if err != nil { db.onError(err) return err } return query(conn, func(rows *sql.Rows) error { qerr = scanner(rows) return qerr }, q, args...) }, func(err error) bool { return qerr == err || db.acceptable(err) }) } func (s statement) Close() error { return s.stmt.Close() } func (s statement) Exec(args ...interface{}) (sql.Result, error) { return execStmt(s.stmt, s.query, args...) } func (s statement) QueryRow(v interface{}, args ...interface{}) error { return queryStmt(s.stmt, func(rows *sql.Rows) error { return unmarshalRow(v, rows, true) }, s.query, args...) } func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error { return queryStmt(s.stmt, func(rows *sql.Rows) error { return unmarshalRow(v, rows, false) }, s.query, args...) } func (s statement) QueryRows(v interface{}, args ...interface{}) error { return queryStmt(s.stmt, func(rows *sql.Rows) error { return unmarshalRows(v, rows, true) }, s.query, args...) } func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error { return queryStmt(s.stmt, func(rows *sql.Rows) error { return unmarshalRows(v, rows, false) }, s.query, args...) }