You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/core/stores/sqlx/sqlconn.go

328 lines
8.5 KiB
Go

package sqlx
import (
"context"
"database/sql"
"errors"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
)
// spanName is used to identify the span name for the SQL execution.
const spanName = "sql"
type (
// Session stands for raw connections or transaction sessions
Session interface {
Exec(query string, args ...any) (sql.Result, error)
ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error)
Prepare(query string) (StmtSession, error)
PrepareCtx(ctx context.Context, query string) (StmtSession, error)
QueryRow(v any, query string, args ...any) error
QueryRowCtx(ctx context.Context, v any, query string, args ...any) error
QueryRowPartial(v any, query string, args ...any) error
QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error
QueryRows(v any, query string, args ...any) error
QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error
QueryRowsPartial(v any, query string, args ...any) error
QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error
}
// SqlConn only stands for raw connections, so Transact method can be called.
SqlConn interface {
Session
// RawDB is for other ORM to operate with, use it with caution.
// Notice: don't close it.
RawDB() (*sql.DB, error)
Transact(fn func(Session) error) error
TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
}
// SqlOption defines the method to customize a sql connection.
SqlOption func(*commonSqlConn)
// thread-safe
// Because CORBA doesn't support PREPARE, so we need to combine the
// query arguments into one string and do underlying query without arguments
commonSqlConn struct {
connProv connProvider
onError func(context.Context, error)
beginTx beginnable
brk breaker.Breaker
accept breaker.Acceptable
}
connProvider func() (*sql.DB, error)
sessionConn interface {
Exec(query string, args ...any) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
)
// NewSqlConn returns a SqlConn with given driver name and datasource.
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
return getSqlConn(driverName, datasource)
},
onError: func(ctx context.Context, err error) {
logInstanceError(ctx, datasource, err)
},
beginTx: begin,
brk: breaker.NewBreaker(),
}
for _, opt := range opts {
opt(conn)
}
return conn
}
// NewSqlConnFromDB returns a SqlConn with the given sql.DB.
// Use it with caution, it's provided for other ORM to interact with.
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
return db, nil
},
onError: func(ctx context.Context, err error) {
logx.WithContext(ctx).Errorf("Error on getting sql instance: %v", err)
},
beginTx: begin,
brk: breaker.NewBreaker(),
}
for _, opt := range opts {
opt(conn)
}
return conn
}
// NewSqlConnFromSession returns a SqlConn with the given session.
func NewSqlConnFromSession(session Session) SqlConn {
return txConn{
Session: session,
}
}
func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) {
return db.ExecCtx(context.Background(), q, args...)
}
func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
result sql.Result, err error) {
ctx, span := startSpan(ctx, "Exec")
defer func() {
endSpan(span, err)
}()
err = db.brk.DoWithAcceptable(func() error {
var conn *sql.DB
conn, err = db.connProv()
if err != nil {
db.onError(ctx, err)
return err
}
result, err = exec(ctx, conn, q, args...)
return err
}, db.acceptable)
if errors.Is(err, breaker.ErrServiceUnavailable) {
metricReqErr.Inc("Exec", "breaker")
}
return
}
func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
return db.PrepareCtx(context.Background(), query)
}
func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
ctx, span := startSpan(ctx, "Prepare")
defer func() {
endSpan(span, err)
}()
err = db.brk.DoWithAcceptable(func() error {
var conn *sql.DB
conn, err = db.connProv()
if err != nil {
db.onError(ctx, err)
return err
}
st, err := conn.PrepareContext(ctx, query)
if err != nil {
return err
}
stmt = statement{
query: query,
stmt: st,
brk: db.brk,
accept: db.acceptable,
}
return nil
}, db.acceptable)
if errors.Is(err, breaker.ErrServiceUnavailable) {
metricReqErr.Inc("Prepare", "breaker")
}
return
}
func (db *commonSqlConn) QueryRow(v any, q string, args ...any) error {
return db.QueryRowCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v any, q string,
args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRow")
defer func() {
endSpan(span, err)
}()
return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, q, args...)
}
func (db *commonSqlConn) QueryRowPartial(v any, q string, args ...any) error {
return db.QueryRowPartialCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v any,
q string, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowPartial")
defer func() {
endSpan(span, err)
}()
return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, q, args...)
}
func (db *commonSqlConn) QueryRows(v any, q string, args ...any) error {
return db.QueryRowsCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v any, q string,
args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRows")
defer func() {
endSpan(span, err)
}()
return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, q, args...)
}
func (db *commonSqlConn) QueryRowsPartial(v any, q string, args ...any) error {
return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
q string, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowsPartial")
defer func() {
endSpan(span, err)
}()
return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, q, args...)
}
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
return db.connProv()
}
func (db *commonSqlConn) Transact(fn func(Session) error) error {
return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
return fn(session)
})
}
func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) {
ctx, span := startSpan(ctx, "Transact")
defer func() {
endSpan(span, err)
}()
err = db.brk.DoWithAcceptable(func() error {
return transact(ctx, db, db.beginTx, fn)
}, db.acceptable)
if errors.Is(err, breaker.ErrServiceUnavailable) {
metricReqErr.Inc("Transact", "breaker")
}
return
}
func (db *commonSqlConn) acceptable(err error) bool {
if err == nil || errors.Is(err, sql.ErrNoRows) || errors.Is(err, sql.ErrTxDone) ||
errors.Is(err, context.Canceled) {
return true
}
var e acceptableError
if errors.As(err, &e) {
return true
}
if db.accept == nil {
return false
}
return db.accept(err)
}
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
q string, args ...any) (err error) {
var scanFailed bool
err = db.brk.DoWithAcceptable(func() error {
conn, err := db.connProv()
if err != nil {
db.onError(ctx, err)
return err
}
return query(ctx, conn, func(rows *sql.Rows) error {
e := scanner(rows)
if e != nil {
scanFailed = true
}
return e
}, q, args...)
}, func(err error) bool {
return scanFailed || db.acceptable(err)
})
if errors.Is(err, breaker.ErrServiceUnavailable) {
metricReqErr.Inc("queryRows", "breaker")
}
return
}
// WithAcceptable returns a SqlOption that setting the acceptable function.
// acceptable is the func to check if the error can be accepted.
func WithAcceptable(acceptable func(err error) bool) SqlOption {
return func(conn *commonSqlConn) {
if conn.accept == nil {
conn.accept = acceptable
} else {
pre := conn.accept
conn.accept = func(err error) bool {
return pre(err) || acceptable(err)
}
}
}
}