diff --git a/core/stores/sqlx/stmt.go b/core/stores/sqlx/stmt.go index 2bbdf7b0..7b6ad26b 100644 --- a/core/stores/sqlx/stmt.go +++ b/core/stores/sqlx/stmt.go @@ -12,7 +12,22 @@ import ( const defaultSlowThreshold = time.Millisecond * 500 -var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) +var ( + slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) + logSql = syncx.ForAtomicBool(true) + logSlowSql = syncx.ForAtomicBool(true) +) + +// DisableLog disables logging of sql statements, includes info and slow logs. +func DisableLog() { + logSql.Set(false) + logSlowSql.Set(false) +} + +// DisableStmtLog disables info logging of sql statements, but keeps slow logs. +func DisableStmtLog() { + logSql.Set(false) +} // SetSlowThreshold sets the slow threshold. func SetSlowThreshold(threshold time.Duration) { @@ -20,64 +35,39 @@ func SetSlowThreshold(threshold time.Duration) { } func exec(ctx context.Context, conn sessionConn, q string, args ...interface{}) (sql.Result, error) { - stmt, err := format(q, args...) - if err != nil { + guard := newGuard("exec") + if err := guard.start(q, args...); err != nil { return nil, err } - startTime := timex.Now() result, err := conn.ExecContext(ctx, q, args...) - duration := timex.Since(startTime) - if duration > slowThreshold.Load() { - logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt) - } else { - logx.WithContext(ctx).WithDuration(duration).Infof("sql exec: %s", stmt) - } - if err != nil { - logSqlError(ctx, stmt, err) - } + guard.finish(ctx, err) return result, err } func execStmt(ctx context.Context, conn stmtConn, q string, args ...interface{}) (sql.Result, error) { - stmt, err := format(q, args...) - if err != nil { + guard := newGuard("execStmt") + if err := guard.start(q, args...); err != nil { return nil, err } - startTime := timex.Now() result, err := conn.ExecContext(ctx, args...) - duration := timex.Since(startTime) - if duration > slowThreshold.Load() { - logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt) - } else { - logx.WithContext(ctx).WithDuration(duration).Infof("sql execStmt: %s", stmt) - } - if err != nil { - logSqlError(ctx, stmt, err) - } + guard.finish(ctx, err) return result, err } func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error { - stmt, err := format(q, args...) - if err != nil { + guard := newGuard("query") + if err := guard.start(q, args...); err != nil { return err } - startTime := timex.Now() rows, err := conn.QueryContext(ctx, q, args...) - duration := timex.Since(startTime) - if duration > slowThreshold.Load() { - logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt) - } else { - logx.WithContext(ctx).WithDuration(duration).Infof("sql query: %s", stmt) - } + guard.finish(ctx, err) if err != nil { - logSqlError(ctx, stmt, err) return err } defer rows.Close() @@ -87,24 +77,74 @@ func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error, func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error { - stmt, err := format(q, args...) - if err != nil { + guard := newGuard("queryStmt") + if err := guard.start(q, args...); err != nil { return err } - startTime := timex.Now() rows, err := conn.QueryContext(ctx, args...) - duration := timex.Since(startTime) - if duration > slowThreshold.Load() { - logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt) - } else { - logx.WithContext(ctx).WithDuration(duration).Infof("sql queryStmt: %s", stmt) - } + guard.finish(ctx, err) if err != nil { - logSqlError(ctx, stmt, err) return err } defer rows.Close() return scanner(rows) } + +type ( + sqlGuard interface { + start(q string, args ...interface{}) error + finish(ctx context.Context, err error) + } + + nilGuard struct{} + + realSqlGuard struct { + command string + stmt string + startTime time.Duration + } +) + +func newGuard(command string) sqlGuard { + if logSql.True() || logSlowSql.True() { + return &realSqlGuard{ + command: command, + } + } + + return nilGuard{} +} + +func (n nilGuard) start(_ string, _ ...interface{}) error { + return nil +} + +func (n nilGuard) finish(_ context.Context, _ error) { +} + +func (e *realSqlGuard) finish(ctx context.Context, err error) { + duration := timex.Since(e.startTime) + if duration > slowThreshold.Load() { + logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] %s: slowcall - %s", e.command, e.stmt) + } else if logSql.True() { + logx.WithContext(ctx).WithDuration(duration).Infof("sql %s: %s", e.command, e.stmt) + } + + if err != nil { + logSqlError(ctx, e.stmt, err) + } +} + +func (e *realSqlGuard) start(q string, args ...interface{}) error { + stmt, err := format(q, args...) + if err != nil { + return err + } + + e.stmt = stmt + e.startTime = timex.Now() + + return nil +} diff --git a/core/stores/sqlx/stmt_test.go b/core/stores/sqlx/stmt_test.go index 074ba2c6..0647e9c1 100644 --- a/core/stores/sqlx/stmt_test.go +++ b/core/stores/sqlx/stmt_test.go @@ -178,6 +178,47 @@ func TestSetSlowThreshold(t *testing.T) { assert.Equal(t, time.Second, slowThreshold.Load()) } +func TestDisableLog(t *testing.T) { + assert.True(t, logSql.True()) + assert.True(t, logSlowSql.True()) + defer func() { + logSql.Set(true) + logSlowSql.Set(true) + }() + + DisableLog() + assert.False(t, logSql.True()) + assert.False(t, logSlowSql.True()) +} + +func TestDisableStmtLog(t *testing.T) { + assert.True(t, logSql.True()) + assert.True(t, logSlowSql.True()) + defer func() { + logSql.Set(true) + logSlowSql.Set(true) + }() + + DisableStmtLog() + assert.False(t, logSql.True()) + assert.True(t, logSlowSql.True()) +} + +func TestNilGuard(t *testing.T) { + assert.True(t, logSql.True()) + assert.True(t, logSlowSql.True()) + defer func() { + logSql.Set(true) + logSlowSql.Set(true) + }() + + DisableLog() + guard := newGuard("any") + assert.Nil(t, guard.start("foo", "bar")) + guard.finish(context.Background(), nil) + assert.Equal(t, nilGuard{}, guard) +} + type mockedSessionConn struct { lastInsertId int64 rowsAffected int64