feat: check key overwritten

kevin 2 years ago
parent 93fcf899dc
commit b61c94bb66

@ -13,12 +13,15 @@ import (
"github.com/zeromicro/go-zero/internal/encoding" "github.com/zeromicro/go-zero/internal/encoding"
) )
var loaders = map[string]func([]byte, interface{}) error{ var (
".json": LoadFromJsonBytes, loaders = map[string]func([]byte, any) error{
".toml": LoadFromTomlBytes, ".json": LoadFromJsonBytes,
".yaml": LoadFromYamlBytes, ".toml": LoadFromTomlBytes,
".yml": LoadFromYamlBytes, ".yaml": LoadFromYamlBytes,
} ".yml": LoadFromYamlBytes,
}
emptyFieldInfo fieldInfo
)
type fieldInfo struct { type fieldInfo struct {
children map[string]fieldInfo children map[string]fieldInfo
@ -62,7 +65,11 @@ func LoadFromJsonBytes(content []byte, v interface{}) error {
return err return err
} }
finfo := buildFieldsInfo(reflect.TypeOf(v)) finfo, err := buildFieldsInfo(reflect.TypeOf(v))
if err != nil {
return err
}
lowerCaseKeyMap := toLowerCaseKeyMap(m, finfo) lowerCaseKeyMap := toLowerCaseKeyMap(m, finfo)
return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase)) return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
@ -107,19 +114,29 @@ func MustLoad(path string, v interface{}, opts ...Option) {
} }
} }
func addOrMergeFields(info fieldInfo, key string, child fieldInfo) { func addOrMergeFields(info fieldInfo, key string, child fieldInfo) error {
if prev, ok := info.children[key]; ok { if prev, ok := info.children[key]; ok {
if len(child.children) == 0 && child.mapField == nil {
return newDupKeyError(key)
}
// merge fields // merge fields
for k, v := range child.children { for k, v := range child.children {
if _, ok = prev.children[k]; ok {
return newDupKeyError(k)
}
prev.children[k] = v prev.children[k] = v
} }
prev.mapField = child.mapField prev.mapField = child.mapField
} else { } else {
info.children[key] = child info.children[key] = child
} }
return nil
} }
func buildFieldsInfo(tp reflect.Type) fieldInfo { func buildFieldsInfo(tp reflect.Type) (fieldInfo, error) {
tp = mapping.Deref(tp) tp = mapping.Deref(tp)
switch tp.Kind() { switch tp.Kind() {
@ -128,11 +145,11 @@ func buildFieldsInfo(tp reflect.Type) fieldInfo {
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
return buildFieldsInfo(mapping.Deref(tp.Elem())) return buildFieldsInfo(mapping.Deref(tp.Elem()))
default: default:
return fieldInfo{} return emptyFieldInfo, nil
} }
} }
func buildStructFieldsInfo(tp reflect.Type) fieldInfo { func buildStructFieldsInfo(tp reflect.Type) (fieldInfo, error) {
info := fieldInfo{ info := fieldInfo{
children: make(map[string]fieldInfo), children: make(map[string]fieldInfo),
} }
@ -146,17 +163,31 @@ func buildStructFieldsInfo(tp reflect.Type) fieldInfo {
if field.Anonymous { if field.Anonymous {
switch ft.Kind() { switch ft.Kind() {
case reflect.Struct: case reflect.Struct:
fields := buildFieldsInfo(ft) fields, err := buildFieldsInfo(ft)
if err != nil {
return emptyFieldInfo, err
}
for k, v := range fields.children { for k, v := range fields.children {
addOrMergeFields(info, k, v) if err = addOrMergeFields(info, k, v); err != nil {
return emptyFieldInfo, err
}
} }
info.mapField = fields.mapField info.mapField = fields.mapField
case reflect.Map: case reflect.Map:
elemField := buildFieldsInfo(mapping.Deref(ft.Elem())) elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil {
return emptyFieldInfo, err
}
if _, ok := info.children[lowerCaseName]; ok {
return emptyFieldInfo, newDupKeyError(lowerCaseName)
}
info.children[lowerCaseName] = fieldInfo{ info.children[lowerCaseName] = fieldInfo{
mapField: &elemField, mapField: &elemField,
} }
default: default:
if _, ok := info.children[lowerCaseName]; ok {
return emptyFieldInfo, newDupKeyError(lowerCaseName)
}
info.children[lowerCaseName] = fieldInfo{ info.children[lowerCaseName] = fieldInfo{
children: make(map[string]fieldInfo), children: make(map[string]fieldInfo),
} }
@ -165,20 +196,37 @@ func buildStructFieldsInfo(tp reflect.Type) fieldInfo {
} }
var finfo fieldInfo var finfo fieldInfo
var err error
switch ft.Kind() { switch ft.Kind() {
case reflect.Struct: case reflect.Struct:
finfo = buildFieldsInfo(ft) finfo, err = buildFieldsInfo(ft)
if err != nil {
return emptyFieldInfo, err
}
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
finfo = buildFieldsInfo(ft.Elem()) finfo, err = buildFieldsInfo(ft.Elem())
if err != nil {
return emptyFieldInfo, err
}
case reflect.Map: case reflect.Map:
elemInfo := buildFieldsInfo(mapping.Deref(ft.Elem())) elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil {
return emptyFieldInfo, err
}
finfo.mapField = &elemInfo finfo.mapField = &elemInfo
default:
finfo, err = buildFieldsInfo(ft)
if err != nil {
return emptyFieldInfo, err
}
} }
addOrMergeFields(info, lowerCaseName, finfo) if err := addOrMergeFields(info, lowerCaseName, finfo); err != nil {
return emptyFieldInfo, err
}
} }
return info return info, nil
} }
func toLowerCase(s string) string { func toLowerCase(s string) string {
@ -222,3 +270,15 @@ func toLowerCaseKeyMap(m map[string]any, info fieldInfo) map[string]any {
return res return res
} }
type dupKeyError struct {
key string
}
func newDupKeyError(key string) dupKeyError {
return dupKeyError{key: key}
}
func (e dupKeyError) Error() string {
return fmt.Sprintf("duplicated key %s", e.key)
}

@ -9,6 +9,8 @@ import (
"github.com/zeromicro/go-zero/core/hash" "github.com/zeromicro/go-zero/core/hash"
) )
var dupErr dupKeyError
func TestLoadConfig_notExists(t *testing.T) { func TestLoadConfig_notExists(t *testing.T) {
assert.NotNil(t, Load("not_a_file", nil)) assert.NotNil(t, Load("not_a_file", nil))
} }
@ -413,11 +415,7 @@ func TestLoadFromYamlItemOverlay(t *testing.T) {
`) `)
var c TestConfig var c TestConfig
if assert.NoError(t, LoadFromYamlBytes(input, &c)) { assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
assert.Equal(t, "localhost", c.Redis.Host)
assert.Equal(t, 6379, c.Redis.Port)
assert.Equal(t, "test", c.Server.Redis.Key)
}
} }
func TestLoadFromYamlItemOverlayReverse(t *testing.T) { func TestLoadFromYamlItemOverlayReverse(t *testing.T) {
@ -449,11 +447,7 @@ func TestLoadFromYamlItemOverlayReverse(t *testing.T) {
`) `)
var c TestConfig var c TestConfig
if assert.NoError(t, LoadFromYamlBytes(input, &c)) { assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
assert.Equal(t, "localhost", c.Redis.Host)
assert.Equal(t, 6379, c.Redis.Port)
assert.Equal(t, "test", c.Redis.Key)
}
} }
func TestLoadFromYamlItemOverlayWithMap(t *testing.T) { func TestLoadFromYamlItemOverlayWithMap(t *testing.T) {
@ -616,6 +610,126 @@ func TestUnmarshalJsonBytesWithMapTypeValueOfStruct(t *testing.T) {
} }
} }
func Test_checkInheritOverwrite(t *testing.T) {
t.Run("normal", func(t *testing.T) {
type Base struct {
Name string
}
type St1 struct {
Base
Name2 string
}
type St2 struct {
Base
Name2 string
}
type St3 struct {
*Base
Name2 string
}
type St4 struct {
*Base
Name2 *string
}
validate := func(val any) {
input := []byte(`{"Name": "hello", "Name2": "world"}`)
assert.NoError(t, LoadFromJsonBytes(input, val))
}
validate(&St1{})
validate(&St2{})
validate(&St3{})
validate(&St4{})
})
t.Run("Inherit Override", func(t *testing.T) {
type Base struct {
Name string
}
type St1 struct {
Base
Name string
}
type St2 struct {
Base
Name int
}
type St3 struct {
*Base
Name int
}
type St4 struct {
*Base
Name *string
}
validate := func(val any) {
input := []byte(`{"Name": "hello"}`)
err := LoadFromJsonBytes(input, val)
assert.ErrorAs(t, err, &dupErr)
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
}
validate(&St1{})
validate(&St2{})
validate(&St3{})
validate(&St4{})
})
t.Run("Inherit more", func(t *testing.T) {
type Base1 struct {
Name string
}
type St0 struct {
Base1
Name string
}
type St1 struct {
St0
Name string
}
type St2 struct {
St0
Name int
}
type St3 struct {
*St0
Name int
}
type St4 struct {
*St0
Name *int
}
validate := func(val any) {
input := []byte(`{"Name": "hello"}`)
err := LoadFromJsonBytes(input, val)
assert.ErrorAs(t, err, &dupErr)
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
}
validate(&St0{})
validate(&St1{})
validate(&St2{})
validate(&St3{})
validate(&St4{})
})
}
func createTempFile(ext, text string) (string, error) { func createTempFile(ext, text string) (string, error) {
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext) tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
if err != nil { if err != nil {

Loading…
Cancel
Save