diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index f6b0ed1a..4069492b 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -43,7 +43,8 @@ type ( UnmarshalOption func(*unmarshalOptions) unmarshalOptions struct { - fromString bool + fromString bool + canonicalKey func(key string) string } keyCache map[string][]string @@ -321,9 +322,12 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect if err != nil { return err } - + k := key + if u.opts.canonicalKey != nil { + k = u.opts.canonicalKey(key) + } fullName = join(fullName, key) - mapValue, hasValue := getValue(m, key) + mapValue, hasValue := getValue(m, k) if hasValue { return u.processNamedFieldWithValue(field, value, mapValue, key, opts, fullName) } @@ -621,6 +625,12 @@ func WithStringValues() UnmarshalOption { } } +func WithCanonicalKeyFunc(f func(string) string) UnmarshalOption { + return func(opt *unmarshalOptions) { + opt.canonicalKey = f + } +} + func fillDurationValue(fieldKind reflect.Kind, value reflect.Value, dur string) error { d, err := time.ParseDuration(dur) if err != nil { diff --git a/rest/httpx/requests.go b/rest/httpx/requests.go index 45dc7024..c6469bee 100644 --- a/rest/httpx/requests.go +++ b/rest/httpx/requests.go @@ -3,6 +3,7 @@ package httpx import ( "io" "net/http" + "net/textproto" "strings" "github.com/tal-tech/go-zero/core/mapping" @@ -23,7 +24,7 @@ const ( var ( formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues()) pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues()) - headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues()) + headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(), mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey)) ) // Parse parses the request. @@ -47,7 +48,6 @@ func Parse(r *http.Request, v interface{}) error { func ParseHeaders(r *http.Request, v interface{}) error { m := map[string]interface{}{} for k, v := range r.Header { - k = strings.ToLower(k) if len(v) == 1 { m[k] = v[0] } else { diff --git a/rest/httpx/requests_test.go b/rest/httpx/requests_test.go index cd590723..fa79366c 100644 --- a/rest/httpx/requests_test.go +++ b/rest/httpx/requests_test.go @@ -203,10 +203,16 @@ func BenchmarkParseAuto(b *testing.B) { } func TestParseHeaders(t *testing.T) { + type AnonymousStruct struct { + XRealIP string `header:"x-real-ip"` + Accept string `header:"Accept,optional"` + } v := struct { - Name string `header:"name"` - Percent string `header:"percent"` - Addrs []string `header:"addrs"` + Name string `header:"name,optional"` + Percent string `header:"percent"` + Addrs []string `header:"addrs"` + XForwardedFor string `header:"X-Forwarded-For,optional"` + AnonymousStruct }{} request, err := http.NewRequest("POST", "http://hello.com/", nil) if err != nil { @@ -216,6 +222,9 @@ func TestParseHeaders(t *testing.T) { request.Header.Set("percent", "1") request.Header.Add("addrs", "addr1") request.Header.Add("addrs", "addr2") + request.Header.Add("X-Forwarded-For", "10.0.10.11") + request.Header.Add("x-real-ip", "10.0.11.10") + request.Header.Add("Accept", "application/json") err = ParseHeaders(request, &v) if err != nil { t.Fatal(err) @@ -223,4 +232,7 @@ func TestParseHeaders(t *testing.T) { assert.Equal(t, "chenquan", v.Name) assert.Equal(t, "1", v.Percent) assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs) + assert.Equal(t, "10.0.10.11", v.XForwardedFor) + assert.Equal(t, "10.0.11.10", v.XRealIP) + assert.Equal(t, "application/json", v.Accept) }