diff --git a/core/logx/logs.go b/core/logx/logs.go index c0a3315a..817dc513 100644 --- a/core/logx/logs.go +++ b/core/logx/logs.go @@ -275,7 +275,7 @@ func Infov(v interface{}) { infoAnySync(v) } -// Must checks if err is nil, otherwise logs the err and exits. +// Must checks if err is nil, otherwise logs the error and exits. func Must(err error) { if err != nil { msg := formatWithCaller(err.Error(), 3) diff --git a/core/stores/sqlx/utils.go b/core/stores/sqlx/utils.go index 95d2cdaa..bfbef248 100644 --- a/core/stores/sqlx/utils.go +++ b/core/stores/sqlx/utils.go @@ -2,6 +2,7 @@ package sqlx import ( "context" + "errors" "fmt" "strconv" "strings" @@ -10,6 +11,8 @@ import ( "github.com/zeromicro/go-zero/core/mapping" ) +var errUnbalancedEscape = errors.New("no char after escape char") + func desensitize(datasource string) string { // remove account pos := strings.LastIndex(datasource, "@") @@ -95,6 +98,30 @@ func format(query string, args ...interface{}) (string, error) { writeValue(&b, args[index]) i = j - 1 } + case '\'', '"', '`': + b.WriteByte(ch) + for j := i + 1; j < bytes; j++ { + cur := query[j] + b.WriteByte(cur) + + switch cur { + case '\\': + j++ + if j >= bytes { + return "", errUnbalancedEscape + } + + b.WriteByte(query[j]) + case '\'', '"', '`': + if cur == ch { + i = j + goto end + } + } + } + + end: + break default: b.WriteByte(ch) } diff --git a/core/stores/sqlx/utils_test.go b/core/stores/sqlx/utils_test.go index 6631e134..c3761c5e 100644 --- a/core/stores/sqlx/utils_test.go +++ b/core/stores/sqlx/utils_test.go @@ -97,6 +97,30 @@ func TestFormat(t *testing.T) { args: []interface{}{"133", false}, hasErr: true, }, + { + name: "select with date", + query: "select * from user where date='2006-01-02 15:04:05' and name=:1", + args: []interface{}{"foo"}, + expect: "select * from user where date='2006-01-02 15:04:05' and name='foo'", + }, + { + name: "select with date and escape", + query: `select * from user where date=' 2006-01-02 15:04:05 \'' and name=:1`, + args: []interface{}{"foo"}, + expect: `select * from user where date=' 2006-01-02 15:04:05 \'' and name='foo'`, + }, + { + name: "select with date and bad arg", + query: `select * from user where date='2006-01-02 15:04:05 \'' and name=:a`, + args: []interface{}{"foo"}, + hasErr: true, + }, + { + name: "select with date and escape error", + query: `select * from user where date='2006-01-02 15:04:05 \`, + args: []interface{}{"foo"}, + hasErr: true, + }, } for _, test := range tests { @@ -108,6 +132,7 @@ func TestFormat(t *testing.T) { if test.hasErr { assert.NotNil(t, err) } else { + assert.Nil(t, err) assert.Equal(t, test.expect, actual) } }) diff --git a/tools/goctl/model/sql/parser/parser.go b/tools/goctl/model/sql/parser/parser.go index 37db2690..7a114160 100644 --- a/tools/goctl/model/sql/parser/parser.go +++ b/tools/goctl/model/sql/parser/parser.go @@ -69,7 +69,6 @@ func Parse(filename, database string) ([]*Table, error) { } nameOriginals := parseNameOriginal(tables) - indexNameGen := func(column ...string) string { return strings.Join(column, "_") } @@ -77,14 +76,12 @@ func Parse(filename, database string) ([]*Table, error) { prefix := filepath.Base(filename) var list []*Table for indexTable, e := range tables { - columns := e.Columns - var ( + primaryColumn string primaryColumnSet = collection.NewSet() - - primaryColumn string - uniqueKeyMap = make(map[string][]string) - normalKeyMap = make(map[string][]string) + uniqueKeyMap = make(map[string][]string) + normalKeyMap = make(map[string][]string) + columns = e.Columns ) for _, column := range columns {