Kevin Wan 2 years ago committed by GitHub
parent 05dd6bd743
commit 8ed22eafdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,7 +12,22 @@ import (
const defaultSlowThreshold = time.Millisecond * 500 const defaultSlowThreshold = time.Millisecond * 500
var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) var (
slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
logSql = syncx.ForAtomicBool(true)
logSlowSql = syncx.ForAtomicBool(true)
)
// DisableLog disables logging of sql statements, includes info and slow logs.
func DisableLog() {
logSql.Set(false)
logSlowSql.Set(false)
}
// DisableStmtLog disables info logging of sql statements, but keeps slow logs.
func DisableStmtLog() {
logSql.Set(false)
}
// SetSlowThreshold sets the slow threshold. // SetSlowThreshold sets the slow threshold.
func SetSlowThreshold(threshold time.Duration) { func SetSlowThreshold(threshold time.Duration) {
@ -20,64 +35,39 @@ func SetSlowThreshold(threshold time.Duration) {
} }
func exec(ctx context.Context, conn sessionConn, q string, args ...interface{}) (sql.Result, error) { func exec(ctx context.Context, conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...) guard := newGuard("exec")
if err != nil { if err := guard.start(q, args...); err != nil {
return nil, err return nil, err
} }
startTime := timex.Now()
result, err := conn.ExecContext(ctx, q, args...) result, err := conn.ExecContext(ctx, q, args...)
duration := timex.Since(startTime) guard.finish(ctx, err)
if duration > slowThreshold.Load() {
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
} else {
logx.WithContext(ctx).WithDuration(duration).Infof("sql exec: %s", stmt)
}
if err != nil {
logSqlError(ctx, stmt, err)
}
return result, err return result, err
} }
func execStmt(ctx context.Context, conn stmtConn, q string, args ...interface{}) (sql.Result, error) { func execStmt(ctx context.Context, conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...) guard := newGuard("execStmt")
if err != nil { if err := guard.start(q, args...); err != nil {
return nil, err return nil, err
} }
startTime := timex.Now()
result, err := conn.ExecContext(ctx, args...) result, err := conn.ExecContext(ctx, args...)
duration := timex.Since(startTime) guard.finish(ctx, err)
if duration > slowThreshold.Load() {
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
} else {
logx.WithContext(ctx).WithDuration(duration).Infof("sql execStmt: %s", stmt)
}
if err != nil {
logSqlError(ctx, stmt, err)
}
return result, err return result, err
} }
func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error, func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error,
q string, args ...interface{}) error { q string, args ...interface{}) error {
stmt, err := format(q, args...) guard := newGuard("query")
if err != nil { if err := guard.start(q, args...); err != nil {
return err return err
} }
startTime := timex.Now()
rows, err := conn.QueryContext(ctx, q, args...) rows, err := conn.QueryContext(ctx, q, args...)
duration := timex.Since(startTime) guard.finish(ctx, err)
if duration > slowThreshold.Load() {
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
} else {
logx.WithContext(ctx).WithDuration(duration).Infof("sql query: %s", stmt)
}
if err != nil { if err != nil {
logSqlError(ctx, stmt, err)
return err return err
} }
defer rows.Close() defer rows.Close()
@ -87,24 +77,74 @@ func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error,
func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error, func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error,
q string, args ...interface{}) error { q string, args ...interface{}) error {
stmt, err := format(q, args...) guard := newGuard("queryStmt")
if err != nil { if err := guard.start(q, args...); err != nil {
return err return err
} }
startTime := timex.Now()
rows, err := conn.QueryContext(ctx, args...) rows, err := conn.QueryContext(ctx, args...)
duration := timex.Since(startTime) guard.finish(ctx, err)
if duration > slowThreshold.Load() {
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
} else {
logx.WithContext(ctx).WithDuration(duration).Infof("sql queryStmt: %s", stmt)
}
if err != nil { if err != nil {
logSqlError(ctx, stmt, err)
return err return err
} }
defer rows.Close() defer rows.Close()
return scanner(rows) return scanner(rows)
} }
type (
sqlGuard interface {
start(q string, args ...interface{}) error
finish(ctx context.Context, err error)
}
nilGuard struct{}
realSqlGuard struct {
command string
stmt string
startTime time.Duration
}
)
func newGuard(command string) sqlGuard {
if logSql.True() || logSlowSql.True() {
return &realSqlGuard{
command: command,
}
}
return nilGuard{}
}
func (n nilGuard) start(_ string, _ ...interface{}) error {
return nil
}
func (n nilGuard) finish(_ context.Context, _ error) {
}
func (e *realSqlGuard) finish(ctx context.Context, err error) {
duration := timex.Since(e.startTime)
if duration > slowThreshold.Load() {
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] %s: slowcall - %s", e.command, e.stmt)
} else if logSql.True() {
logx.WithContext(ctx).WithDuration(duration).Infof("sql %s: %s", e.command, e.stmt)
}
if err != nil {
logSqlError(ctx, e.stmt, err)
}
}
func (e *realSqlGuard) start(q string, args ...interface{}) error {
stmt, err := format(q, args...)
if err != nil {
return err
}
e.stmt = stmt
e.startTime = timex.Now()
return nil
}

@ -178,6 +178,47 @@ func TestSetSlowThreshold(t *testing.T) {
assert.Equal(t, time.Second, slowThreshold.Load()) assert.Equal(t, time.Second, slowThreshold.Load())
} }
func TestDisableLog(t *testing.T) {
assert.True(t, logSql.True())
assert.True(t, logSlowSql.True())
defer func() {
logSql.Set(true)
logSlowSql.Set(true)
}()
DisableLog()
assert.False(t, logSql.True())
assert.False(t, logSlowSql.True())
}
func TestDisableStmtLog(t *testing.T) {
assert.True(t, logSql.True())
assert.True(t, logSlowSql.True())
defer func() {
logSql.Set(true)
logSlowSql.Set(true)
}()
DisableStmtLog()
assert.False(t, logSql.True())
assert.True(t, logSlowSql.True())
}
func TestNilGuard(t *testing.T) {
assert.True(t, logSql.True())
assert.True(t, logSlowSql.True())
defer func() {
logSql.Set(true)
logSlowSql.Set(true)
}()
DisableLog()
guard := newGuard("any")
assert.Nil(t, guard.start("foo", "bar"))
guard.finish(context.Background(), nil)
assert.Equal(t, nilGuard{}, guard)
}
type mockedSessionConn struct { type mockedSessionConn struct {
lastInsertId int64 lastInsertId int64
rowsAffected int64 rowsAffected int64

Loading…
Cancel
Save