diff --git a/core/stores/sqlx/stmt.go b/core/stores/sqlx/stmt.go index e140a064..cc35f2b5 100644 --- a/core/stores/sqlx/stmt.go +++ b/core/stores/sqlx/stmt.go @@ -3,6 +3,7 @@ package sqlx import ( "context" "database/sql" + "errors" "time" "github.com/zeromicro/go-zero/core/breaker" @@ -70,6 +71,9 @@ func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, }, func(err error) bool { return s.accept(err) }) + if errors.Is(err, breaker.ErrServiceUnavailable) { + metricReqErr.Inc("stmt_exec", "breaker") + } return } @@ -137,7 +141,8 @@ func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) func (s statement) queryRows(ctx context.Context, scanFn func(any, rowsScanner) error, v any, args ...any) error { var scanFailed bool - return s.brk.DoWithAcceptable(func() error { + + err := s.brk.DoWithAcceptable(func() error { return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { err := scanFn(v, rows) if err != nil { @@ -148,6 +153,11 @@ func (s statement) queryRows(ctx context.Context, scanFn func(any, rowsScanner) }, func(err error) bool { return scanFailed || s.accept(err) }) + if errors.Is(err, breaker.ErrServiceUnavailable) { + metricReqErr.Inc("stmt_queryRows", "breaker") + } + + return err } // DisableLog disables logging of sql statements, includes info and slow logs. diff --git a/core/stores/sqlx/stmt_test.go b/core/stores/sqlx/stmt_test.go index 68aed9b6..84de7577 100644 --- a/core/stores/sqlx/stmt_test.go +++ b/core/stores/sqlx/stmt_test.go @@ -222,7 +222,7 @@ func TestNilGuard(t *testing.T) { assert.Equal(t, nilGuard{}, guard) } -func TestStmtScanFailed(t *testing.T) { +func TestStmtBreaker(t *testing.T) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { mock.ExpectPrepare("any") @@ -242,6 +242,52 @@ func TestStmtScanFailed(t *testing.T) { assert.NotErrorIs(t, err, breaker.ErrServiceUnavailable) } }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectPrepare("any") + conn := NewSqlConnFromDB(db) + stmt, err := conn.Prepare("any") + assert.NoError(t, err) + + for i := 0; i < 1000; i++ { + assert.Error(t, conn.Transact(func(session Session) error { + return nil + })) + } + + var breakerTriggered bool + for i := 0; i < 1000; i++ { + _, err = stmt.Exec("any") + if errors.Is(err, breaker.ErrServiceUnavailable) { + breakerTriggered = true + break + } + } + assert.True(t, breakerTriggered) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectPrepare("any") + conn := NewSqlConnFromDB(db) + stmt, err := conn.Prepare("any") + assert.NoError(t, err) + + for i := 0; i < 1000; i++ { + assert.Error(t, conn.Transact(func(session Session) error { + return nil + })) + } + + var breakerTriggered bool + for i := 0; i < 1000; i++ { + err = stmt.QueryRows(&struct{}{}, "any") + if errors.Is(err, breaker.ErrServiceUnavailable) { + breakerTriggered = true + break + } + } + assert.True(t, breakerTriggered) + }) } type mockedSessionConn struct {