diff --git a/core/mapping/marshaler.go b/core/mapping/marshaler.go new file mode 100644 index 00000000..6d1bee8e --- /dev/null +++ b/core/mapping/marshaler.go @@ -0,0 +1,176 @@ +package mapping + +import ( + "fmt" + "reflect" + "strings" +) + +const ( + emptyTag = "" + tagKVSeparator = ":" +) + +// Marshal marshals the given val and returns the map that contains the fields. +// optional=another is not implemented, and it's hard to implement and not common used. +func Marshal(val interface{}) (map[string]map[string]interface{}, error) { + ret := make(map[string]map[string]interface{}) + tp := reflect.TypeOf(val) + rv := reflect.ValueOf(val) + + for i := 0; i < tp.NumField(); i++ { + field := tp.Field(i) + value := rv.Field(i) + if err := processMember(field, value, ret); err != nil { + return nil, err + } + } + + return ret, nil +} + +func getTag(field reflect.StructField) (string, bool) { + tag := string(field.Tag) + if i := strings.Index(tag, tagKVSeparator); i >= 0 { + return strings.TrimSpace(tag[:i]), true + } + + return strings.TrimSpace(tag), false +} + +func processMember(field reflect.StructField, value reflect.Value, + collector map[string]map[string]interface{}) error { + var key string + var opt *fieldOptions + var err error + tag, ok := getTag(field) + if !ok { + tag = emptyTag + key = field.Name + } else { + key, opt, err = parseKeyAndOptions(tag, field) + if err != nil { + return err + } + + if err = validate(field, value, opt); err != nil { + return err + } + } + + val := value.Interface() + if opt != nil && opt.FromString { + val = fmt.Sprint(val) + } + + m, ok := collector[tag] + if ok { + m[key] = val + } else { + m = map[string]interface{}{ + key: val, + } + } + collector[tag] = m + + return nil +} + +func validate(field reflect.StructField, value reflect.Value, opt *fieldOptions) error { + if opt == nil || !opt.Optional { + if err := validateOptional(field, value); err != nil { + return err + } + } + + if opt == nil { + return nil + } + + if len(opt.Options) > 0 { + if err := validateOptions(value, opt); err != nil { + return err + } + } + + if opt.Range != nil { + if err := validateRange(value, opt); err != nil { + return err + } + } + + return nil +} + +func validateOptional(field reflect.StructField, value reflect.Value) error { + switch field.Type.Kind() { + case reflect.Ptr: + if value.IsNil() { + return fmt.Errorf("field %q is nil", field.Name) + } + case reflect.Array, reflect.Slice, reflect.Map: + if value.IsNil() || value.Len() == 0 { + return fmt.Errorf("field %q is empty", field.Name) + } + } + + return nil +} + +func validateOptions(value reflect.Value, opt *fieldOptions) error { + var found bool + val := fmt.Sprint(value.Interface()) + for i := range opt.Options { + if opt.Options[i] == val { + found = true + break + } + } + if !found { + return fmt.Errorf("field %q not in options", val) + } + + return nil +} + +func validateRange(value reflect.Value, opt *fieldOptions) error { + var val float64 + switch v := value.Interface().(type) { + case int: + val = float64(v) + case int8: + val = float64(v) + case int16: + val = float64(v) + case int32: + val = float64(v) + case int64: + val = float64(v) + case uint: + val = float64(v) + case uint8: + val = float64(v) + case uint16: + val = float64(v) + case uint32: + val = float64(v) + case uint64: + val = float64(v) + case float32: + val = float64(v) + case float64: + val = v + default: + return fmt.Errorf("unknown support type for range %q", value.Type().String()) + } + + // validates [left, right], [left, right), (left, right], (left, right) + if val < opt.Range.left || + (!opt.Range.leftInclude && val == opt.Range.left) || + val > opt.Range.right || + (!opt.Range.rightInclude && val == opt.Range.right) { + return fmt.Errorf("%v out of range", value.Interface()) + } + + return nil +} diff --git a/core/mapping/marshaler_test.go b/core/mapping/marshaler_test.go new file mode 100644 index 00000000..b7b45bdc --- /dev/null +++ b/core/mapping/marshaler_test.go @@ -0,0 +1,233 @@ +package mapping + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMarshal(t *testing.T) { + v := struct { + Name string `path:"name"` + Address string `json:"address,options=[beijing,shanghai]"` + Age int `json:"age"` + Anonymous bool + }{ + Name: "kevin", + Address: "shanghai", + Age: 20, + Anonymous: true, + } + + m, err := Marshal(v) + assert.Nil(t, err) + assert.Equal(t, "kevin", m["path"]["name"]) + assert.Equal(t, "shanghai", m["json"]["address"]) + assert.Equal(t, 20, m["json"]["age"].(int)) + assert.True(t, m[emptyTag]["Anonymous"].(bool)) +} + +func TestMarshal_OptionalPtr(t *testing.T) { + var val = 1 + v := struct { + Age *int `json:"age"` + }{ + Age: &val, + } + + m, err := Marshal(v) + assert.Nil(t, err) + assert.Equal(t, 1, *m["json"]["age"].(*int)) +} + +func TestMarshal_OptionalPtrNil(t *testing.T) { + v := struct { + Age *int `json:"age"` + }{} + + _, err := Marshal(v) + assert.NotNil(t, err) +} + +func TestMarshal_BadOptions(t *testing.T) { + v := struct { + Name string `json:"name,options"` + }{ + Name: "kevin", + } + + _, err := Marshal(v) + assert.NotNil(t, err) +} + +func TestMarshal_NotInOptions(t *testing.T) { + v := struct { + Name string `json:"name,options=[a,b]"` + }{ + Name: "kevin", + } + + _, err := Marshal(v) + assert.NotNil(t, err) +} + +func TestMarshal_Nested(t *testing.T) { + type address struct { + Country string `json:"country"` + City string `json:"city"` + } + v := struct { + Name string `json:"name,options=[kevin,wan]"` + Address address `json:"address"` + }{ + Name: "kevin", + Address: address{ + Country: "China", + City: "Shanghai", + }, + } + + m, err := Marshal(v) + assert.Nil(t, err) + assert.Equal(t, "kevin", m["json"]["name"]) + assert.Equal(t, "China", m["json"]["address"].(address).Country) + assert.Equal(t, "Shanghai", m["json"]["address"].(address).City) +} + +func TestMarshal_NestedPtr(t *testing.T) { + type address struct { + Country string `json:"country"` + City string `json:"city"` + } + v := struct { + Name string `json:"name,options=[kevin,wan]"` + Address *address `json:"address"` + }{ + Name: "kevin", + Address: &address{ + Country: "China", + City: "Shanghai", + }, + } + + m, err := Marshal(v) + assert.Nil(t, err) + assert.Equal(t, "kevin", m["json"]["name"]) + assert.Equal(t, "China", m["json"]["address"].(*address).Country) + assert.Equal(t, "Shanghai", m["json"]["address"].(*address).City) +} + +func TestMarshal_Slice(t *testing.T) { + v := struct { + Name []string `json:"name"` + }{ + Name: []string{"kevin", "wan"}, + } + + m, err := Marshal(v) + assert.Nil(t, err) + assert.ElementsMatch(t, []string{"kevin", "wan"}, m["json"]["name"].([]string)) +} + +func TestMarshal_SliceNil(t *testing.T) { + v := struct { + Name []string `json:"name"` + }{ + Name: nil, + } + + _, err := Marshal(v) + assert.NotNil(t, err) +} + +func TestMarshal_Range(t *testing.T) { + v := struct { + Int int `json:"int,range=[1:3]"` + Int8 int8 `json:"int8,range=[1:3)"` + Int16 int16 `json:"int16,range=(1:3]"` + Int32 int32 `json:"int32,range=(1:3)"` + Int64 int64 `json:"int64,range=(1:3)"` + Uint uint `json:"uint,range=[1:3]"` + Uint8 uint8 `json:"uint8,range=[1:3)"` + Uint16 uint16 `json:"uint16,range=(1:3]"` + Uint32 uint32 `json:"uint32,range=(1:3)"` + Uint64 uint64 `json:"uint64,range=(1:3)"` + Float32 float32 `json:"float32,range=(1:3)"` + Float64 float64 `json:"float64,range=(1:3)"` + }{ + Int: 1, + Int8: 1, + Int16: 2, + Int32: 2, + Int64: 2, + Uint: 1, + Uint8: 1, + Uint16: 2, + Uint32: 2, + Uint64: 2, + Float32: 2, + Float64: 2, + } + + m, err := Marshal(v) + assert.Nil(t, err) + assert.Equal(t, 1, m["json"]["int"].(int)) + assert.Equal(t, int8(1), m["json"]["int8"].(int8)) + assert.Equal(t, int16(2), m["json"]["int16"].(int16)) + assert.Equal(t, int32(2), m["json"]["int32"].(int32)) + assert.Equal(t, int64(2), m["json"]["int64"].(int64)) + assert.Equal(t, uint(1), m["json"]["uint"].(uint)) + assert.Equal(t, uint8(1), m["json"]["uint8"].(uint8)) + assert.Equal(t, uint16(2), m["json"]["uint16"].(uint16)) + assert.Equal(t, uint32(2), m["json"]["uint32"].(uint32)) + assert.Equal(t, uint64(2), m["json"]["uint64"].(uint64)) + assert.Equal(t, float32(2), m["json"]["float32"].(float32)) + assert.Equal(t, float64(2), m["json"]["float64"].(float64)) +} + +func TestMarshal_RangeOut(t *testing.T) { + tests := []interface{}{ + struct { + Int int `json:"int,range=[1:3]"` + }{ + Int: 4, + }, + struct { + Int int `json:"int,range=(1:3]"` + }{ + Int: 1, + }, + struct { + Int int `json:"int,range=[1:3)"` + }{ + Int: 3, + }, + struct { + Int int `json:"int,range=(1:3)"` + }{ + Int: 3, + }, + struct { + Bool bool `json:"bool,range=(1:3)"` + }{ + Bool: true, + }, + } + + for _, test := range tests { + _, err := Marshal(test) + assert.NotNil(t, err) + } +} + +func TestMarshal_FromString(t *testing.T) { + v := struct { + Age int `json:"age,string"` + }{ + Age: 10, + } + + m, err := Marshal(v) + assert.Nil(t, err) + assert.Equal(t, "10", m["json"]["age"].(string)) +} diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index a8be682c..303805dd 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -97,10 +97,6 @@ func (u *Unmarshaler) unmarshalWithFullName(m Valuer, v interface{}, fullName st numFields := rte.NumField() for i := 0; i < numFields; i++ { field := rte.Field(i) - if usingDifferentKeys(u.key, field) { - continue - } - if err := u.processField(field, rve.Field(i), m, fullName); err != nil { return err } diff --git a/rest/httpc/requests.go b/rest/httpc/requests.go index 85841769..06ca6a73 100644 --- a/rest/httpc/requests.go +++ b/rest/httpc/requests.go @@ -1,8 +1,17 @@ package httpc import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" "net/http" + nurl "net/url" + "strings" + "github.com/zeromicro/go-zero/core/lang" + "github.com/zeromicro/go-zero/core/mapping" "github.com/zeromicro/go-zero/rest/httpc/internal" ) @@ -10,6 +19,17 @@ var interceptors = []internal.Interceptor{ internal.LogInterceptor, } +// Do sends an HTTP request with the given arguments and returns an HTTP response. +// data is automatically marshal into a *httpRequest, typically it's defined in an API file. +func Do(ctx context.Context, method, url string, data interface{}) (*http.Response, error) { + req, err := buildRequest(ctx, method, url, data) + if err != nil { + return nil, err + } + + return DoRequest(req) +} + // DoRequest sends an HTTP request and returns an HTTP response. func DoRequest(r *http.Request) (*http.Response, error) { return request(r, defaultClient{}) @@ -27,6 +47,107 @@ func (c defaultClient) do(r *http.Request) (*http.Response, error) { return http.DefaultClient.Do(r) } +func buildFormQuery(u *nurl.URL, val map[string]interface{}) string { + query := u.Query() + for k, v := range val { + query.Add(k, fmt.Sprint(v)) + } + + return query.Encode() +} + +func buildRequest(ctx context.Context, method, url string, data interface{}) (*http.Request, error) { + u, err := nurl.Parse(url) + if err != nil { + return nil, err + } + + var val map[string]map[string]interface{} + if data != nil { + val, err = mapping.Marshal(data) + if err != nil { + return nil, err + } + } + + if err := fillPath(u, val[pathKey]); err != nil { + return nil, err + } + + var reader io.Reader + jsonVars, hasJsonBody := val[jsonKey] + if hasJsonBody { + if method == http.MethodGet { + return nil, ErrGetWithBody + } + + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + if err := enc.Encode(jsonVars); err != nil { + return nil, err + } + + reader = &buf + } + + req, err := http.NewRequestWithContext(ctx, method, u.String(), reader) + if err != nil { + return nil, err + } + + req.URL.RawQuery = buildFormQuery(u, val[formKey]) + fillHeader(req, val[headerKey]) + if hasJsonBody { + req.Header.Set(contentType, applicationJson) + } + + return req, nil +} + +func fillHeader(r *http.Request, val map[string]interface{}) { + for k, v := range val { + r.Header.Add(k, fmt.Sprint(v)) + } +} + +func fillPath(u *nurl.URL, val map[string]interface{}) error { + used := make(map[string]lang.PlaceholderType) + fields := strings.Split(u.Path, slash) + + for i := range fields { + field := fields[i] + if len(field) > 0 && field[0] == colon { + name := field[1:] + ival, ok := val[name] + if !ok { + return fmt.Errorf("missing path variable %q", name) + } + value := fmt.Sprint(ival) + if len(value) == 0 { + return fmt.Errorf("empty path variable %q", name) + } + fields[i] = value + used[name] = lang.Placeholder + } + } + + if len(val) != len(used) { + for key := range used { + delete(val, key) + } + + var unused []string + for key := range val { + unused = append(unused, key) + } + + return fmt.Errorf("more path variables are provided: %q", strings.Join(unused, ", ")) + } + + u.Path = strings.Join(fields, slash) + return nil +} + func request(r *http.Request, cli client) (*http.Response, error) { var respHandlers []internal.ResponseHandler for _, interceptor := range interceptors { diff --git a/rest/httpc/requests_test.go b/rest/httpc/requests_test.go index 815649af..34bd82af 100644 --- a/rest/httpc/requests_test.go +++ b/rest/httpc/requests_test.go @@ -1,14 +1,17 @@ package httpc import ( + "context" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/rest/httpx" + "github.com/zeromicro/go-zero/rest/router" ) -func TestDo(t *testing.T) { +func TestDoRequest(t *testing.T) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { })) defer svr.Close() @@ -19,7 +22,7 @@ func TestDo(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) } -func TestDoNotFound(t *testing.T) { +func TestDoRequest_NotFound(t *testing.T) { svr := httptest.NewServer(http.NotFoundHandler()) defer svr.Close() req, err := http.NewRequest(http.MethodPost, svr.URL, nil) @@ -30,7 +33,7 @@ func TestDoNotFound(t *testing.T) { assert.Equal(t, http.StatusNotFound, resp.StatusCode) } -func TestDoMoved(t *testing.T) { +func TestDoRequest_Moved(t *testing.T) { svr := httptest.NewServer(http.RedirectHandler("/foo", http.StatusMovedPermanently)) defer svr.Close() req, err := http.NewRequest(http.MethodGet, svr.URL, nil) @@ -39,3 +42,84 @@ func TestDoMoved(t *testing.T) { // too many redirects assert.NotNil(t, err) } + +func TestDo(t *testing.T) { + type Data struct { + Key string `path:"key"` + Value int `form:"value"` + Header string `header:"X-Header"` + Body string `json:"body"` + } + + rt := router.NewRouter() + err := rt.Handle(http.MethodPost, "/nodes/:key", + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req Data + assert.Nil(t, httpx.Parse(r, &req)) + })) + assert.Nil(t, err) + + svr := httptest.NewServer(http.HandlerFunc(rt.ServeHTTP)) + defer svr.Close() + + data := Data{ + Key: "foo", + Value: 10, + Header: "my-header", + Body: "my body", + } + resp, err := Do(context.Background(), http.MethodPost, svr.URL+"/nodes/:key", data) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestDo_BadRequest(t *testing.T) { + _, err := Do(context.Background(), http.MethodPost, ":/nodes/:key", nil) + assert.NotNil(t, err) + + val1 := struct { + Value string `json:"value,options=[a,b]"` + }{ + Value: "c", + } + _, err = Do(context.Background(), http.MethodPost, "/nodes/:key", val1) + assert.NotNil(t, err) + + val2 := struct { + Value string `path:"val"` + }{ + Value: "", + } + _, err = Do(context.Background(), http.MethodPost, "/nodes/:key", val2) + assert.NotNil(t, err) + + val3 := struct { + Value string `path:"key"` + Body string `json:"body"` + }{ + Value: "foo", + } + _, err = Do(context.Background(), http.MethodGet, "/nodes/:key", val3) + assert.NotNil(t, err) + + _, err = Do(context.Background(), "\n", "rtmp://nodes", nil) + assert.NotNil(t, err) + + val4 := struct { + Value string `path:"val"` + }{ + Value: "", + } + _, err = Do(context.Background(), http.MethodPost, "/nodes/:val", val4) + assert.NotNil(t, err) + + val5 := struct { + Value string `path:"val"` + Another int `path:"foo"` + }{ + Value: "1", + Another: 2, + } + _, err = Do(context.Background(), http.MethodPost, "/nodes/:val", val5) + assert.NotNil(t, err) +} diff --git a/rest/httpc/service.go b/rest/httpc/service.go index 54164962..fc80697d 100644 --- a/rest/httpc/service.go +++ b/rest/httpc/service.go @@ -1,6 +1,7 @@ package httpc import ( + "context" "net/http" "github.com/zeromicro/go-zero/core/breaker" @@ -12,6 +13,8 @@ type ( // Service represents a remote HTTP service. Service interface { + // Do sends an HTTP request with the given arguments and returns an HTTP response. + Do(ctx context.Context, method, url string, data interface{}) (*http.Response, error) // DoRequest sends a HTTP request to the service. DoRequest(r *http.Request) (*http.Response, error) } @@ -39,6 +42,16 @@ func NewServiceWithClient(name string, cli *http.Client, opts ...Option) Service } } +// Do sends an HTTP request with the given arguments and returns an HTTP response. +func (s namedService) Do(ctx context.Context, method, url string, data interface{}) (*http.Response, error) { + req, err := buildRequest(ctx, method, url, data) + if err != nil { + return nil, err + } + + return s.DoRequest(req) +} + // DoRequest sends an HTTP request to the service. func (s namedService) DoRequest(r *http.Request) (*http.Response, error) { return request(r, s) diff --git a/rest/httpc/service_test.go b/rest/httpc/service_test.go index ea259796..0af1c6ee 100644 --- a/rest/httpc/service_test.go +++ b/rest/httpc/service_test.go @@ -1,6 +1,7 @@ package httpc import ( + "context" "net/http" "net/http/httptest" "testing" @@ -8,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNamedService_Do(t *testing.T) { +func TestNamedService_DoRequest(t *testing.T) { svr := httptest.NewServer(http.RedirectHandler("/foo", http.StatusMovedPermanently)) defer svr.Close() req, err := http.NewRequest(http.MethodGet, svr.URL, nil) @@ -19,7 +20,7 @@ func TestNamedService_Do(t *testing.T) { assert.NotNil(t, err) } -func TestNamedService_Get(t *testing.T) { +func TestNamedService_DoRequestGet(t *testing.T) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("foo", r.Header.Get("foo")) })) @@ -36,7 +37,7 @@ func TestNamedService_Get(t *testing.T) { assert.Equal(t, "bar", resp.Header.Get("foo")) } -func TestNamedService_Post(t *testing.T) { +func TestNamedService_DoRequestPost(t *testing.T) { svr := httptest.NewServer(http.NotFoundHandler()) defer svr.Close() service := NewService("foo") @@ -47,3 +48,38 @@ func TestNamedService_Post(t *testing.T) { assert.Nil(t, err) assert.Equal(t, http.StatusNotFound, resp.StatusCode) } + +func TestNamedService_Do(t *testing.T) { + type Data struct { + Key string `path:"key"` + Value int `form:"value"` + Header string `header:"X-Header"` + Body string `json:"body"` + } + + svr := httptest.NewServer(http.NotFoundHandler()) + defer svr.Close() + + service := NewService("foo") + data := Data{ + Key: "foo", + Value: 10, + Header: "my-header", + Body: "my body", + } + resp, err := service.Do(context.Background(), http.MethodPost, svr.URL+"/nodes/:key", data) + assert.Nil(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) +} + +func TestNamedService_DoBadRequest(t *testing.T) { + val := struct { + Value string `json:"value,options=[a,b]"` + }{ + Value: "c", + } + + service := NewService("foo") + _, err := service.Do(context.Background(), http.MethodPost, "/nodes/:key", val) + assert.NotNil(t, err) +} diff --git a/rest/httpc/vars.go b/rest/httpc/vars.go index 32d68da9..03349774 100644 --- a/rest/httpc/vars.go +++ b/rest/httpc/vars.go @@ -1,6 +1,17 @@ package httpc +import "errors" + const ( + pathKey = "path" + formKey = "form" + headerKey = "header" + jsonKey = "json" + slash = "/" + colon = ':' contentType = "Content-Type" applicationJson = "application/json" ) + +// ErrGetWithBody indicates that GET request with body. +var ErrGetWithBody = errors.New("HTTP GET should not have body")