diff --git a/gateway/requestparser.go b/gateway/requestparser.go index fe495b32..4de6d06c 100644 --- a/gateway/requestparser.go +++ b/gateway/requestparser.go @@ -7,12 +7,14 @@ import ( "github.com/fullstorydev/grpcurl" "github.com/golang/protobuf/jsonpb" + "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/pathvar" ) -func buildJsonRequestParser(v interface{}, resolver jsonpb.AnyResolver) (grpcurl.RequestParser, error) { +func buildJsonRequestParser(m map[string]interface{}, resolver jsonpb.AnyResolver) ( + grpcurl.RequestParser, error) { var buf bytes.Buffer - if err := json.NewEncoder(&buf).Encode(v); err != nil { + if err := json.NewEncoder(&buf).Encode(m); err != nil { return nil, err } @@ -21,12 +23,20 @@ func buildJsonRequestParser(v interface{}, resolver jsonpb.AnyResolver) (grpcurl func newRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.RequestParser, error) { vars := pathvar.Vars(r) - if len(vars) == 0 { + params, err := httpx.GetFormValues(r) + if err != nil { + return nil, err + } + + for k, v := range vars { + params[k] = v + } + if len(params) == 0 { return grpcurl.NewJSONRequestParser(r.Body, resolver), nil } if r.ContentLength == 0 { - return buildJsonRequestParser(vars, resolver) + return buildJsonRequestParser(params, resolver) } m := make(map[string]interface{}) @@ -34,7 +44,7 @@ func newRequestParser(r *http.Request, resolver jsonpb.AnyResolver) (grpcurl.Req return nil, err } - for k, v := range vars { + for k, v := range params { m[k] = v } diff --git a/gateway/requestparser_test.go b/gateway/requestparser_test.go index 3659b4d1..3207e072 100644 --- a/gateway/requestparser_test.go +++ b/gateway/requestparser_test.go @@ -46,3 +46,10 @@ func TestNewRequestParserWithVarsWithWrongBody(t *testing.T) { assert.NotNil(t, err) assert.Nil(t, parser) } + +func TestNewRequestParserWithForm(t *testing.T) { + req := httptest.NewRequest("GET", "/val?a=b", nil) + parser, err := newRequestParser(req, nil) + assert.Nil(t, err) + assert.NotNil(t, parser) +} diff --git a/rest/httpx/requests.go b/rest/httpx/requests.go index f8fe7169..e46025b5 100644 --- a/rest/httpx/requests.go +++ b/rest/httpx/requests.go @@ -49,24 +49,11 @@ func ParseHeaders(r *http.Request, v interface{}) error { // ParseForm parses the form request. func ParseForm(r *http.Request, v interface{}) error { - if err := r.ParseForm(); err != nil { + params, err := GetFormValues(r) + if err != nil { return err } - if err := r.ParseMultipartForm(maxMemory); err != nil { - if err != http.ErrNotMultipart { - return err - } - } - - params := make(map[string]interface{}, len(r.Form)) - for name := range r.Form { - formValue := r.Form.Get(name) - if len(formValue) > 0 { - params[name] = formValue - } - } - return formUnmarshaler.Unmarshal(params, v) } diff --git a/rest/httpx/util.go b/rest/httpx/util.go index a5f5196f..216abaa0 100644 --- a/rest/httpx/util.go +++ b/rest/httpx/util.go @@ -4,6 +4,29 @@ import "net/http" const xForwardedFor = "X-Forwarded-For" +// GetFormValues returns the form values. +func GetFormValues(r *http.Request) (map[string]interface{}, error) { + if err := r.ParseForm(); err != nil { + return nil, err + } + + if err := r.ParseMultipartForm(maxMemory); err != nil { + if err != http.ErrNotMultipart { + return nil, err + } + } + + params := make(map[string]interface{}, len(r.Form)) + for name := range r.Form { + formValue := r.Form.Get(name) + if len(formValue) > 0 { + params[name] = formValue + } + } + + return params, nil +} + // GetRemoteAddr returns the peer address, supports X-Forward-For. func GetRemoteAddr(r *http.Request) string { v := r.Header.Get(xForwardedFor)