diff --git a/core/mapping/fieldoptions.go b/core/mapping/fieldoptions.go index 37cc13e4..1e8a2ab6 100644 --- a/core/mapping/fieldoptions.go +++ b/core/mapping/fieldoptions.go @@ -8,6 +8,7 @@ type ( // use context and OptionalDep option to determine the value of Optional // nothing to do with context.Context fieldOptionsWithContext struct { + Inherit bool FromString bool Optional bool Options []string @@ -40,6 +41,10 @@ func (o *fieldOptionsWithContext) getDefault() (string, bool) { return o.Default, len(o.Default) > 0 } +func (o *fieldOptionsWithContext) inherit() bool { + return o != nil && o.Inherit +} + func (o *fieldOptionsWithContext) optional() bool { return o != nil && o.Optional } diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index 3f52f1fb..0f456a35 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -70,15 +70,15 @@ func UnmarshalKey(m map[string]interface{}, v interface{}) error { // Unmarshal unmarshals m into v. func (u *Unmarshaler) Unmarshal(m map[string]interface{}, v interface{}) error { - return u.UnmarshalValuer(MapValuer(m), v) + return u.UnmarshalValuer(mapValuer(m), v) } // UnmarshalValuer unmarshals m into v. func (u *Unmarshaler) UnmarshalValuer(m Valuer, v interface{}) error { - return u.unmarshalWithFullName(m, v, "") + return u.unmarshalWithFullName(simpleValuer{current: m}, v, "") } -func (u *Unmarshaler) unmarshalWithFullName(m Valuer, v interface{}, fullName string) error { +func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v interface{}, fullName string) error { rv := reflect.ValueOf(v) if err := ValidatePtr(&rv); err != nil { return err @@ -102,7 +102,7 @@ func (u *Unmarshaler) unmarshalWithFullName(m Valuer, v interface{}, fullName st } func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value reflect.Value, - m Valuer, fullName string) error { + m valuerWithParent, fullName string) error { key, options, err := u.parseOptionsWithContext(field, m, fullName) if err != nil { return err @@ -120,7 +120,7 @@ func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value ref } func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value, - key string, m Valuer, fullName string) error { + key string, m valuerWithParent, fullName string) error { var filled bool var required int var requiredFilled int @@ -161,7 +161,7 @@ func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, v } func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, value reflect.Value, - m Valuer, fullName string) error { + m valuerWithParent, fullName string) error { maybeNewValue(field, value) fieldType := Deref(field.Type) indirectValue := reflect.Indirect(value) @@ -175,8 +175,8 @@ func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, v return nil } -func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Value, m Valuer, - fullName string) error { +func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Value, + m valuerWithParent, fullName string) error { if usingDifferentKeys(u.key, field) { return nil } @@ -189,15 +189,23 @@ func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Valu } func (u *Unmarshaler) processFieldNotFromString(field reflect.StructField, value reflect.Value, - mapValue interface{}, opts *fieldOptionsWithContext, fullName string) error { + vp valueWithParent, opts *fieldOptionsWithContext, fullName string) error { fieldType := field.Type derefedFieldType := Deref(fieldType) typeKind := derefedFieldType.Kind() - valueKind := reflect.TypeOf(mapValue).Kind() + valueKind := reflect.TypeOf(vp.value).Kind() + mapValue := vp.value switch { case valueKind == reflect.Map && typeKind == reflect.Struct: - return u.processFieldStruct(field, value, mapValue, fullName) + if mv, ok := mapValue.(map[string]interface{}); ok { + return u.processFieldStruct(field, value, &simpleValuer{ + current: mapValuer(mv), + parent: vp.parent, + }, fullName) + } else { + return errTypeMismatch + } case valueKind == reflect.Map && typeKind == reflect.Map: return u.fillMap(field, value, mapValue) case valueKind == reflect.String && typeKind == reflect.Map: @@ -292,18 +300,7 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(field reflect.StructFi } func (u *Unmarshaler) processFieldStruct(field reflect.StructField, value reflect.Value, - mapValue interface{}, fullName string) error { - convertedValue, ok := mapValue.(map[string]interface{}) - if !ok { - valueKind := reflect.TypeOf(mapValue).Kind() - return fmt.Errorf("error: field: %s, expect map[string]interface{}, actual %v", fullName, valueKind) - } - - return u.processFieldStructWithMap(field, value, MapValuer(convertedValue), fullName) -} - -func (u *Unmarshaler) processFieldStructWithMap(field reflect.StructField, value reflect.Value, - m Valuer, fullName string) error { + m valuerWithParent, fullName string) error { if field.Type.Kind() == reflect.Ptr { baseType := Deref(field.Type) target := reflect.New(baseType).Elem() @@ -342,7 +339,7 @@ func (u *Unmarshaler) processFieldTextUnmarshaler(field reflect.StructField, val } func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect.Value, - m Valuer, fullName string) error { + m valuerWithParent, fullName string) error { key, opts, err := u.parseOptionsWithContext(field, m, fullName) if err != nil { return err @@ -353,16 +350,22 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect if u.opts.canonicalKey != nil { canonicalKey = u.opts.canonicalKey(key) } - mapValue, hasValue := getValue(m, canonicalKey) - if hasValue { - return u.processNamedFieldWithValue(field, value, mapValue, key, opts, fullName) + + valuer := createValuer(m, opts) + mapValue, hasValue := getValue(valuer, canonicalKey) + if !hasValue { + return u.processNamedFieldWithoutValue(field, value, opts, fullName) } - return u.processNamedFieldWithoutValue(field, value, opts, fullName) + return u.processNamedFieldWithValue(field, value, valueWithParent{ + value: mapValue, + parent: valuer, + }, key, opts, fullName) } func (u *Unmarshaler) processNamedFieldWithValue(field reflect.StructField, value reflect.Value, - mapValue interface{}, key string, opts *fieldOptionsWithContext, fullName string) error { + vp valueWithParent, key string, opts *fieldOptionsWithContext, fullName string) error { + mapValue := vp.value if mapValue == nil { if opts.optional() { return nil @@ -384,7 +387,7 @@ func (u *Unmarshaler) processNamedFieldWithValue(field reflect.StructField, valu fieldKind := Deref(field.Type).Kind() switch fieldKind { case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: - return u.processFieldNotFromString(field, value, mapValue, opts, fullName) + return u.processFieldNotFromString(field, value, vp, opts, fullName) default: if u.opts.fromString || opts.fromString() { valueKind := reflect.TypeOf(mapValue).Kind() @@ -403,7 +406,7 @@ func (u *Unmarshaler) processNamedFieldWithValue(field reflect.StructField, valu return fillPrimitive(field.Type, value, mapValue, opts, fullName) } - return u.processFieldNotFromString(field, value, mapValue, opts, fullName) + return u.processFieldNotFromString(field, value, vp, opts, fullName) } } @@ -431,7 +434,9 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(field reflect.StructField, v switch fieldKind { case reflect.Array, reflect.Map, reflect.Slice: if !opts.optional() { - return u.processFieldNotFromString(field, value, emptyMap, opts, fullName) + return u.processFieldNotFromString(field, value, valueWithParent{ + value: emptyMap, + }, opts, fullName) } case reflect.Struct: if !opts.optional() { @@ -439,10 +444,14 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(field reflect.StructField, v if err != nil { return err } + if required { return fmt.Errorf("%q is not set", fullName) } - return u.processFieldNotFromString(field, value, emptyMap, opts, fullName) + + return u.processFieldNotFromString(field, value, valueWithParent{ + value: emptyMap, + }, opts, fullName) } default: if !opts.optional() { @@ -738,6 +747,20 @@ func WithCanonicalKeyFunc(f func(string) string) UnmarshalOption { } } +func createValuer(v valuerWithParent, opts *fieldOptionsWithContext) valuerWithParent { + if opts.inherit() { + return recursiveValuer{ + current: v, + parent: v.Parent(), + } + } + + return simpleValuer{ + current: v, + parent: v.Parent(), + } +} + func fillDurationValue(fieldKind reflect.Kind, value reflect.Value, dur string) error { d, err := time.ParseDuration(dur) if err != nil { @@ -805,26 +828,30 @@ func fillWithSameType(field reflect.StructField, value reflect.Value, mapValue i } // getValue gets the value for the specific key, the key can be in the format of parentKey.childKey -func getValue(m Valuer, key string) (interface{}, bool) { +func getValue(m valuerWithParent, key string) (interface{}, bool) { keys := readKeys(key) return getValueWithChainedKeys(m, keys) } -func getValueWithChainedKeys(m Valuer, keys []string) (interface{}, bool) { - if len(keys) == 1 { +func getValueWithChainedKeys(m valuerWithParent, keys []string) (interface{}, bool) { + switch len(keys) { + case 0: + return nil, false + case 1: v, ok := m.Value(keys[0]) return v, ok - } - - if len(keys) > 1 { + default: if v, ok := m.Value(keys[0]); ok { if nextm, ok := v.(map[string]interface{}); ok { - return getValueWithChainedKeys(MapValuer(nextm), keys[1:]) + return getValueWithChainedKeys(recursiveValuer{ + current: mapValuer(nextm), + parent: m, + }, keys[1:]) } } - } - return nil, false + return nil, false + } } func join(elem ...string) string { diff --git a/core/mapping/unmarshaler_test.go b/core/mapping/unmarshaler_test.go index f445817d..e4fd78fd 100644 --- a/core/mapping/unmarshaler_test.go +++ b/core/mapping/unmarshaler_test.go @@ -7,10 +7,10 @@ import ( "strings" "testing" "time" + "unicode" "github.com/google/uuid" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/stringx" ) @@ -38,6 +38,29 @@ func TestUnmarshalWithoutTagName(t *testing.T) { assert.True(t, in.Optional) } +func TestUnmarshalWithoutTagNameWithCanonicalKey(t *testing.T) { + type inner struct { + Name string `key:"name"` + } + m := map[string]interface{}{ + "Name": "go-zero", + } + + var in inner + unmarshaler := NewUnmarshaler(defaultKeyName, WithCanonicalKeyFunc(func(s string) string { + first := true + return strings.Map(func(r rune) rune { + if first { + first = false + return unicode.ToTitle(r) + } + return r + }, s) + })) + assert.Nil(t, unmarshaler.Unmarshal(m, &in)) + assert.Equal(t, "go-zero", in.Name) +} + func TestUnmarshalBool(t *testing.T) { type inner struct { True bool `key:"yes"` @@ -2718,6 +2741,256 @@ func TestUnmarshalNestedMapSimpleTypeMatch(t *testing.T) { assert.Equal(t, "1", c.Anything["id"]) } +func TestUnmarshalInheritPrimitiveUseParent(t *testing.T) { + type ( + component struct { + Name string `key:"name"` + Discovery string `key:"discovery,inherit"` + } + server struct { + Discovery string `key:"discovery"` + Component component `key:"component"` + } + ) + + var s server + assert.NoError(t, UnmarshalKey(map[string]interface{}{ + "discovery": "localhost:8080", + "component": map[string]interface{}{ + "name": "test", + }, + }, &s)) + assert.Equal(t, "localhost:8080", s.Discovery) + assert.Equal(t, "localhost:8080", s.Component.Discovery) +} + +func TestUnmarshalInheritPrimitiveUseSelf(t *testing.T) { + type ( + component struct { + Name string `key:"name"` + Discovery string `key:"discovery,inherit"` + } + server struct { + Discovery string `key:"discovery"` + Component component `key:"component"` + } + ) + + var s server + assert.NoError(t, UnmarshalKey(map[string]interface{}{ + "discovery": "localhost:8080", + "component": map[string]interface{}{ + "name": "test", + "discovery": "localhost:8888", + }, + }, &s)) + assert.Equal(t, "localhost:8080", s.Discovery) + assert.Equal(t, "localhost:8888", s.Component.Discovery) +} + +func TestUnmarshalInheritPrimitiveNotExist(t *testing.T) { + type ( + component struct { + Name string `key:"name"` + Discovery string `key:"discovery,inherit"` + } + server struct { + Component component `key:"component"` + } + ) + + var s server + assert.NotNil(t, UnmarshalKey(map[string]interface{}{ + "component": map[string]interface{}{ + "name": "test", + }, + }, &s)) +} + +func TestUnmarshalInheritStructUseParent(t *testing.T) { + type ( + discovery struct { + Host string `key:"host"` + Port int `key:"port"` + } + component struct { + Name string `key:"name"` + Discovery discovery `key:"discovery,inherit"` + } + server struct { + Discovery discovery `key:"discovery"` + Component component `key:"component"` + } + ) + + var s server + assert.NoError(t, UnmarshalKey(map[string]interface{}{ + "discovery": map[string]interface{}{ + "host": "localhost", + "port": 8080, + }, + "component": map[string]interface{}{ + "name": "test", + }, + }, &s)) + assert.Equal(t, "localhost", s.Discovery.Host) + assert.Equal(t, 8080, s.Discovery.Port) + assert.Equal(t, "localhost", s.Component.Discovery.Host) + assert.Equal(t, 8080, s.Component.Discovery.Port) +} + +func TestUnmarshalInheritStructUseSelf(t *testing.T) { + type ( + discovery struct { + Host string `key:"host"` + Port int `key:"port"` + } + component struct { + Name string `key:"name"` + Discovery discovery `key:"discovery,inherit"` + } + server struct { + Discovery discovery `key:"discovery"` + Component component `key:"component"` + } + ) + + var s server + assert.NoError(t, UnmarshalKey(map[string]interface{}{ + "discovery": map[string]interface{}{ + "host": "localhost", + "port": 8080, + }, + "component": map[string]interface{}{ + "name": "test", + "discovery": map[string]interface{}{ + "host": "remotehost", + "port": 8888, + }, + }, + }, &s)) + assert.Equal(t, "localhost", s.Discovery.Host) + assert.Equal(t, 8080, s.Discovery.Port) + assert.Equal(t, "remotehost", s.Component.Discovery.Host) + assert.Equal(t, 8888, s.Component.Discovery.Port) +} + +func TestUnmarshalInheritStructNotExist(t *testing.T) { + type ( + discovery struct { + Host string `key:"host"` + Port int `key:"port"` + } + component struct { + Name string `key:"name"` + Discovery discovery `key:"discovery,inherit"` + } + server struct { + Component component `key:"component"` + } + ) + + var s server + assert.NotNil(t, UnmarshalKey(map[string]interface{}{ + "component": map[string]interface{}{ + "name": "test", + }, + }, &s)) +} + +func TestUnmarshalInheritStructUsePartial(t *testing.T) { + type ( + discovery struct { + Host string `key:"host"` + Port int `key:"port"` + } + component struct { + Name string `key:"name"` + Discovery discovery `key:"discovery,inherit"` + } + server struct { + Discovery discovery `key:"discovery"` + Component component `key:"component"` + } + ) + + var s server + assert.NoError(t, UnmarshalKey(map[string]interface{}{ + "discovery": map[string]interface{}{ + "host": "localhost", + "port": 8080, + }, + "component": map[string]interface{}{ + "name": "test", + "discovery": map[string]interface{}{ + "port": 8888, + }, + }, + }, &s)) + assert.Equal(t, "localhost", s.Discovery.Host) + assert.Equal(t, 8080, s.Discovery.Port) + assert.Equal(t, "localhost", s.Component.Discovery.Host) + assert.Equal(t, 8888, s.Component.Discovery.Port) +} + +func TestUnmarshalInheritStructUseSelfIncorrectType(t *testing.T) { + type ( + discovery struct { + Host string `key:"host"` + Port int `key:"port"` + } + component struct { + Name string `key:"name"` + Discovery discovery `key:"discovery,inherit"` + } + server struct { + Discovery discovery `key:"discovery"` + Component component `key:"component"` + } + ) + + var s server + assert.NotNil(t, UnmarshalKey(map[string]interface{}{ + "discovery": map[string]interface{}{ + "host": "localhost", + }, + "component": map[string]interface{}{ + "name": "test", + "discovery": map[string]string{ + "host": "remotehost", + }, + }, + }, &s)) +} + +func TestUnmarshaler_InheritFromGrandparent(t *testing.T) { + type ( + component struct { + Name string `key:"name"` + Discovery string `key:"discovery,inherit"` + } + middle struct { + Value component `key:"value"` + } + server struct { + Discovery string `key:"discovery"` + Middle middle `key:"middle"` + } + ) + + var s server + assert.NoError(t, UnmarshalKey(map[string]interface{}{ + "discovery": "localhost:8080", + "middle": map[string]interface{}{ + "value": map[string]interface{}{ + "name": "test", + }, + }, + }, &s)) + assert.Equal(t, "localhost:8080", s.Discovery) + assert.Equal(t, "localhost:8080", s.Middle.Value.Discovery) +} + func TestUnmarshalValuer(t *testing.T) { unmarshaler := NewUnmarshaler(jsonTagKey) var foo string diff --git a/core/mapping/utils.go b/core/mapping/utils.go index 7d8af4df..8d21fd58 100644 --- a/core/mapping/utils.go +++ b/core/mapping/utils.go @@ -15,6 +15,7 @@ import ( const ( defaultOption = "default" + inheritOption = "inherit" stringOption = "string" optionalOption = "optional" optionsOption = "options" @@ -335,6 +336,8 @@ func parseNumberRange(str string) (*numberRange, error) { func parseOption(fieldOpts *fieldOptions, fieldName, option string) error { switch { + case option == inheritOption: + fieldOpts.Inherit = true case option == stringOption: fieldOpts.FromString = true case strings.HasPrefix(option, optionalOption): diff --git a/core/mapping/valuer.go b/core/mapping/valuer.go index d1a0dfea..e22c877f 100644 --- a/core/mapping/valuer.go +++ b/core/mapping/valuer.go @@ -7,12 +7,94 @@ type ( Value(key string) (interface{}, bool) } - // A MapValuer is a map that can use Value method to get values with given keys. - MapValuer map[string]interface{} + // A valuerWithParent defines a node that has a parent node. + valuerWithParent interface { + Valuer + // Parent get the parent valuer for current node. + Parent() valuerWithParent + } + + // A node is a map that can use Value method to get values with given keys. + node struct { + current Valuer + parent valuerWithParent + } + + // A valueWithParent is used to wrap the value with its parent. + valueWithParent struct { + value interface{} + parent valuerWithParent + } + + // mapValuer is a type for map to meet the Valuer interface. + mapValuer map[string]interface{} + // simpleValuer is a type to get value from current node. + simpleValuer node + // recursiveValuer is a type to get the value recursively from current and parent nodes. + recursiveValuer node ) -// Value gets the value associated with the given key from mv. -func (mv MapValuer) Value(key string) (interface{}, bool) { +// Value gets the value assciated with the given key from mv. +func (mv mapValuer) Value(key string) (interface{}, bool) { v, ok := mv[key] return v, ok } + +// Value gets the value associated with the given key from sv. +func (sv simpleValuer) Value(key string) (interface{}, bool) { + v, ok := sv.current.Value(key) + return v, ok +} + +// Parent get the parent valuer from sv. +func (sv simpleValuer) Parent() valuerWithParent { + if sv.parent == nil { + return nil + } + + return recursiveValuer{ + current: sv.parent, + parent: sv.parent.Parent(), + } +} + +// Value gets the value associated with the given key from rv, +// and it will inherit the value from parent nodes. +func (rv recursiveValuer) Value(key string) (interface{}, bool) { + val, ok := rv.current.Value(key) + if !ok { + if parent := rv.Parent(); parent != nil { + return parent.Value(key) + } + + return nil, false + } + + if vm, ok := val.(map[string]interface{}); ok { + if parent := rv.Parent(); parent != nil { + pv, pok := parent.Value(key) + if pok { + if pm, ok := pv.(map[string]interface{}); ok { + for k, v := range vm { + pm[k] = v + } + return pm, true + } + } + } + } + + return val, true +} + +// Parent get the parent valuer from rv. +func (rv recursiveValuer) Parent() valuerWithParent { + if rv.parent == nil { + return nil + } + + return recursiveValuer{ + current: rv.parent, + parent: rv.parent.Parent(), + } +} diff --git a/core/mapping/valuer_test.go b/core/mapping/valuer_test.go new file mode 100644 index 00000000..01ee56f5 --- /dev/null +++ b/core/mapping/valuer_test.go @@ -0,0 +1,33 @@ +package mapping + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMapValuerWithInherit_Value(t *testing.T) { + input := map[string]interface{}{ + "discovery": map[string]interface{}{ + "host": "localhost", + "port": 8080, + }, + "component": map[string]interface{}{ + "name": "test", + }, + } + valuer := recursiveValuer{ + current: mapValuer(input["component"].(map[string]interface{})), + parent: simpleValuer{ + current: mapValuer(input), + }, + } + + val, ok := valuer.Value("discovery") + assert.True(t, ok) + + m, ok := val.(map[string]interface{}) + assert.True(t, ok) + assert.Equal(t, "localhost", m["host"]) + assert.Equal(t, 8080, m["port"]) +} diff --git a/zrpc/config.go b/zrpc/config.go index 36dd0074..225ee2eb 100644 --- a/zrpc/config.go +++ b/zrpc/config.go @@ -12,7 +12,7 @@ type ( RpcServerConf struct { service.ServiceConf ListenOn string - Etcd discov.EtcdConf `json:",optional"` + Etcd discov.EtcdConf `json:",optional,inherit"` Auth bool `json:",optional"` Redis redis.RedisKeyConf `json:",optional"` StrictControl bool `json:",optional"` @@ -25,7 +25,7 @@ type ( // A RpcClientConf is a rpc client config. RpcClientConf struct { - Etcd discov.EtcdConf `json:",optional"` + Etcd discov.EtcdConf `json:",optional,inherit"` Endpoints []string `json:",optional"` Target string `json:",optional"` App string `json:",optional"`