* fix #1806

* chore: refine error text
master
Kevin Wan 3 years ago committed by GitHub
parent 5c9fae7e62
commit 5bcee4cf7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -275,7 +275,7 @@ func Infov(v interface{}) {
infoAnySync(v) infoAnySync(v)
} }
// Must checks if err is nil, otherwise logs the err and exits. // Must checks if err is nil, otherwise logs the error and exits.
func Must(err error) { func Must(err error) {
if err != nil { if err != nil {
msg := formatWithCaller(err.Error(), 3) msg := formatWithCaller(err.Error(), 3)

@ -2,6 +2,7 @@ package sqlx
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@ -10,6 +11,8 @@ import (
"github.com/zeromicro/go-zero/core/mapping" "github.com/zeromicro/go-zero/core/mapping"
) )
var errUnbalancedEscape = errors.New("no char after escape char")
func desensitize(datasource string) string { func desensitize(datasource string) string {
// remove account // remove account
pos := strings.LastIndex(datasource, "@") pos := strings.LastIndex(datasource, "@")
@ -95,6 +98,30 @@ func format(query string, args ...interface{}) (string, error) {
writeValue(&b, args[index]) writeValue(&b, args[index])
i = j - 1 i = j - 1
} }
case '\'', '"', '`':
b.WriteByte(ch)
for j := i + 1; j < bytes; j++ {
cur := query[j]
b.WriteByte(cur)
switch cur {
case '\\':
j++
if j >= bytes {
return "", errUnbalancedEscape
}
b.WriteByte(query[j])
case '\'', '"', '`':
if cur == ch {
i = j
goto end
}
}
}
end:
break
default: default:
b.WriteByte(ch) b.WriteByte(ch)
} }

@ -97,6 +97,30 @@ func TestFormat(t *testing.T) {
args: []interface{}{"133", false}, args: []interface{}{"133", false},
hasErr: true, hasErr: true,
}, },
{
name: "select with date",
query: "select * from user where date='2006-01-02 15:04:05' and name=:1",
args: []interface{}{"foo"},
expect: "select * from user where date='2006-01-02 15:04:05' and name='foo'",
},
{
name: "select with date and escape",
query: `select * from user where date=' 2006-01-02 15:04:05 \'' and name=:1`,
args: []interface{}{"foo"},
expect: `select * from user where date=' 2006-01-02 15:04:05 \'' and name='foo'`,
},
{
name: "select with date and bad arg",
query: `select * from user where date='2006-01-02 15:04:05 \'' and name=:a`,
args: []interface{}{"foo"},
hasErr: true,
},
{
name: "select with date and escape error",
query: `select * from user where date='2006-01-02 15:04:05 \`,
args: []interface{}{"foo"},
hasErr: true,
},
} }
for _, test := range tests { for _, test := range tests {
@ -108,6 +132,7 @@ func TestFormat(t *testing.T) {
if test.hasErr { if test.hasErr {
assert.NotNil(t, err) assert.NotNil(t, err)
} else { } else {
assert.Nil(t, err)
assert.Equal(t, test.expect, actual) assert.Equal(t, test.expect, actual)
} }
}) })

@ -69,7 +69,6 @@ func Parse(filename, database string) ([]*Table, error) {
} }
nameOriginals := parseNameOriginal(tables) nameOriginals := parseNameOriginal(tables)
indexNameGen := func(column ...string) string { indexNameGen := func(column ...string) string {
return strings.Join(column, "_") return strings.Join(column, "_")
} }
@ -77,14 +76,12 @@ func Parse(filename, database string) ([]*Table, error) {
prefix := filepath.Base(filename) prefix := filepath.Base(filename)
var list []*Table var list []*Table
for indexTable, e := range tables { for indexTable, e := range tables {
columns := e.Columns
var ( var (
primaryColumn string
primaryColumnSet = collection.NewSet() primaryColumnSet = collection.NewSet()
uniqueKeyMap = make(map[string][]string)
primaryColumn string normalKeyMap = make(map[string][]string)
uniqueKeyMap = make(map[string][]string) columns = e.Columns
normalKeyMap = make(map[string][]string)
) )
for _, column := range columns { for _, column := range columns {

Loading…
Cancel
Save