diff --git a/core/conf/config.go b/core/conf/config.go index 002a4d91..f63efb0b 100644 --- a/core/conf/config.go +++ b/core/conf/config.go @@ -13,12 +13,15 @@ import ( "github.com/zeromicro/go-zero/internal/encoding" ) -var loaders = map[string]func([]byte, any) error{ - ".json": LoadFromJsonBytes, - ".toml": LoadFromTomlBytes, - ".yaml": LoadFromYamlBytes, - ".yml": LoadFromYamlBytes, -} +var ( + loaders = map[string]func([]byte, any) error{ + ".json": LoadFromJsonBytes, + ".toml": LoadFromTomlBytes, + ".yaml": LoadFromYamlBytes, + ".yml": LoadFromYamlBytes, + } + emptyFieldInfo fieldInfo +) type fieldInfo struct { children map[string]fieldInfo @@ -62,7 +65,11 @@ func LoadFromJsonBytes(content []byte, v any) error { return err } - finfo := buildFieldsInfo(reflect.TypeOf(v)) + finfo, err := buildFieldsInfo(reflect.TypeOf(v)) + if err != nil { + return err + } + lowerCaseKeyMap := toLowerCaseKeyMap(m, finfo) return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase)) @@ -107,19 +114,29 @@ func MustLoad(path string, v any, opts ...Option) { } } -func addOrMergeFields(info fieldInfo, key string, child fieldInfo) { +func addOrMergeFields(info fieldInfo, key string, child fieldInfo) error { if prev, ok := info.children[key]; ok { + if len(child.children) == 0 && child.mapField == nil { + return newDupKeyError(key) + } + // merge fields for k, v := range child.children { + if _, ok = prev.children[k]; ok { + return newDupKeyError(k) + } + prev.children[k] = v } prev.mapField = child.mapField } else { info.children[key] = child } + + return nil } -func buildFieldsInfo(tp reflect.Type) fieldInfo { +func buildFieldsInfo(tp reflect.Type) (fieldInfo, error) { tp = mapping.Deref(tp) switch tp.Kind() { @@ -128,11 +145,11 @@ func buildFieldsInfo(tp reflect.Type) fieldInfo { case reflect.Array, reflect.Slice: return buildFieldsInfo(mapping.Deref(tp.Elem())) default: - return fieldInfo{} + return emptyFieldInfo, nil } } -func buildStructFieldsInfo(tp reflect.Type) fieldInfo { +func buildStructFieldsInfo(tp reflect.Type) (fieldInfo, error) { info := fieldInfo{ children: make(map[string]fieldInfo), } @@ -146,17 +163,31 @@ func buildStructFieldsInfo(tp reflect.Type) fieldInfo { if field.Anonymous { switch ft.Kind() { case reflect.Struct: - fields := buildFieldsInfo(ft) + fields, err := buildFieldsInfo(ft) + if err != nil { + return emptyFieldInfo, err + } for k, v := range fields.children { - addOrMergeFields(info, k, v) + if err = addOrMergeFields(info, k, v); err != nil { + return emptyFieldInfo, err + } } info.mapField = fields.mapField case reflect.Map: - elemField := buildFieldsInfo(mapping.Deref(ft.Elem())) + elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem())) + if err != nil { + return emptyFieldInfo, err + } + if _, ok := info.children[lowerCaseName]; ok { + return emptyFieldInfo, newDupKeyError(lowerCaseName) + } info.children[lowerCaseName] = fieldInfo{ mapField: &elemField, } default: + if _, ok := info.children[lowerCaseName]; ok { + return emptyFieldInfo, newDupKeyError(lowerCaseName) + } info.children[lowerCaseName] = fieldInfo{ children: make(map[string]fieldInfo), } @@ -165,20 +196,37 @@ func buildStructFieldsInfo(tp reflect.Type) fieldInfo { } var finfo fieldInfo + var err error switch ft.Kind() { case reflect.Struct: - finfo = buildFieldsInfo(ft) + finfo, err = buildFieldsInfo(ft) + if err != nil { + return emptyFieldInfo, err + } case reflect.Array, reflect.Slice: - finfo = buildFieldsInfo(ft.Elem()) + finfo, err = buildFieldsInfo(ft.Elem()) + if err != nil { + return emptyFieldInfo, err + } case reflect.Map: - elemInfo := buildFieldsInfo(mapping.Deref(ft.Elem())) + elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem())) + if err != nil { + return emptyFieldInfo, err + } finfo.mapField = &elemInfo + default: + finfo, err = buildFieldsInfo(ft) + if err != nil { + return emptyFieldInfo, err + } } - addOrMergeFields(info, lowerCaseName, finfo) + if err := addOrMergeFields(info, lowerCaseName, finfo); err != nil { + return emptyFieldInfo, err + } } - return info + return info, nil } func toLowerCase(s string) string { @@ -222,3 +270,15 @@ func toLowerCaseKeyMap(m map[string]any, info fieldInfo) map[string]any { return res } + +type dupKeyError struct { + key string +} + +func newDupKeyError(key string) dupKeyError { + return dupKeyError{key: key} +} + +func (e dupKeyError) Error() string { + return fmt.Sprintf("duplicated key %s", e.key) +} diff --git a/core/conf/config_test.go b/core/conf/config_test.go index e92152e7..c1b88d45 100644 --- a/core/conf/config_test.go +++ b/core/conf/config_test.go @@ -9,6 +9,8 @@ import ( "github.com/zeromicro/go-zero/core/hash" ) +var dupErr dupKeyError + func TestLoadConfig_notExists(t *testing.T) { assert.NotNil(t, Load("not_a_file", nil)) } @@ -413,11 +415,7 @@ func TestLoadFromYamlItemOverlay(t *testing.T) { `) var c TestConfig - if assert.NoError(t, LoadFromYamlBytes(input, &c)) { - assert.Equal(t, "localhost", c.Redis.Host) - assert.Equal(t, 6379, c.Redis.Port) - assert.Equal(t, "test", c.Server.Redis.Key) - } + assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr) } func TestLoadFromYamlItemOverlayReverse(t *testing.T) { @@ -449,11 +447,7 @@ func TestLoadFromYamlItemOverlayReverse(t *testing.T) { `) var c TestConfig - if assert.NoError(t, LoadFromYamlBytes(input, &c)) { - assert.Equal(t, "localhost", c.Redis.Host) - assert.Equal(t, 6379, c.Redis.Port) - assert.Equal(t, "test", c.Redis.Key) - } + assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr) } func TestLoadFromYamlItemOverlayWithMap(t *testing.T) { @@ -616,6 +610,126 @@ func TestUnmarshalJsonBytesWithMapTypeValueOfStruct(t *testing.T) { } } +func Test_checkInheritOverwrite(t *testing.T) { + t.Run("normal", func(t *testing.T) { + type Base struct { + Name string + } + + type St1 struct { + Base + Name2 string + } + + type St2 struct { + Base + Name2 string + } + + type St3 struct { + *Base + Name2 string + } + + type St4 struct { + *Base + Name2 *string + } + + validate := func(val any) { + input := []byte(`{"Name": "hello", "Name2": "world"}`) + assert.NoError(t, LoadFromJsonBytes(input, val)) + } + + validate(&St1{}) + validate(&St2{}) + validate(&St3{}) + validate(&St4{}) + }) + + t.Run("Inherit Override", func(t *testing.T) { + type Base struct { + Name string + } + + type St1 struct { + Base + Name string + } + + type St2 struct { + Base + Name int + } + + type St3 struct { + *Base + Name int + } + + type St4 struct { + *Base + Name *string + } + + validate := func(val any) { + input := []byte(`{"Name": "hello"}`) + err := LoadFromJsonBytes(input, val) + assert.ErrorAs(t, err, &dupErr) + assert.Equal(t, newDupKeyError("name").Error(), err.Error()) + } + + validate(&St1{}) + validate(&St2{}) + validate(&St3{}) + validate(&St4{}) + }) + + t.Run("Inherit more", func(t *testing.T) { + type Base1 struct { + Name string + } + + type St0 struct { + Base1 + Name string + } + + type St1 struct { + St0 + Name string + } + + type St2 struct { + St0 + Name int + } + + type St3 struct { + *St0 + Name int + } + + type St4 struct { + *St0 + Name *int + } + + validate := func(val any) { + input := []byte(`{"Name": "hello"}`) + err := LoadFromJsonBytes(input, val) + assert.ErrorAs(t, err, &dupErr) + assert.Equal(t, newDupKeyError("name").Error(), err.Error()) + } + + validate(&St0{}) + validate(&St1{}) + validate(&St2{}) + validate(&St3{}) + validate(&St4{}) + }) +} + func createTempFile(ext, text string) (string, error) { tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext) if err != nil {