diff --git a/rest/httpx/requests.go b/rest/httpx/requests.go index 072b9af1..45dc7024 100644 --- a/rest/httpx/requests.go +++ b/rest/httpx/requests.go @@ -12,6 +12,7 @@ import ( const ( formKey = "form" pathKey = "path" + headerKey = "header" emptyJson = "{}" maxMemory = 32 << 20 // 32MB maxBodyLen = 8 << 20 // 8MB @@ -20,8 +21,9 @@ const ( ) var ( - formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues()) - pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues()) + formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues()) + pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues()) + headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues()) ) // Parse parses the request. @@ -34,9 +36,28 @@ func Parse(r *http.Request, v interface{}) error { return err } + if err := ParseHeaders(r, v); err != nil { + return err + } + return ParseJsonBody(r, v) } +// ParseHeaders parses the headers request. +func ParseHeaders(r *http.Request, v interface{}) error { + m := map[string]interface{}{} + for k, v := range r.Header { + k = strings.ToLower(k) + if len(v) == 1 { + m[k] = v[0] + } else { + m[k] = v + } + } + + return headerUnmarshaler.Unmarshal(m, v) +} + // ParseForm parses the form request. func ParseForm(r *http.Request, v interface{}) error { if err := r.ParseForm(); err != nil { diff --git a/rest/httpx/requests_test.go b/rest/httpx/requests_test.go index 896b0eb1..cd590723 100644 --- a/rest/httpx/requests_test.go +++ b/rest/httpx/requests_test.go @@ -201,3 +201,26 @@ func BenchmarkParseAuto(b *testing.B) { } } } + +func TestParseHeaders(t *testing.T) { + v := struct { + Name string `header:"name"` + Percent string `header:"percent"` + Addrs []string `header:"addrs"` + }{} + request, err := http.NewRequest("POST", "http://hello.com/", nil) + if err != nil { + t.Fatal(err) + } + request.Header.Set("name", "chenquan") + request.Header.Set("percent", "1") + request.Header.Add("addrs", "addr1") + request.Header.Add("addrs", "addr2") + err = ParseHeaders(request, &v) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "chenquan", v.Name) + assert.Equal(t, "1", v.Percent) + assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs) +}