|
|
|
package sqlx
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"strconv"
|
|
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
|
|
"github.com/zeromicro/go-zero/core/mapping"
|
|
|
|
)
|
|
|
|
|
|
|
|
var errUnbalancedEscape = errors.New("no char after escape char")
|
|
|
|
|
|
|
|
func desensitize(datasource string) string {
|
|
|
|
// remove account
|
|
|
|
pos := strings.LastIndex(datasource, "@")
|
|
|
|
if 0 <= pos && pos+1 < len(datasource) {
|
|
|
|
datasource = datasource[pos+1:]
|
|
|
|
}
|
|
|
|
|
|
|
|
return datasource
|
|
|
|
}
|
|
|
|
|
|
|
|
func escape(input string) string {
|
|
|
|
var b strings.Builder
|
|
|
|
|
|
|
|
for _, ch := range input {
|
|
|
|
switch ch {
|
|
|
|
case '\x00':
|
|
|
|
b.WriteString(`\x00`)
|
|
|
|
case '\r':
|
|
|
|
b.WriteString(`\r`)
|
|
|
|
case '\n':
|
|
|
|
b.WriteString(`\n`)
|
|
|
|
case '\\':
|
|
|
|
b.WriteString(`\\`)
|
|
|
|
case '\'':
|
|
|
|
b.WriteString(`\'`)
|
|
|
|
case '"':
|
|
|
|
b.WriteString(`\"`)
|
|
|
|
case '\x1a':
|
|
|
|
b.WriteString(`\x1a`)
|
|
|
|
default:
|
|
|
|
b.WriteRune(ch)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return b.String()
|
|
|
|
}
|
|
|
|
|
|
|
|
func format(query string, args ...any) (val string, err error) {
|
|
|
|
defer func() {
|
|
|
|
if err != nil {
|
|
|
|
err = newAcceptableError(err)
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
numArgs := len(args)
|
|
|
|
if numArgs == 0 {
|
|
|
|
return query, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
var b strings.Builder
|
|
|
|
var argIndex int
|
|
|
|
bytes := len(query)
|
|
|
|
|
|
|
|
for i := 0; i < bytes; i++ {
|
|
|
|
ch := query[i]
|
|
|
|
switch ch {
|
|
|
|
case '?':
|
|
|
|
if argIndex >= numArgs {
|
|
|
|
return "", fmt.Errorf("%d ? in sql, but only %d arguments provided",
|
|
|
|
argIndex+1, numArgs)
|
|
|
|
}
|
|
|
|
|
|
|
|
writeValue(&b, args[argIndex])
|
|
|
|
argIndex++
|
|
|
|
case ':', '$':
|
|
|
|
var j int
|
|
|
|
for j = i + 1; j < bytes; j++ {
|
|
|
|
char := query[j]
|
|
|
|
if char < '0' || '9' < char {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if j > i+1 {
|
|
|
|
index, err := strconv.Atoi(query[i+1 : j])
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
// index starts from 1 for pg or oracle
|
|
|
|
if index > argIndex {
|
|
|
|
argIndex = index
|
|
|
|
}
|
|
|
|
|
|
|
|
index--
|
|
|
|
if index < 0 || numArgs <= index {
|
|
|
|
return "", fmt.Errorf("wrong index %d in sql", index)
|
|
|
|
}
|
|
|
|
|
|
|
|
writeValue(&b, args[index])
|
|
|
|
i = j - 1
|
|
|
|
}
|
|
|
|
case '\'', '"', '`':
|
|
|
|
b.WriteByte(ch)
|
|
|
|
|
|
|
|
for j := i + 1; j < bytes; j++ {
|
|
|
|
cur := query[j]
|
|
|
|
b.WriteByte(cur)
|
|
|
|
|
|
|
|
if cur == '\\' {
|
|
|
|
j++
|
|
|
|
if j >= bytes {
|
|
|
|
return "", errUnbalancedEscape
|
|
|
|
}
|
|
|
|
|
|
|
|
b.WriteByte(query[j])
|
|
|
|
} else if cur == ch {
|
|
|
|
i = j
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
default:
|
|
|
|
b.WriteByte(ch)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if argIndex < numArgs {
|
|
|
|
return "", fmt.Errorf("%d arguments provided, not matching sql", argIndex)
|
|
|
|
}
|
|
|
|
|
|
|
|
return b.String(), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func logInstanceError(ctx context.Context, datasource string, err error) {
|
|
|
|
datasource = desensitize(datasource)
|
|
|
|
logx.WithContext(ctx).Errorf("Error on getting sql instance of %s: %v", datasource, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
func logSqlError(ctx context.Context, stmt string, err error) {
|
|
|
|
if err != nil && err != ErrNotFound {
|
|
|
|
logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func writeValue(buf *strings.Builder, arg any) {
|
|
|
|
switch v := arg.(type) {
|
|
|
|
case bool:
|
|
|
|
if v {
|
|
|
|
buf.WriteByte('1')
|
|
|
|
} else {
|
|
|
|
buf.WriteByte('0')
|
|
|
|
}
|
|
|
|
case string:
|
|
|
|
buf.WriteByte('\'')
|
|
|
|
buf.WriteString(escape(v))
|
|
|
|
buf.WriteByte('\'')
|
|
|
|
case time.Time:
|
|
|
|
buf.WriteByte('\'')
|
|
|
|
buf.WriteString(v.String())
|
|
|
|
buf.WriteByte('\'')
|
|
|
|
case *time.Time:
|
|
|
|
buf.WriteByte('\'')
|
|
|
|
buf.WriteString(v.String())
|
|
|
|
buf.WriteByte('\'')
|
|
|
|
default:
|
|
|
|
buf.WriteString(mapping.Repr(v))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type acceptableError struct {
|
|
|
|
err error
|
|
|
|
}
|
|
|
|
|
|
|
|
func newAcceptableError(err error) error {
|
|
|
|
return acceptableError{
|
|
|
|
err: err,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (e acceptableError) Error() string {
|
|
|
|
return e.err.Error()
|
|
|
|
}
|