You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/core/stores/sqlx/sqlconn.go

410 lines
11 KiB
Go

4 years ago
package sqlx
import (
"context"
4 years ago
"database/sql"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
4 years ago
)
// spanName is used to identify the span name for the SQL execution.
const spanName = "sql"
4 years ago
type (
// Session stands for raw connections or transaction sessions
Session interface {
Exec(query string, args ...any) (sql.Result, error)
ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error)
4 years ago
Prepare(query string) (StmtSession, error)
PrepareCtx(ctx context.Context, query string) (StmtSession, error)
QueryRow(v any, query string, args ...any) error
QueryRowCtx(ctx context.Context, v any, query string, args ...any) error
QueryRowPartial(v any, query string, args ...any) error
QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error
QueryRows(v any, query string, args ...any) error
QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error
QueryRowsPartial(v any, query string, args ...any) error
QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error
4 years ago
}
// SqlConn only stands for raw connections, so Transact method can be called.
SqlConn interface {
Session
// RawDB is for other ORM to operate with, use it with caution.
// Notice: don't close it.
RawDB() (*sql.DB, error)
Transact(fn func(Session) error) error
TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
4 years ago
}
// SqlOption defines the method to customize a sql connection.
4 years ago
SqlOption func(*commonSqlConn)
// StmtSession interface represents a session that can be used to execute statements.
4 years ago
StmtSession interface {
Close() error
Exec(args ...any) (sql.Result, error)
ExecCtx(ctx context.Context, args ...any) (sql.Result, error)
QueryRow(v any, args ...any) error
QueryRowCtx(ctx context.Context, v any, args ...any) error
QueryRowPartial(v any, args ...any) error
QueryRowPartialCtx(ctx context.Context, v any, args ...any) error
QueryRows(v any, args ...any) error
QueryRowsCtx(ctx context.Context, v any, args ...any) error
QueryRowsPartial(v any, args ...any) error
QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error
4 years ago
}
// thread-safe
// Because CORBA doesn't support PREPARE, so we need to combine the
// query arguments into one string and do underlying query without arguments
commonSqlConn struct {
connProv connProvider
onError func(context.Context, error)
beginTx beginnable
brk breaker.Breaker
accept func(error) bool
4 years ago
}
connProvider func() (*sql.DB, error)
4 years ago
sessionConn interface {
Exec(query string, args ...any) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
4 years ago
}
statement struct {
query string
stmt *sql.Stmt
4 years ago
}
stmtConn interface {
Exec(args ...any) (sql.Result, error)
ExecContext(ctx context.Context, args ...any) (sql.Result, error)
Query(args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
4 years ago
}
)
// NewSqlConn returns a SqlConn with given driver name and datasource.
4 years ago
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
return getSqlConn(driverName, datasource)
},
onError: func(ctx context.Context, err error) {
logInstanceError(ctx, datasource, err)
},
beginTx: begin,
brk: breaker.NewBreaker(),
}
for _, opt := range opts {
opt(conn)
}
return conn
}
// NewSqlConnFromDB returns a SqlConn with the given sql.DB.
// Use it with caution, it's provided for other ORM to interact with.
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
return db, nil
},
onError: func(ctx context.Context, err error) {
logx.WithContext(ctx).Errorf("Error on getting sql instance: %v", err)
},
beginTx: begin,
brk: breaker.NewBreaker(),
4 years ago
}
for _, opt := range opts {
opt(conn)
}
return conn
}
// NewSqlConnFromSession returns a SqlConn with the given session.
func NewSqlConnFromSession(session Session) SqlConn {
return txConn{
Session: session,
}
}
func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) {
return db.ExecCtx(context.Background(), q, args...)
}
func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
result sql.Result, err error) {
ctx, span := startSpan(ctx, "Exec")
defer func() {
endSpan(span, err)
}()
4 years ago
err = db.brk.DoWithAcceptable(func() error {
var conn *sql.DB
conn, err = db.connProv()
4 years ago
if err != nil {
db.onError(ctx, err)
4 years ago
return err
}
result, err = exec(ctx, conn, q, args...)
4 years ago
return err
}, db.acceptable)
if err == breaker.ErrServiceUnavailable {
metricReqErr.Inc("Exec", "breaker")
}
4 years ago
return
}
func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
return db.PrepareCtx(context.Background(), query)
}
func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
ctx, span := startSpan(ctx, "Prepare")
defer func() {
endSpan(span, err)
}()
4 years ago
err = db.brk.DoWithAcceptable(func() error {
var conn *sql.DB
conn, err = db.connProv()
4 years ago
if err != nil {
db.onError(ctx, err)
4 years ago
return err
}
st, err := conn.PrepareContext(ctx, query)
if err != nil {
4 years ago
return err
}
stmt = statement{
query: query,
stmt: st,
}
return nil
4 years ago
}, db.acceptable)
if err == breaker.ErrServiceUnavailable {
metricReqErr.Inc("Prepare", "breaker")
}
4 years ago
return
}
func (db *commonSqlConn) QueryRow(v any, q string, args ...any) error {
return db.QueryRowCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v any, q string,
args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRow")
defer func() {
endSpan(span, err)
}()
return db.queryRows(ctx, func(rows *sql.Rows) error {
4 years ago
return unmarshalRow(v, rows, true)
}, q, args...)
}
func (db *commonSqlConn) QueryRowPartial(v any, q string, args ...any) error {
return db.QueryRowPartialCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v any,
q string, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowPartial")
defer func() {
endSpan(span, err)
}()
return db.queryRows(ctx, func(rows *sql.Rows) error {
4 years ago
return unmarshalRow(v, rows, false)
}, q, args...)
}
func (db *commonSqlConn) QueryRows(v any, q string, args ...any) error {
return db.QueryRowsCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v any, q string,
args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRows")
defer func() {
endSpan(span, err)
}()
return db.queryRows(ctx, func(rows *sql.Rows) error {
4 years ago
return unmarshalRows(v, rows, true)
}, q, args...)
}
func (db *commonSqlConn) QueryRowsPartial(v any, q string, args ...any) error {
return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
q string, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowsPartial")
defer func() {
endSpan(span, err)
}()
return db.queryRows(ctx, func(rows *sql.Rows) error {
4 years ago
return unmarshalRows(v, rows, false)
}, q, args...)
}
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
return db.connProv()
}
4 years ago
func (db *commonSqlConn) Transact(fn func(Session) error) error {
return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
return fn(session)
})
}
func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) {
ctx, span := startSpan(ctx, "Transact")
defer func() {
endSpan(span, err)
}()
err = db.brk.DoWithAcceptable(func() error {
return transact(ctx, db, db.beginTx, fn)
4 years ago
}, db.acceptable)
if err == breaker.ErrServiceUnavailable {
metricReqErr.Inc("Transact", "breaker")
}
return
4 years ago
}
func (db *commonSqlConn) acceptable(err error) bool {
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
4 years ago
if db.accept == nil {
return ok
}
return ok || db.accept(err)
4 years ago
}
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
q string, args ...any) (err error) {
4 years ago
var qerr error
err = db.brk.DoWithAcceptable(func() error {
conn, err := db.connProv()
4 years ago
if err != nil {
db.onError(ctx, err)
4 years ago
return err
}
return query(ctx, conn, func(rows *sql.Rows) error {
4 years ago
qerr = scanner(rows)
return qerr
}, q, args...)
}, func(err error) bool {
return qerr == err || db.acceptable(err)
})
if err == breaker.ErrServiceUnavailable {
metricReqErr.Inc("queryRows", "breaker")
}
return
4 years ago
}
func (s statement) Close() error {
return s.stmt.Close()
}
func (s statement) Exec(args ...any) (sql.Result, error) {
return s.ExecCtx(context.Background(), args...)
}
func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) {
ctx, span := startSpan(ctx, "Exec")
defer func() {
endSpan(span, err)
}()
return execStmt(ctx, s.stmt, s.query, args...)
4 years ago
}
func (s statement) QueryRow(v any, args ...any) error {
return s.QueryRowCtx(context.Background(), v, args...)
}
func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRow")
defer func() {
endSpan(span, err)
}()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
4 years ago
return unmarshalRow(v, rows, true)
}, s.query, args...)
4 years ago
}
func (s statement) QueryRowPartial(v any, args ...any) error {
return s.QueryRowPartialCtx(context.Background(), v, args...)
}
func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowPartial")
defer func() {
endSpan(span, err)
}()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
4 years ago
return unmarshalRow(v, rows, false)
}, s.query, args...)
4 years ago
}
func (s statement) QueryRows(v any, args ...any) error {
return s.QueryRowsCtx(context.Background(), v, args...)
}
func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRows")
defer func() {
endSpan(span, err)
}()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
4 years ago
return unmarshalRows(v, rows, true)
}, s.query, args...)
4 years ago
}
func (s statement) QueryRowsPartial(v any, args ...any) error {
return s.QueryRowsPartialCtx(context.Background(), v, args...)
}
func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowsPartial")
defer func() {
endSpan(span, err)
}()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
4 years ago
return unmarshalRows(v, rows, false)
}, s.query, args...)
4 years ago
}
// WithAcceptable returns a SqlOption that setting the acceptable function.
// acceptable is the func to check if the error can be accepted.
func WithAcceptable(acceptable func(err error) bool) SqlOption {
return func(conn *commonSqlConn) {
conn.accept = acceptable
}
}