diff --git a/core/conf/config.go b/core/conf/config.go index f63efb0b..3e6a3abf 100644 --- a/core/conf/config.go +++ b/core/conf/config.go @@ -13,18 +13,17 @@ import ( "github.com/zeromicro/go-zero/internal/encoding" ) -var ( - loaders = map[string]func([]byte, any) error{ - ".json": LoadFromJsonBytes, - ".toml": LoadFromTomlBytes, - ".yaml": LoadFromYamlBytes, - ".yml": LoadFromYamlBytes, - } - emptyFieldInfo fieldInfo -) +var loaders = map[string]func([]byte, any) error{ + ".json": LoadFromJsonBytes, + ".toml": LoadFromTomlBytes, + ".yaml": LoadFromYamlBytes, + ".yml": LoadFromYamlBytes, +} +// children and mapField should not be both filled. +// named fields and map cannot be bound to the same field name. type fieldInfo struct { - children map[string]fieldInfo + children map[string]*fieldInfo mapField *fieldInfo } @@ -60,13 +59,13 @@ func LoadConfig(file string, v any, opts ...Option) error { // LoadFromJsonBytes loads config into v from content json bytes. func LoadFromJsonBytes(content []byte, v any) error { - var m map[string]any - if err := jsonx.Unmarshal(content, &m); err != nil { + finfo, err := buildFieldsInfo(reflect.TypeOf(v)) + if err != nil { return err } - finfo, err := buildFieldsInfo(reflect.TypeOf(v)) - if err != nil { + var m map[string]any + if err := jsonx.Unmarshal(content, &m); err != nil { return err } @@ -114,21 +113,15 @@ func MustLoad(path string, v any, opts ...Option) { } } -func addOrMergeFields(info fieldInfo, key string, child fieldInfo) error { +func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error { if prev, ok := info.children[key]; ok { - if len(child.children) == 0 && child.mapField == nil { + if child.mapField != nil { return newDupKeyError(key) } - // merge fields - for k, v := range child.children { - if _, ok = prev.children[k]; ok { - return newDupKeyError(k) - } - - prev.children[k] = v + if err := mergeFields(prev, key, child.children); err != nil { + return err } - prev.mapField = child.mapField } else { info.children[key] = child } @@ -136,7 +129,47 @@ func addOrMergeFields(info fieldInfo, key string, child fieldInfo) error { return nil } -func buildFieldsInfo(tp reflect.Type) (fieldInfo, error) { +func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error { + switch ft.Kind() { + case reflect.Struct: + fields, err := buildFieldsInfo(ft) + if err != nil { + return err + } + + for k, v := range fields.children { + if err = addOrMergeFields(info, k, v); err != nil { + return err + } + } + case reflect.Map: + elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem())) + if err != nil { + return err + } + + if _, ok := info.children[lowerCaseName]; ok { + return newDupKeyError(lowerCaseName) + } + + info.children[lowerCaseName] = &fieldInfo{ + children: make(map[string]*fieldInfo), + mapField: elemField, + } + default: + if _, ok := info.children[lowerCaseName]; ok { + return newDupKeyError(lowerCaseName) + } + + info.children[lowerCaseName] = &fieldInfo{ + children: make(map[string]*fieldInfo), + } + } + + return nil +} + +func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) { tp = mapping.Deref(tp) switch tp.Kind() { @@ -145,13 +178,50 @@ func buildFieldsInfo(tp reflect.Type) (fieldInfo, error) { case reflect.Array, reflect.Slice: return buildFieldsInfo(mapping.Deref(tp.Elem())) default: - return emptyFieldInfo, nil + return &fieldInfo{ + children: make(map[string]*fieldInfo), + }, nil + } +} + +func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error { + var finfo *fieldInfo + var err error + + switch ft.Kind() { + case reflect.Struct: + finfo, err = buildFieldsInfo(ft) + if err != nil { + return err + } + case reflect.Array, reflect.Slice: + finfo, err = buildFieldsInfo(ft.Elem()) + if err != nil { + return err + } + case reflect.Map: + elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem())) + if err != nil { + return err + } + + finfo = &fieldInfo{ + children: make(map[string]*fieldInfo), + mapField: elemInfo, + } + default: + finfo, err = buildFieldsInfo(ft) + if err != nil { + return err + } } + + return addOrMergeFields(info, lowerCaseName, finfo) } -func buildStructFieldsInfo(tp reflect.Type) (fieldInfo, error) { - info := fieldInfo{ - children: make(map[string]fieldInfo), +func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) { + info := &fieldInfo{ + children: make(map[string]*fieldInfo), } for i := 0; i < tp.NumField(); i++ { @@ -161,79 +231,39 @@ func buildStructFieldsInfo(tp reflect.Type) (fieldInfo, error) { ft := mapping.Deref(field.Type) // flatten anonymous fields if field.Anonymous { - switch ft.Kind() { - case reflect.Struct: - fields, err := buildFieldsInfo(ft) - if err != nil { - return emptyFieldInfo, err - } - for k, v := range fields.children { - if err = addOrMergeFields(info, k, v); err != nil { - return emptyFieldInfo, err - } - } - info.mapField = fields.mapField - case reflect.Map: - elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem())) - if err != nil { - return emptyFieldInfo, err - } - if _, ok := info.children[lowerCaseName]; ok { - return emptyFieldInfo, newDupKeyError(lowerCaseName) - } - info.children[lowerCaseName] = fieldInfo{ - mapField: &elemField, - } - default: - if _, ok := info.children[lowerCaseName]; ok { - return emptyFieldInfo, newDupKeyError(lowerCaseName) - } - info.children[lowerCaseName] = fieldInfo{ - children: make(map[string]fieldInfo), - } + if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil { + return nil, err } - continue + } else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil { + return nil, err } + } - var finfo fieldInfo - var err error - switch ft.Kind() { - case reflect.Struct: - finfo, err = buildFieldsInfo(ft) - if err != nil { - return emptyFieldInfo, err - } - case reflect.Array, reflect.Slice: - finfo, err = buildFieldsInfo(ft.Elem()) - if err != nil { - return emptyFieldInfo, err - } - case reflect.Map: - elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem())) - if err != nil { - return emptyFieldInfo, err - } - finfo.mapField = &elemInfo - default: - finfo, err = buildFieldsInfo(ft) - if err != nil { - return emptyFieldInfo, err - } - } + return info, nil +} + +func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error { + if len(children) == 0 { + return newDupKeyError(key) + } - if err := addOrMergeFields(info, lowerCaseName, finfo); err != nil { - return emptyFieldInfo, err + // merge fields + for k, v := range children { + if _, ok := prev.children[k]; ok { + return newDupKeyError(k) } + + prev.children[k] = v } - return info, nil + return nil } func toLowerCase(s string) string { return strings.ToLower(s) } -func toLowerCaseInterface(v any, info fieldInfo) any { +func toLowerCaseInterface(v any, info *fieldInfo) any { switch vv := v.(type) { case map[string]any: return toLowerCaseKeyMap(vv, info) @@ -248,7 +278,7 @@ func toLowerCaseInterface(v any, info fieldInfo) any { } } -func toLowerCaseKeyMap(m map[string]any, info fieldInfo) map[string]any { +func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any { res := make(map[string]any) for k, v := range m { @@ -262,7 +292,7 @@ func toLowerCaseKeyMap(m map[string]any, info fieldInfo) map[string]any { if ti, ok = info.children[lk]; ok { res[lk] = toLowerCaseInterface(v, ti) } else if info.mapField != nil { - res[k] = toLowerCaseInterface(v, *info.mapField) + res[k] = toLowerCaseInterface(v, info.mapField) } else { res[k] = v } diff --git a/core/conf/config_test.go b/core/conf/config_test.go index c1b88d45..1d4e0247 100644 --- a/core/conf/config_test.go +++ b/core/conf/config_test.go @@ -479,11 +479,7 @@ func TestLoadFromYamlItemOverlayWithMap(t *testing.T) { `) var c TestConfig - if assert.NoError(t, LoadFromYamlBytes(input, &c)) { - assert.Equal(t, "localhost", c.Server.Redis.Host) - assert.Equal(t, 6379, c.Server.Redis.Port) - assert.Equal(t, "test", c.Server.Redis.Key) - } + assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr) } func TestUnmarshalJsonBytesMap(t *testing.T) { @@ -610,7 +606,7 @@ func TestUnmarshalJsonBytesWithMapTypeValueOfStruct(t *testing.T) { } } -func Test_checkInheritOverwrite(t *testing.T) { +func Test_FieldOverwrite(t *testing.T) { t.Run("normal", func(t *testing.T) { type Base struct { Name string @@ -730,6 +726,292 @@ func Test_checkInheritOverwrite(t *testing.T) { }) } +func TestFieldOverwriteComplicated(t *testing.T) { + t.Run("double maps", func(t *testing.T) { + type ( + Base1 struct { + Values map[string]string + } + Base2 struct { + Values map[string]string + } + Config struct { + Base1 + Base2 + } + ) + + var c Config + input := []byte(`{"Values": {"Key": "Value"}}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) + + t.Run("merge children", func(t *testing.T) { + type ( + Inner1 struct { + Name string + } + Inner2 struct { + Age int + } + Base1 struct { + Inner Inner1 + } + Base2 struct { + Inner Inner2 + } + Config struct { + Base1 + Base2 + } + ) + + var c Config + input := []byte(`{"Inner": {"Name": "foo", "Age": 10}}`) + if assert.NoError(t, LoadFromJsonBytes(input, &c)) { + assert.Equal(t, "foo", c.Base1.Inner.Name) + assert.Equal(t, 10, c.Base2.Inner.Age) + } + }) + + t.Run("overwritten maps", func(t *testing.T) { + type ( + Inner struct { + Map map[string]string + } + Config struct { + Map map[string]string + Inner + } + ) + + var c Config + input := []byte(`{"Inner": {"Map": {"Key": "Value"}}}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) + + t.Run("overwritten nested maps", func(t *testing.T) { + type ( + Inner struct { + Map map[string]string + } + Middle1 struct { + Map map[string]string + Inner + } + Middle2 struct { + Map map[string]string + Inner + } + Config struct { + Middle1 + Middle2 + } + ) + + var c Config + input := []byte(`{"Middle1": {"Inner": {"Map": {"Key": "Value"}}}}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) + + t.Run("overwritten outer/inner maps", func(t *testing.T) { + type ( + Inner struct { + Map map[string]string + } + Middle struct { + Inner + Map map[string]string + } + Config struct { + Middle + } + ) + + var c Config + input := []byte(`{"Middle": {"Inner": {"Map": {"Key": "Value"}}}}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) + + t.Run("overwritten anonymous maps", func(t *testing.T) { + type ( + Inner struct { + Map map[string]string + } + Middle struct { + Inner + Map map[string]string + } + Elem map[string]Middle + Config struct { + Elem + } + ) + + var c Config + input := []byte(`{"Elem": {"Key": {"Inner": {"Map": {"Key": "Value"}}}}}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) + + t.Run("overwritten primitive and map", func(t *testing.T) { + type ( + Inner struct { + Value string + } + Elem map[string]Inner + Named struct { + Elem string + } + Config struct { + Named + Elem + } + ) + + var c Config + input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) + + t.Run("overwritten map and slice", func(t *testing.T) { + type ( + Inner struct { + Value string + } + Elem []Inner + Named struct { + Elem string + } + Config struct { + Named + Elem + } + ) + + var c Config + input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) + + t.Run("overwritten map and string", func(t *testing.T) { + type ( + Elem string + Named struct { + Elem string + } + Config struct { + Named + Elem + } + ) + + var c Config + input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) +} + +func TestLoadNamedFieldOverwritten(t *testing.T) { + t.Run("overwritten named struct", func(t *testing.T) { + type ( + Elem string + Named struct { + Elem string + } + Base struct { + Named + Elem + } + Config struct { + Val Base + } + ) + + var c Config + input := []byte(`{"Val": {"Elem": {"Key": {"Value": "Value"}}}}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) + + t.Run("overwritten named []struct", func(t *testing.T) { + type ( + Elem string + Named struct { + Elem string + } + Base struct { + Named + Elem + } + Config struct { + Vals []Base + } + ) + + var c Config + input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) + + t.Run("overwritten named map[string]struct", func(t *testing.T) { + type ( + Elem string + Named struct { + Elem string + } + Base struct { + Named + Elem + } + Config struct { + Vals map[string]Base + } + ) + + var c Config + input := []byte(`{"Vals": {"Key": {"Elem": {"Key": {"Value": "Value"}}}}}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) + + t.Run("overwritten named *struct", func(t *testing.T) { + type ( + Elem string + Named struct { + Elem string + } + Base struct { + Named + Elem + } + Config struct { + Vals *Base + } + ) + + var c Config + input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) + + t.Run("overwritten named struct", func(t *testing.T) { + type ( + Named struct { + Elem string + } + Base struct { + Named + Elem Named + } + Config struct { + Val Base + } + ) + + var c Config + input := []byte(`{"Val": {"Elem": {"Key": {"Value": "Value"}}}}`) + assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr) + }) +} + func createTempFile(ext, text string) (string, error) { tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext) if err != nil {