fix http header binding failure bug #885 (#887)

master
voidint 3 years ago committed by GitHub
parent 872e75e10d
commit 28a7c9d38f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -44,6 +44,7 @@ type (
unmarshalOptions struct { unmarshalOptions struct {
fromString bool fromString bool
canonicalKey func(key string) string
} }
keyCache map[string][]string keyCache map[string][]string
@ -321,9 +322,12 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
if err != nil { if err != nil {
return err return err
} }
k := key
if u.opts.canonicalKey != nil {
k = u.opts.canonicalKey(key)
}
fullName = join(fullName, key) fullName = join(fullName, key)
mapValue, hasValue := getValue(m, key) mapValue, hasValue := getValue(m, k)
if hasValue { if hasValue {
return u.processNamedFieldWithValue(field, value, mapValue, key, opts, fullName) return u.processNamedFieldWithValue(field, value, mapValue, key, opts, fullName)
} }
@ -621,6 +625,12 @@ func WithStringValues() UnmarshalOption {
} }
} }
func WithCanonicalKeyFunc(f func(string) string) UnmarshalOption {
return func(opt *unmarshalOptions) {
opt.canonicalKey = f
}
}
func fillDurationValue(fieldKind reflect.Kind, value reflect.Value, dur string) error { func fillDurationValue(fieldKind reflect.Kind, value reflect.Value, dur string) error {
d, err := time.ParseDuration(dur) d, err := time.ParseDuration(dur)
if err != nil { if err != nil {

@ -3,6 +3,7 @@ package httpx
import ( import (
"io" "io"
"net/http" "net/http"
"net/textproto"
"strings" "strings"
"github.com/tal-tech/go-zero/core/mapping" "github.com/tal-tech/go-zero/core/mapping"
@ -23,7 +24,7 @@ const (
var ( var (
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues()) formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues()) pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues()) headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(), mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey))
) )
// Parse parses the request. // Parse parses the request.
@ -47,7 +48,6 @@ func Parse(r *http.Request, v interface{}) error {
func ParseHeaders(r *http.Request, v interface{}) error { func ParseHeaders(r *http.Request, v interface{}) error {
m := map[string]interface{}{} m := map[string]interface{}{}
for k, v := range r.Header { for k, v := range r.Header {
k = strings.ToLower(k)
if len(v) == 1 { if len(v) == 1 {
m[k] = v[0] m[k] = v[0]
} else { } else {

@ -203,10 +203,16 @@ func BenchmarkParseAuto(b *testing.B) {
} }
func TestParseHeaders(t *testing.T) { func TestParseHeaders(t *testing.T) {
type AnonymousStruct struct {
XRealIP string `header:"x-real-ip"`
Accept string `header:"Accept,optional"`
}
v := struct { v := struct {
Name string `header:"name"` Name string `header:"name,optional"`
Percent string `header:"percent"` Percent string `header:"percent"`
Addrs []string `header:"addrs"` Addrs []string `header:"addrs"`
XForwardedFor string `header:"X-Forwarded-For,optional"`
AnonymousStruct
}{} }{}
request, err := http.NewRequest("POST", "http://hello.com/", nil) request, err := http.NewRequest("POST", "http://hello.com/", nil)
if err != nil { if err != nil {
@ -216,6 +222,9 @@ func TestParseHeaders(t *testing.T) {
request.Header.Set("percent", "1") request.Header.Set("percent", "1")
request.Header.Add("addrs", "addr1") request.Header.Add("addrs", "addr1")
request.Header.Add("addrs", "addr2") request.Header.Add("addrs", "addr2")
request.Header.Add("X-Forwarded-For", "10.0.10.11")
request.Header.Add("x-real-ip", "10.0.11.10")
request.Header.Add("Accept", "application/json")
err = ParseHeaders(request, &v) err = ParseHeaders(request, &v)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -223,4 +232,7 @@ func TestParseHeaders(t *testing.T) {
assert.Equal(t, "chenquan", v.Name) assert.Equal(t, "chenquan", v.Name)
assert.Equal(t, "1", v.Percent) assert.Equal(t, "1", v.Percent)
assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs) assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs)
assert.Equal(t, "10.0.10.11", v.XForwardedFor)
assert.Equal(t, "10.0.11.10", v.XRealIP)
assert.Equal(t, "application/json", v.Accept)
} }

Loading…
Cancel
Save