From aaa39e17a3a8cb558bbaa40f512db372f6823b57 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Thu, 20 May 2021 16:14:44 +0800 Subject: [PATCH] print entire sql statements in logx if necessary (#704) --- core/stores/sqlx/sqlconn.go | 16 ++++--- core/stores/sqlx/stmt.go | 29 +++++++++--- core/stores/sqlx/stmt_test.go | 36 +++++++++++--- core/stores/sqlx/utils.go | 86 ++++++++++++++++++++-------------- core/stores/sqlx/utils_test.go | 53 +++++++++++++++++---- 5 files changed, 153 insertions(+), 67 deletions(-) diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go index 5bf636e0..d30ce0ba 100644 --- a/core/stores/sqlx/sqlconn.go +++ b/core/stores/sqlx/sqlconn.go @@ -56,7 +56,8 @@ type ( } statement struct { - stmt *sql.Stmt + query string + stmt *sql.Stmt } stmtConn interface { @@ -111,7 +112,8 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) { } stmt = statement{ - stmt: st, + query: query, + stmt: st, } return nil }, db.acceptable) @@ -181,29 +183,29 @@ func (s statement) Close() error { } func (s statement) Exec(args ...interface{}) (sql.Result, error) { - return execStmt(s.stmt, args...) + return execStmt(s.stmt, s.query, args...) } func (s statement) QueryRow(v interface{}, args ...interface{}) error { return queryStmt(s.stmt, func(rows *sql.Rows) error { return unmarshalRow(v, rows, true) - }, args...) + }, s.query, args...) } func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error { return queryStmt(s.stmt, func(rows *sql.Rows) error { return unmarshalRow(v, rows, false) - }, args...) + }, s.query, args...) } func (s statement) QueryRows(v interface{}, args ...interface{}) error { return queryStmt(s.stmt, func(rows *sql.Rows) error { return unmarshalRows(v, rows, true) - }, args...) + }, s.query, args...) } func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error { return queryStmt(s.stmt, func(rows *sql.Rows) error { return unmarshalRows(v, rows, false) - }, args...) + }, s.query, args...) } diff --git a/core/stores/sqlx/stmt.go b/core/stores/sqlx/stmt.go index 62651245..ebefa1f4 100644 --- a/core/stores/sqlx/stmt.go +++ b/core/stores/sqlx/stmt.go @@ -2,7 +2,6 @@ package sqlx import ( "database/sql" - "fmt" "time" "github.com/tal-tech/go-zero/core/logx" @@ -12,10 +11,14 @@ import ( const slowThreshold = time.Millisecond * 500 func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) { + stmt, err := format(q, args...) + if err != nil { + return nil, err + } + startTime := timex.Now() result, err := conn.Exec(q, args...) duration := timex.Since(startTime) - stmt := formatForPrint(q, args) if duration > slowThreshold { logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt) } else { @@ -28,11 +31,15 @@ func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) { return result, err } -func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) { +func execStmt(conn stmtConn, q string, args ...interface{}) (sql.Result, error) { + stmt, err := format(q, args...) + if err != nil { + return nil, err + } + startTime := timex.Now() result, err := conn.Exec(args...) duration := timex.Since(startTime) - stmt := fmt.Sprint(args...) if duration > slowThreshold { logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt) } else { @@ -46,10 +53,14 @@ func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) { } func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error { + stmt, err := format(q, args...) + if err != nil { + return err + } + startTime := timex.Now() rows, err := conn.Query(q, args...) duration := timex.Since(startTime) - stmt := fmt.Sprint(args...) if duration > slowThreshold { logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt) } else { @@ -64,8 +75,12 @@ func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...in return scanner(rows) } -func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, args ...interface{}) error { - stmt := fmt.Sprint(args...) +func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error { + stmt, err := format(q, args...) + if err != nil { + return err + } + startTime := timex.Now() rows, err := conn.Query(args...) duration := timex.Since(startTime) diff --git a/core/stores/sqlx/stmt_test.go b/core/stores/sqlx/stmt_test.go index 7247b40c..bff970a6 100644 --- a/core/stores/sqlx/stmt_test.go +++ b/core/stores/sqlx/stmt_test.go @@ -14,6 +14,7 @@ var errMockedPlaceholder = errors.New("placeholder") func TestStmt_exec(t *testing.T) { tests := []struct { name string + query string args []interface{} delay bool hasError bool @@ -23,18 +24,28 @@ func TestStmt_exec(t *testing.T) { }{ { name: "normal", + query: "select user from users where id=?", args: []interface{}{1}, lastInsertId: 1, rowsAffected: 2, }, { name: "exec error", + query: "select user from users where id=?", + args: []interface{}{1}, + hasError: true, + err: errors.New("exec"), + }, + { + name: "exec more args error", + query: "select user from users where id=? and name=?", args: []interface{}{1}, hasError: true, err: errors.New("exec"), }, { name: "slowcall", + query: "select user from users where id=?", args: []interface{}{1}, delay: true, lastInsertId: 1, @@ -51,7 +62,7 @@ func TestStmt_exec(t *testing.T) { rowsAffected: test.rowsAffected, err: test.err, delay: test.delay, - }, "select user from users where id=?", args...) + }, test.query, args...) }, func(args ...interface{}) (sql.Result, error) { return execStmt(&mockedStmtConn{ @@ -59,7 +70,7 @@ func TestStmt_exec(t *testing.T) { rowsAffected: test.rowsAffected, err: test.err, delay: test.delay, - }, args...) + }, test.query, args...) }, } @@ -89,23 +100,34 @@ func TestStmt_exec(t *testing.T) { func TestStmt_query(t *testing.T) { tests := []struct { name string + query string args []interface{} delay bool hasError bool err error }{ { - name: "normal", - args: []interface{}{1}, + name: "normal", + query: "select user from users where id=?", + args: []interface{}{1}, }, { name: "query error", + query: "select user from users where id=?", + args: []interface{}{1}, + hasError: true, + err: errors.New("exec"), + }, + { + name: "query more args error", + query: "select user from users where id=? and name=?", args: []interface{}{1}, hasError: true, err: errors.New("exec"), }, { name: "slowcall", + query: "select user from users where id=?", args: []interface{}{1}, delay: true, }, @@ -120,7 +142,7 @@ func TestStmt_query(t *testing.T) { delay: test.delay, }, func(rows *sql.Rows) error { return nil - }, "select user from users where id=?", args...) + }, test.query, args...) }, func(args ...interface{}) error { return queryStmt(&mockedStmtConn{ @@ -128,7 +150,7 @@ func TestStmt_query(t *testing.T) { delay: test.delay, }, func(rows *sql.Rows) error { return nil - }, args...) + }, test.query, args...) }, } @@ -143,7 +165,7 @@ func TestStmt_query(t *testing.T) { return } - assert.Equal(t, errMockedPlaceholder, err) + assert.NotNil(t, err) }) } } diff --git a/core/stores/sqlx/utils.go b/core/stores/sqlx/utils.go index 48995be6..01b96a72 100644 --- a/core/stores/sqlx/utils.go +++ b/core/stores/sqlx/utils.go @@ -2,6 +2,7 @@ package sqlx import ( "fmt" + "strconv" "strings" "github.com/tal-tech/go-zero/core/logx" @@ -45,24 +46,6 @@ func escape(input string) string { return b.String() } -func formatForPrint(query string, args ...interface{}) string { - if len(args) == 0 { - return query - } - - var vals []string - for _, arg := range args { - vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg))) - } - - var b strings.Builder - b.WriteByte('[') - b.WriteString(strings.Join(vals, ", ")) - b.WriteByte(']') - - return strings.Join([]string{query, b.String()}, " ") -} - func format(query string, args ...interface{}) (string, error) { numArgs := len(args) if numArgs == 0 { @@ -72,36 +55,50 @@ func format(query string, args ...interface{}) (string, error) { var b strings.Builder argIndex := 0 - for _, ch := range query { - if ch == '?' { + bytes := len(query) + for i := 0; i < bytes; i++ { + ch := query[i] + switch ch { + case '?': if argIndex >= numArgs { return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex) } - arg := args[argIndex] + writeValue(&b, args[argIndex]) argIndex++ + case '$': + var j int + for j = i + 1; j < bytes; j++ { + char := query[j] + if char < '0' || '9' < char { + break + } + } + if j > i+1 { + index, err := strconv.Atoi(query[i+1 : j]) + if err != nil { + return "", err + } - switch v := arg.(type) { - case bool: - if v { - b.WriteByte('1') - } else { - b.WriteByte('0') + // index starts from 1 for pg + if index > argIndex { + argIndex = index + } + index-- + if index < 0 || numArgs <= index { + return "", fmt.Errorf("error: wrong index %d in sql", index) } - case string: - b.WriteByte('\'') - b.WriteString(escape(v)) - b.WriteByte('\'') - default: - b.WriteString(mapping.Repr(v)) + + writeValue(&b, args[index]) + i = j - 1 } - } else { - b.WriteRune(ch) + default: + b.WriteByte(ch) } } if argIndex < numArgs { - return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex) + return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex) } return b.String(), nil @@ -117,3 +114,20 @@ func logSqlError(stmt string, err error) { logx.Errorf("stmt: %s, error: %s", stmt, err.Error()) } } + +func writeValue(buf *strings.Builder, arg interface{}) { + switch v := arg.(type) { + case bool: + if v { + buf.WriteByte('1') + } else { + buf.WriteByte('0') + } + case string: + buf.WriteByte('\'') + buf.WriteString(escape(v)) + buf.WriteByte('\'') + default: + buf.WriteString(mapping.Repr(v)) + } +} diff --git a/core/stores/sqlx/utils_test.go b/core/stores/sqlx/utils_test.go index 1dcfa3ad..af6c8067 100644 --- a/core/stores/sqlx/utils_test.go +++ b/core/stores/sqlx/utils_test.go @@ -29,30 +29,63 @@ func TestDesensitize_WithoutAccount(t *testing.T) { assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)")) } -func TestFormatForPrint(t *testing.T) { +func TestFormat(t *testing.T) { tests := []struct { name string query string args []interface{} expect string + hasErr bool }{ { - name: "no args", - query: "select user, name from table where id=?", - expect: `select user, name from table where id=?`, + name: "mysql normal", + query: "select name, age from users where bool=? and phone=?", + args: []interface{}{true, "133"}, + expect: "select name, age from users where bool=1 and phone='133'", }, { - name: "one arg", - query: "select user, name from table where id=?", - args: []interface{}{"kevin"}, - expect: `select user, name from table where id=? ["kevin"]`, + name: "mysql normal", + query: "select name, age from users where bool=? and phone=?", + args: []interface{}{false, "133"}, + expect: "select name, age from users where bool=0 and phone='133'", + }, + { + name: "pg normal", + query: "select name, age from users where bool=$1 and phone=$2", + args: []interface{}{true, "133"}, + expect: "select name, age from users where bool=1 and phone='133'", + }, + { + name: "pg normal reverse", + query: "select name, age from users where bool=$2 and phone=$1", + args: []interface{}{"133", false}, + expect: "select name, age from users where bool=0 and phone='133'", + }, + { + name: "pg error not number", + query: "select name, age from users where bool=$a and phone=$1", + args: []interface{}{"133", false}, + hasErr: true, + }, + { + name: "pg error more args", + query: "select name, age from users where bool=$2 and phone=$1 and nickname=$3", + args: []interface{}{"133", false}, + hasErr: true, }, } for _, test := range tests { + test := test t.Run(test.name, func(t *testing.T) { - actual := formatForPrint(test.query, test.args...) - assert.Equal(t, test.expect, actual) + t.Parallel() + + actual, err := format(test.query, test.args...) + if test.hasErr { + assert.NotNil(t, err) + } else { + assert.Equal(t, test.expect, actual) + } }) } }