diff --git a/core/breaker/breakers.go b/core/breaker/breakers.go index 314965c4..1df42d3e 100644 --- a/core/breaker/breakers.go +++ b/core/breaker/breakers.go @@ -59,7 +59,7 @@ func GetBreaker(name string) Breaker { // NoBreakerFor disables the circuit breaker for the given name. func NoBreakerFor(name string) { lock.Lock() - breakers[name] = newNopBreaker() + breakers[name] = NopBreaker() lock.Unlock() } diff --git a/core/breaker/nopbreaker.go b/core/breaker/nopbreaker.go index cd030b63..baa09801 100644 --- a/core/breaker/nopbreaker.go +++ b/core/breaker/nopbreaker.go @@ -4,7 +4,8 @@ const nopBreakerName = "nopBreaker" type nopBreaker struct{} -func newNopBreaker() Breaker { +// NopBreaker returns a breaker that never trigger breaker circuit. +func NopBreaker() Breaker { return nopBreaker{} } diff --git a/core/breaker/nopbreaker_test.go b/core/breaker/nopbreaker_test.go index 1756aa2a..ac26428d 100644 --- a/core/breaker/nopbreaker_test.go +++ b/core/breaker/nopbreaker_test.go @@ -8,7 +8,7 @@ import ( ) func TestNopBreaker(t *testing.T) { - b := newNopBreaker() + b := NopBreaker() assert.Equal(t, nopBreakerName, b.Name()) p, err := b.Allow() assert.Nil(t, err) diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go index 9603af5b..11838056 100644 --- a/core/stores/sqlx/sqlconn.go +++ b/core/stores/sqlx/sqlconn.go @@ -42,21 +42,6 @@ type ( // SqlOption defines the method to customize a sql connection. SqlOption func(*commonSqlConn) - // StmtSession interface represents a session that can be used to execute statements. - StmtSession interface { - Close() error - Exec(args ...any) (sql.Result, error) - ExecCtx(ctx context.Context, args ...any) (sql.Result, error) - QueryRow(v any, args ...any) error - QueryRowCtx(ctx context.Context, v any, args ...any) error - QueryRowPartial(v any, args ...any) error - QueryRowPartialCtx(ctx context.Context, v any, args ...any) error - QueryRows(v any, args ...any) error - QueryRowsCtx(ctx context.Context, v any, args ...any) error - QueryRowsPartial(v any, args ...any) error - QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error - } - // thread-safe // Because CORBA doesn't support PREPARE, so we need to combine the // query arguments into one string and do underlying query without arguments @@ -65,7 +50,7 @@ type ( onError func(context.Context, error) beginTx beginnable brk breaker.Breaker - accept func(error) bool + accept breaker.Acceptable } connProvider func() (*sql.DB, error) @@ -76,18 +61,6 @@ type ( Query(query string, args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) } - - statement struct { - query string - stmt *sql.Stmt - } - - stmtConn interface { - Exec(args ...any) (sql.Result, error) - ExecContext(ctx context.Context, args ...any) (sql.Result, error) - Query(args ...any) (*sql.Rows, error) - QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) - } ) // NewSqlConn returns a SqlConn with given driver name and datasource. @@ -189,8 +162,10 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm } stmt = statement{ - query: query, - stmt: st, + query: query, + stmt: st, + brk: db.brk, + accept: db.acceptable, } return nil }, db.acceptable) @@ -311,7 +286,7 @@ func (db *commonSqlConn) acceptable(err error) bool { func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error, q string, args ...any) (err error) { - var qerr error + var scanFailed bool err = db.brk.DoWithAcceptable(func() error { conn, err := db.connProv() if err != nil { @@ -320,11 +295,14 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) } return query(ctx, conn, func(rows *sql.Rows) error { - qerr = scanner(rows) - return qerr + e := scanner(rows) + if e != nil { + scanFailed = true + } + return e }, q, args...) }, func(err error) bool { - return errors.Is(err, qerr) || db.acceptable(err) + return scanFailed || db.acceptable(err) }) if errors.Is(err, breaker.ErrServiceUnavailable) { metricReqErr.Inc("queryRows", "breaker") @@ -333,83 +311,6 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) return } -func (s statement) Close() error { - return s.stmt.Close() -} - -func (s statement) Exec(args ...any) (sql.Result, error) { - return s.ExecCtx(context.Background(), args...) -} - -func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) { - ctx, span := startSpan(ctx, "Exec") - defer func() { - endSpan(span, err) - }() - - return execStmt(ctx, s.stmt, s.query, args...) -} - -func (s statement) QueryRow(v any, args ...any) error { - return s.QueryRowCtx(context.Background(), v, args...) -} - -func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) { - ctx, span := startSpan(ctx, "QueryRow") - defer func() { - endSpan(span, err) - }() - - return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { - return unmarshalRow(v, rows, true) - }, s.query, args...) -} - -func (s statement) QueryRowPartial(v any, args ...any) error { - return s.QueryRowPartialCtx(context.Background(), v, args...) -} - -func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) { - ctx, span := startSpan(ctx, "QueryRowPartial") - defer func() { - endSpan(span, err) - }() - - return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { - return unmarshalRow(v, rows, false) - }, s.query, args...) -} - -func (s statement) QueryRows(v any, args ...any) error { - return s.QueryRowsCtx(context.Background(), v, args...) -} - -func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) { - ctx, span := startSpan(ctx, "QueryRows") - defer func() { - endSpan(span, err) - }() - - return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { - return unmarshalRows(v, rows, true) - }, s.query, args...) -} - -func (s statement) QueryRowsPartial(v any, args ...any) error { - return s.QueryRowsPartialCtx(context.Background(), v, args...) -} - -func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) { - ctx, span := startSpan(ctx, "QueryRowsPartial") - defer func() { - endSpan(span, err) - }() - - return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { - return unmarshalRows(v, rows, false) - }, s.query, args...) -} - // WithAcceptable returns a SqlOption that setting the acceptable function. // acceptable is the func to check if the error can be accepted. func WithAcceptable(acceptable func(err error) bool) SqlOption { diff --git a/core/stores/sqlx/sqlconn_test.go b/core/stores/sqlx/sqlconn_test.go index 339d30cc..ab83e2c6 100644 --- a/core/stores/sqlx/sqlconn_test.go +++ b/core/stores/sqlx/sqlconn_test.go @@ -156,6 +156,7 @@ func TestStatement(t *testing.T) { st := statement{ query: "foo", stmt: stmt, + brk: breaker.NopBreaker(), } assert.NoError(t, st.Close()) }) diff --git a/core/stores/sqlx/stmt.go b/core/stores/sqlx/stmt.go index ab688abb..e140a064 100644 --- a/core/stores/sqlx/stmt.go +++ b/core/stores/sqlx/stmt.go @@ -5,6 +5,7 @@ import ( "database/sql" "time" + "github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/timex" @@ -18,6 +19,137 @@ var ( logSlowSql = syncx.ForAtomicBool(true) ) +type ( + // StmtSession interface represents a session that can be used to execute statements. + StmtSession interface { + Close() error + Exec(args ...any) (sql.Result, error) + ExecCtx(ctx context.Context, args ...any) (sql.Result, error) + QueryRow(v any, args ...any) error + QueryRowCtx(ctx context.Context, v any, args ...any) error + QueryRowPartial(v any, args ...any) error + QueryRowPartialCtx(ctx context.Context, v any, args ...any) error + QueryRows(v any, args ...any) error + QueryRowsCtx(ctx context.Context, v any, args ...any) error + QueryRowsPartial(v any, args ...any) error + QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error + } + + statement struct { + query string + stmt *sql.Stmt + brk breaker.Breaker + accept breaker.Acceptable + } + + stmtConn interface { + Exec(args ...any) (sql.Result, error) + ExecContext(ctx context.Context, args ...any) (sql.Result, error) + Query(args ...any) (*sql.Rows, error) + QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) + } +) + +func (s statement) Close() error { + return s.stmt.Close() +} + +func (s statement) Exec(args ...any) (sql.Result, error) { + return s.ExecCtx(context.Background(), args...) +} + +func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) { + ctx, span := startSpan(ctx, "Exec") + defer func() { + endSpan(span, err) + }() + + err = s.brk.DoWithAcceptable(func() error { + result, err = execStmt(ctx, s.stmt, s.query, args...) + return err + }, func(err error) bool { + return s.accept(err) + }) + + return +} + +func (s statement) QueryRow(v any, args ...any) error { + return s.QueryRowCtx(context.Background(), v, args...) +} + +func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) { + ctx, span := startSpan(ctx, "QueryRow") + defer func() { + endSpan(span, err) + }() + + return s.queryRows(ctx, func(v any, scanner rowsScanner) error { + return unmarshalRow(v, scanner, true) + }, v, args...) +} + +func (s statement) QueryRowPartial(v any, args ...any) error { + return s.QueryRowPartialCtx(context.Background(), v, args...) +} + +func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) { + ctx, span := startSpan(ctx, "QueryRowPartial") + defer func() { + endSpan(span, err) + }() + + return s.queryRows(ctx, func(v any, scanner rowsScanner) error { + return unmarshalRow(v, scanner, false) + }, v, args...) +} + +func (s statement) QueryRows(v any, args ...any) error { + return s.QueryRowsCtx(context.Background(), v, args...) +} + +func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) { + ctx, span := startSpan(ctx, "QueryRows") + defer func() { + endSpan(span, err) + }() + + return s.queryRows(ctx, func(v any, scanner rowsScanner) error { + return unmarshalRows(v, scanner, true) + }, v, args...) +} + +func (s statement) QueryRowsPartial(v any, args ...any) error { + return s.QueryRowsPartialCtx(context.Background(), v, args...) +} + +func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) { + ctx, span := startSpan(ctx, "QueryRowsPartial") + defer func() { + endSpan(span, err) + }() + + return s.queryRows(ctx, func(v any, scanner rowsScanner) error { + return unmarshalRows(v, scanner, false) + }, v, args...) +} + +func (s statement) queryRows(ctx context.Context, scanFn func(any, rowsScanner) error, + v any, args ...any) error { + var scanFailed bool + return s.brk.DoWithAcceptable(func() error { + return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { + err := scanFn(v, rows) + if err != nil { + scanFailed = true + } + return err + }, s.query, args...) + }, func(err error) bool { + return scanFailed || s.accept(err) + }) +} + // DisableLog disables logging of sql statements, includes info and slow logs. func DisableLog() { logSql.Set(false) diff --git a/core/stores/sqlx/stmt_test.go b/core/stores/sqlx/stmt_test.go index 215bf2b2..68aed9b6 100644 --- a/core/stores/sqlx/stmt_test.go +++ b/core/stores/sqlx/stmt_test.go @@ -7,7 +7,10 @@ import ( "testing" "time" + "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/breaker" + "github.com/zeromicro/go-zero/core/stores/dbtest" ) var errMockedPlaceholder = errors.New("placeholder") @@ -219,6 +222,28 @@ func TestNilGuard(t *testing.T) { assert.Equal(t, nilGuard{}, guard) } +func TestStmtScanFailed(t *testing.T) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectPrepare("any") + + conn := NewSqlConnFromDB(db) + stmt, err := conn.Prepare("any") + assert.NoError(t, err) + + var val struct { + Foo int + Bar string + } + for i := 0; i < 1000; i++ { + row := sqlmock.NewRows([]string{"foo"}).AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(row) + err := stmt.QueryRow(&val) + assert.Error(t, err) + assert.NotErrorIs(t, err, breaker.ErrServiceUnavailable) + } + }) +} + type mockedSessionConn struct { lastInsertId int64 rowsAffected int64 diff --git a/core/stores/sqlx/tx.go b/core/stores/sqlx/tx.go index d983077c..ea2fd2be 100644 --- a/core/stores/sqlx/tx.go +++ b/core/stores/sqlx/tx.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "fmt" + + "github.com/zeromicro/go-zero/core/breaker" ) type ( @@ -75,6 +77,7 @@ func (t txSession) PrepareCtx(ctx context.Context, q string) (stmtSession StmtSe return statement{ query: q, stmt: stmt, + brk: breaker.NopBreaker(), }, nil }