diff --git a/core/conf/config_test.go b/core/conf/config_test.go index 3e20a095..7735e619 100644 --- a/core/conf/config_test.go +++ b/core/conf/config_test.go @@ -35,11 +35,11 @@ func TestConfigJson(t *testing.T) { "c": "${FOO}", "d": "abcd!@#$112" }` + t.Setenv("FOO", "2") + for _, test := range tests { test := test t.Run(test, func(t *testing.T) { - os.Setenv("FOO", "2") - defer os.Unsetenv("FOO") tmpfile, err := createTempFile(test, text) assert.Nil(t, err) defer os.Remove(tmpfile) @@ -81,8 +81,7 @@ b = 1 c = "${FOO}" d = "abcd!@#$112" ` - os.Setenv("FOO", "2") - defer os.Unsetenv("FOO") + t.Setenv("FOO", "2") tmpfile, err := createTempFile(".toml", text) assert.Nil(t, err) defer os.Remove(tmpfile) @@ -207,8 +206,7 @@ b = 1 c = "${FOO}" d = "abcd!@#112" ` - os.Setenv("FOO", "2") - defer os.Unsetenv("FOO") + t.Setenv("FOO", "2") tmpfile, err := createTempFile(".toml", text) assert.Nil(t, err) defer os.Remove(tmpfile) @@ -239,11 +237,10 @@ func TestConfigJsonEnv(t *testing.T) { "c": "${FOO}", "d": "abcd!@#$a12 3" }` + t.Setenv("FOO", "2") for _, test := range tests { test := test t.Run(test, func(t *testing.T) { - os.Setenv("FOO", "2") - defer os.Unsetenv("FOO") tmpfile, err := createTempFile(test, text) assert.Nil(t, err) defer os.Remove(tmpfile) diff --git a/core/conf/properties_test.go b/core/conf/properties_test.go index 8d7ba0e2..19564c23 100644 --- a/core/conf/properties_test.go +++ b/core/conf/properties_test.go @@ -45,8 +45,7 @@ func TestPropertiesEnv(t *testing.T) { assert.Nil(t, err) defer os.Remove(tmpfile) - os.Setenv("FOO", "2") - defer os.Unsetenv("FOO") + t.Setenv("FOO", "2") props, err := LoadProperties(tmpfile, UseEnv()) assert.Nil(t, err) diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index 91939a27..1f5da75b 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -513,8 +513,8 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re vp valueWithParent, opts *fieldOptionsWithContext, fullName string) error { derefedFieldType := Deref(fieldType) typeKind := derefedFieldType.Kind() - valueKind := reflect.TypeOf(vp.value).Kind() mapValue := vp.value + valueKind := reflect.TypeOf(mapValue).Kind() switch { case valueKind == reflect.Map && typeKind == reflect.Struct: @@ -527,6 +527,8 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re current: mapValuer(mv), parent: vp.parent, }, fullName) + case typeKind == reflect.Slice && valueKind == reflect.Slice: + return u.fillSlice(fieldType, value, mapValue) case valueKind == reflect.Map && typeKind == reflect.Map: return u.fillMap(fieldType, value, mapValue) case valueKind == reflect.String && typeKind == reflect.Map: @@ -545,23 +547,16 @@ func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflec typeKind := Deref(fieldType).Kind() valueKind := reflect.TypeOf(mapValue).Kind() - switch { - case typeKind == reflect.Slice && valueKind == reflect.Slice: - return u.fillSlice(fieldType, value, mapValue) - case typeKind == reflect.Map && valueKind == reflect.Map: - return u.fillMap(fieldType, value, mapValue) + switch v := mapValue.(type) { + case json.Number: + return u.processFieldPrimitiveWithJSONNumber(fieldType, value, v, opts, fullName) default: - switch v := mapValue.(type) { - case json.Number: - return u.processFieldPrimitiveWithJSONNumber(fieldType, value, v, opts, fullName) - default: - if typeKind == valueKind { - if err := validateValueInOptions(mapValue, opts.options()); err != nil { - return err - } - - return fillWithSameType(fieldType, value, mapValue, opts) + if typeKind == valueKind { + if err := validateValueInOptions(mapValue, opts.options()); err != nil { + return err } + + return fillWithSameType(fieldType, value, mapValue, opts) } } diff --git a/core/mapping/unmarshaler_test.go b/core/mapping/unmarshaler_test.go index 08c5b834..ed21f32f 100644 --- a/core/mapping/unmarshaler_test.go +++ b/core/mapping/unmarshaler_test.go @@ -3,7 +3,6 @@ package mapping import ( "encoding/json" "fmt" - "os" "strconv" "strings" "testing" @@ -1454,18 +1453,42 @@ func TestUnmarshalMapOfStructError(t *testing.T) { } func TestUnmarshalSlice(t *testing.T) { - m := map[string]any{ - "Ids": []any{"first", "second"}, - } - var v struct { - Ids []string - } - ast := assert.New(t) - if ast.NoError(UnmarshalKey(m, &v)) { - ast.Equal(2, len(v.Ids)) - ast.Equal("first", v.Ids[0]) - ast.Equal("second", v.Ids[1]) - } + t.Run("slice of string", func(t *testing.T) { + m := map[string]any{ + "Ids": []any{"first", "second"}, + } + var v struct { + Ids []string + } + ast := assert.New(t) + if ast.NoError(UnmarshalKey(m, &v)) { + ast.Equal(2, len(v.Ids)) + ast.Equal("first", v.Ids[0]) + ast.Equal("second", v.Ids[1]) + } + }) + + t.Run("slice with type mismatch", func(t *testing.T) { + var v struct { + Ids string + } + assert.Error(t, NewUnmarshaler(jsonTagKey).Unmarshal([]any{1, 2}, &v)) + }) + + t.Run("slice", func(t *testing.T) { + var v []int + ast := assert.New(t) + if ast.NoError(NewUnmarshaler(jsonTagKey).Unmarshal([]any{1, 2}, &v)) { + ast.Equal(2, len(v)) + ast.Equal(1, v[0]) + ast.Equal(2, v[1]) + } + }) + + t.Run("slice with unsupported type", func(t *testing.T) { + var v int + assert.Error(t, NewUnmarshaler(jsonTagKey).Unmarshal(1, &v)) + }) } func TestUnmarshalSliceOfStruct(t *testing.T) { @@ -3529,8 +3552,7 @@ func TestUnmarshal_EnvString(t *testing.T) { envName = "TEST_NAME_STRING" envVal = "this is a name" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3547,8 +3569,7 @@ func TestUnmarshal_EnvStringOverwrite(t *testing.T) { envName = "TEST_NAME_STRING" envVal = "this is a name" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value if assert.NoError(t, UnmarshalKey(map[string]any{ @@ -3567,8 +3588,7 @@ func TestUnmarshal_EnvInt(t *testing.T) { envName = "TEST_NAME_INT" envVal = "123" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3585,8 +3605,7 @@ func TestUnmarshal_EnvIntOverwrite(t *testing.T) { envName = "TEST_NAME_INT" envVal = "123" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value if assert.NoError(t, UnmarshalKey(map[string]any{ @@ -3605,8 +3624,7 @@ func TestUnmarshal_EnvFloat(t *testing.T) { envName = "TEST_NAME_FLOAT" envVal = "123.45" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3623,8 +3641,7 @@ func TestUnmarshal_EnvFloatOverwrite(t *testing.T) { envName = "TEST_NAME_FLOAT" envVal = "123.45" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value if assert.NoError(t, UnmarshalKey(map[string]any{ @@ -3643,8 +3660,7 @@ func TestUnmarshal_EnvBoolTrue(t *testing.T) { envName = "TEST_NAME_BOOL_TRUE" envVal = "true" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3661,8 +3677,7 @@ func TestUnmarshal_EnvBoolFalse(t *testing.T) { envName = "TEST_NAME_BOOL_FALSE" envVal = "false" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3679,8 +3694,7 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) { envName = "TEST_NAME_BOOL_BAD" envVal = "bad" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value assert.Error(t, UnmarshalKey(emptyMap, &v)) @@ -3695,8 +3709,7 @@ func TestUnmarshal_EnvDuration(t *testing.T) { envName = "TEST_NAME_DURATION" envVal = "1s" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3713,8 +3726,7 @@ func TestUnmarshal_EnvDurationBadValue(t *testing.T) { envName = "TEST_NAME_BAD_DURATION" envVal = "bad" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value assert.Error(t, UnmarshalKey(emptyMap, &v)) @@ -3729,8 +3741,7 @@ func TestUnmarshal_EnvWithOptions(t *testing.T) { envName = "TEST_NAME_ENV_OPTIONS_MATCH" envVal = "123" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { @@ -3747,8 +3758,7 @@ func TestUnmarshal_EnvWithOptionsWrongValueBool(t *testing.T) { envName = "TEST_NAME_ENV_OPTIONS_BOOL" envVal = "false" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value assert.Error(t, UnmarshalKey(emptyMap, &v)) @@ -3763,8 +3773,7 @@ func TestUnmarshal_EnvWithOptionsWrongValueDuration(t *testing.T) { envName = "TEST_NAME_ENV_OPTIONS_DURATION" envVal = "4s" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value assert.Error(t, UnmarshalKey(emptyMap, &v)) @@ -3779,8 +3788,7 @@ func TestUnmarshal_EnvWithOptionsWrongValueNumber(t *testing.T) { envName = "TEST_NAME_ENV_OPTIONS_AGE" envVal = "30" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value assert.Error(t, UnmarshalKey(emptyMap, &v)) @@ -3795,8 +3803,7 @@ func TestUnmarshal_EnvWithOptionsWrongValueString(t *testing.T) { envName = "TEST_NAME_ENV_OPTIONS_STRING" envVal = "this is a name" ) - os.Setenv(envName, envVal) - defer os.Unsetenv(envName) + t.Setenv(envName, envVal) var v Value assert.Error(t, UnmarshalKey(emptyMap, &v)) @@ -4408,18 +4415,80 @@ func TestFillDefaultUnmarshal(t *testing.T) { } func Test_UnmarshalMap(t *testing.T) { - type Customer struct { - Names map[int]string `key:"names"` - } + t.Run("type mismatch", func(t *testing.T) { + type Customer struct { + Names map[int]string `key:"names"` + } - input := map[string]any{ - "names": map[string]any{ - "19": "Tom", - }, - } + input := map[string]any{ + "names": map[string]any{ + "19": "Tom", + }, + } + + var customer Customer + assert.ErrorIs(t, UnmarshalKey(input, &customer), errTypeMismatch) + }) + + t.Run("map type mismatch", func(t *testing.T) { + type Customer struct { + Names struct { + Values map[string]string + } `key:"names"` + } + + input := map[string]any{ + "names": map[string]string{ + "19": "Tom", + }, + } + + var customer Customer + assert.ErrorIs(t, UnmarshalKey(input, &customer), errTypeMismatch) + }) +} + +func TestGetValueWithChainedKeys(t *testing.T) { + t.Run("no key", func(t *testing.T) { + _, ok := getValueWithChainedKeys(nil, []string{}) + assert.False(t, ok) + }) - var customer Customer - assert.ErrorIs(t, UnmarshalKey(input, &customer), errTypeMismatch) + t.Run("one key", func(t *testing.T) { + v, ok := getValueWithChainedKeys(mockValuerWithParent{ + value: "bar", + ok: true, + }, []string{"foo"}) + assert.True(t, ok) + assert.Equal(t, "bar", v) + }) + + t.Run("two keys", func(t *testing.T) { + v, ok := getValueWithChainedKeys(mockValuerWithParent{ + value: map[string]any{ + "bar": "baz", + }, + ok: true, + }, []string{"foo", "bar"}) + assert.True(t, ok) + assert.Equal(t, "baz", v) + }) + + t.Run("two keys not found", func(t *testing.T) { + _, ok := getValueWithChainedKeys(mockValuerWithParent{ + value: "bar", + ok: false, + }, []string{"foo", "bar"}) + assert.False(t, ok) + }) + + t.Run("two keys type mismatch", func(t *testing.T) { + _, ok := getValueWithChainedKeys(mockValuerWithParent{ + value: "bar", + ok: true, + }, []string{"foo", "bar"}) + assert.False(t, ok) + }) } func BenchmarkDefaultValue(b *testing.B) { @@ -4521,3 +4590,17 @@ func BenchmarkUnmarshal(b *testing.B) { UnmarshalKey(data, &an) } } + +type mockValuerWithParent struct { + parent valuerWithParent + value any + ok bool +} + +func (m mockValuerWithParent) Value(key string) (any, bool) { + return m.value, m.ok +} + +func (m mockValuerWithParent) Parent() valuerWithParent { + return m.parent +} diff --git a/core/proc/env_test.go b/core/proc/env_test.go index 1187ff25..0307d7cd 100644 --- a/core/proc/env_test.go +++ b/core/proc/env_test.go @@ -1,7 +1,6 @@ package proc import ( - "os" "testing" "github.com/stretchr/testify/assert" @@ -21,13 +20,11 @@ func TestEnvInt(t *testing.T) { val, ok := EnvInt("any") assert.Equal(t, 0, val) assert.False(t, ok) - err := os.Setenv("anyInt", "10") - assert.Nil(t, err) + t.Setenv("anyInt", "10") val, ok = EnvInt("anyInt") assert.Equal(t, 10, val) assert.True(t, ok) - err = os.Setenv("anyString", "a") - assert.Nil(t, err) + t.Setenv("anyString", "a") val, ok = EnvInt("anyString") assert.Equal(t, 0, val) assert.False(t, ok) diff --git a/core/stat/alert_test.go b/core/stat/alert_test.go index 3a9a84a9..0d303f1f 100644 --- a/core/stat/alert_test.go +++ b/core/stat/alert_test.go @@ -3,7 +3,6 @@ package stat import ( - "os" "strconv" "sync/atomic" "testing" @@ -12,8 +11,7 @@ import ( ) func TestReport(t *testing.T) { - os.Setenv(clusterNameKey, "test-cluster") - defer os.Unsetenv(clusterNameKey) + t.Setenv(clusterNameKey, "test-cluster") var count int32 SetReporter(func(s string) {