package sqlx import ( "context" "database/sql" "errors" "time" "github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/timex" ) const defaultSlowThreshold = time.Millisecond * 500 var ( slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) logSql = syncx.ForAtomicBool(true) logSlowSql = syncx.ForAtomicBool(true) ) type ( // StmtSession interface represents a session that can be used to execute statements. StmtSession interface { Close() error Exec(args ...any) (sql.Result, error) ExecCtx(ctx context.Context, args ...any) (sql.Result, error) QueryRow(v any, args ...any) error QueryRowCtx(ctx context.Context, v any, args ...any) error QueryRowPartial(v any, args ...any) error QueryRowPartialCtx(ctx context.Context, v any, args ...any) error QueryRows(v any, args ...any) error QueryRowsCtx(ctx context.Context, v any, args ...any) error QueryRowsPartial(v any, args ...any) error QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error } statement struct { query string stmt *sql.Stmt brk breaker.Breaker accept breaker.Acceptable } stmtConn interface { Exec(args ...any) (sql.Result, error) ExecContext(ctx context.Context, args ...any) (sql.Result, error) Query(args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) } ) func (s statement) Close() error { return s.stmt.Close() } func (s statement) Exec(args ...any) (sql.Result, error) { return s.ExecCtx(context.Background(), args...) } func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) { ctx, span := startSpan(ctx, "Exec") defer func() { endSpan(span, err) }() err = s.brk.DoWithAcceptable(func() error { result, err = execStmt(ctx, s.stmt, s.query, args...) return err }, func(err error) bool { return s.accept(err) }) if errors.Is(err, breaker.ErrServiceUnavailable) { metricReqErr.Inc("stmt_exec", "breaker") } return } func (s statement) QueryRow(v any, args ...any) error { return s.QueryRowCtx(context.Background(), v, args...) } func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) { ctx, span := startSpan(ctx, "QueryRow") defer func() { endSpan(span, err) }() return s.queryRows(ctx, func(v any, scanner rowsScanner) error { return unmarshalRow(v, scanner, true) }, v, args...) } func (s statement) QueryRowPartial(v any, args ...any) error { return s.QueryRowPartialCtx(context.Background(), v, args...) } func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) { ctx, span := startSpan(ctx, "QueryRowPartial") defer func() { endSpan(span, err) }() return s.queryRows(ctx, func(v any, scanner rowsScanner) error { return unmarshalRow(v, scanner, false) }, v, args...) } func (s statement) QueryRows(v any, args ...any) error { return s.QueryRowsCtx(context.Background(), v, args...) } func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) { ctx, span := startSpan(ctx, "QueryRows") defer func() { endSpan(span, err) }() return s.queryRows(ctx, func(v any, scanner rowsScanner) error { return unmarshalRows(v, scanner, true) }, v, args...) } func (s statement) QueryRowsPartial(v any, args ...any) error { return s.QueryRowsPartialCtx(context.Background(), v, args...) } func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) { ctx, span := startSpan(ctx, "QueryRowsPartial") defer func() { endSpan(span, err) }() return s.queryRows(ctx, func(v any, scanner rowsScanner) error { return unmarshalRows(v, scanner, false) }, v, args...) } func (s statement) queryRows(ctx context.Context, scanFn func(any, rowsScanner) error, v any, args ...any) error { var scanFailed bool err := s.brk.DoWithAcceptable(func() error { return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { err := scanFn(v, rows) if err != nil { scanFailed = true } return err }, s.query, args...) }, func(err error) bool { return scanFailed || s.accept(err) }) if errors.Is(err, breaker.ErrServiceUnavailable) { metricReqErr.Inc("stmt_queryRows", "breaker") } return err } // DisableLog disables logging of sql statements, includes info and slow logs. func DisableLog() { logSql.Set(false) logSlowSql.Set(false) } // DisableStmtLog disables info logging of sql statements, but keeps slow logs. func DisableStmtLog() { logSql.Set(false) } // SetSlowThreshold sets the slow threshold. func SetSlowThreshold(threshold time.Duration) { slowThreshold.Set(threshold) } func exec(ctx context.Context, conn sessionConn, q string, args ...any) (sql.Result, error) { guard := newGuard("exec") if err := guard.start(q, args...); err != nil { return nil, err } result, err := conn.ExecContext(ctx, q, args...) guard.finish(ctx, err) return result, err } func execStmt(ctx context.Context, conn stmtConn, q string, args ...any) (sql.Result, error) { guard := newGuard("execStmt") if err := guard.start(q, args...); err != nil { return nil, err } result, err := conn.ExecContext(ctx, args...) guard.finish(ctx, err) return result, err } func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error, q string, args ...any) error { guard := newGuard("query") if err := guard.start(q, args...); err != nil { return err } rows, err := conn.QueryContext(ctx, q, args...) guard.finish(ctx, err) if err != nil { return err } defer rows.Close() return scanner(rows) } func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error, q string, args ...any) error { guard := newGuard("queryStmt") if err := guard.start(q, args...); err != nil { return err } rows, err := conn.QueryContext(ctx, args...) guard.finish(ctx, err) if err != nil { return err } defer rows.Close() return scanner(rows) } type ( sqlGuard interface { start(q string, args ...any) error finish(ctx context.Context, err error) } nilGuard struct{} realSqlGuard struct { command string stmt string startTime time.Duration } ) func newGuard(command string) sqlGuard { if logSql.True() || logSlowSql.True() { return &realSqlGuard{ command: command, } } return nilGuard{} } func (n nilGuard) start(_ string, _ ...any) error { return nil } func (n nilGuard) finish(_ context.Context, _ error) { } func (e *realSqlGuard) finish(ctx context.Context, err error) { duration := timex.Since(e.startTime) if duration > slowThreshold.Load() { logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] %s: slowcall - %s", e.command, e.stmt) metricSlowCount.Inc(e.command) } else if logSql.True() { logx.WithContext(ctx).WithDuration(duration).Infof("sql %s: %s", e.command, e.stmt) } if err != nil { logSqlError(ctx, e.stmt, err) } metricReqDur.ObserveFloat(float64(duration)/float64(time.Millisecond), e.command) } func (e *realSqlGuard) start(q string, args ...any) error { stmt, err := format(q, args...) if err != nil { return err } e.stmt = stmt e.startTime = timex.Now() return nil }