diff --git a/core/collection/set.go b/core/collection/set.go index 732a914f..acea9269 100644 --- a/core/collection/set.go +++ b/core/collection/set.go @@ -15,6 +15,7 @@ const ( stringType ) +// Set is not thread-safe, for concurrent use, make sure to use it with synchronization. type Set struct { data map[interface{}]lang.PlaceholderType tp int @@ -182,10 +183,7 @@ func (s *Set) add(i interface{}) { } func (s *Set) setType(i interface{}) { - if s.tp != untyped { - return - } - + // s.tp can only be untyped here switch i.(type) { case int: s.tp = intType diff --git a/core/collection/set_test.go b/core/collection/set_test.go index 0841d5cd..298e3284 100644 --- a/core/collection/set_test.go +++ b/core/collection/set_test.go @@ -5,8 +5,13 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/logx" ) +func init() { + logx.Disable() +} + func BenchmarkRawSet(b *testing.B) { m := make(map[interface{}]struct{}) for i := 0; i < b.N; i++ { @@ -147,3 +152,51 @@ func TestCount(t *testing.T) { // then assert.Equal(t, set.Count(), 3) } + +func TestKeysIntMismatch(t *testing.T) { + set := NewSet() + set.add(int64(1)) + set.add(2) + vals := set.KeysInt() + assert.EqualValues(t, []int{2}, vals) +} + +func TestKeysInt64Mismatch(t *testing.T) { + set := NewSet() + set.add(1) + set.add(int64(2)) + vals := set.KeysInt64() + assert.EqualValues(t, []int64{2}, vals) +} + +func TestKeysUintMismatch(t *testing.T) { + set := NewSet() + set.add(1) + set.add(uint(2)) + vals := set.KeysUint() + assert.EqualValues(t, []uint{2}, vals) +} + +func TestKeysUint64Mismatch(t *testing.T) { + set := NewSet() + set.add(1) + set.add(uint64(2)) + vals := set.KeysUint64() + assert.EqualValues(t, []uint64{2}, vals) +} + +func TestKeysStrMismatch(t *testing.T) { + set := NewSet() + set.add(1) + set.add("2") + vals := set.KeysStr() + assert.EqualValues(t, []string{"2"}, vals) +} + +func TestSetType(t *testing.T) { + set := NewUnmanagedSet() + set.add(1) + set.add("2") + vals := set.Keys() + assert.ElementsMatch(t, []interface{}{1, "2"}, vals) +} diff --git a/core/stores/sqlx/orm_test.go b/core/stores/sqlx/orm_test.go index d11b4b5c..74b82b38 100644 --- a/core/stores/sqlx/orm_test.go +++ b/core/stores/sqlx/orm_test.go @@ -22,6 +22,18 @@ func TestUnmarshalRowBool(t *testing.T) { }) } +func TestUnmarshalRowBoolNotSettable(t *testing.T) { + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + var value bool + assert.NotNil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(value, rows, true) + }, "select value from users where user=?", "anyone")) + }) +} + func TestUnmarshalRowInt(t *testing.T) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2") @@ -228,6 +240,40 @@ func TestUnmarshalRowStructWithTags(t *testing.T) { }) } +func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) { + var value = new(struct { + Age *int `db:"age"` + Name string `db:"name"` + }) + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + assert.NotNil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(value, rows, true) + }, "select name, age from users where user=?", "anyone")) + }) +} + +func TestUnmarshalRowStructWithTagsPtr(t *testing.T) { + var value = new(struct { + Age *int `db:"age"` + Name string `db:"name"` + }) + + runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") + mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) + + assert.Nil(t, query(db, func(rows *sql.Rows) error { + return unmarshalRow(value, rows, true) + }, "select name, age from users where user=?", "anyone")) + assert.Equal(t, "liao", value.Name) + assert.Equal(t, 5, *value.Age) + }) +} + func TestUnmarshalRowsBool(t *testing.T) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { var expect = []bool{true, false}