diff --git a/core/stores/redis/redis.go b/core/stores/redis/redis.go index cb6cbcbd..9095d7d3 100644 --- a/core/stores/redis/redis.go +++ b/core/stores/redis/redis.go @@ -1404,21 +1404,28 @@ func (s *Redis) ScanCtx(ctx context.Context, cursor uint64, match string, count } // SetBit is the implementation of redis setbit command. -func (s *Redis) SetBit(key string, offset int64, value int) error { +func (s *Redis) SetBit(key string, offset int64, value int) (int, error) { return s.SetBitCtx(context.Background(), key, offset, value) } // SetBitCtx is the implementation of redis setbit command. -func (s *Redis) SetBitCtx(ctx context.Context, key string, offset int64, value int) error { - return s.brk.DoWithAcceptable(func() error { +func (s *Redis) SetBitCtx(ctx context.Context, key string, offset int64, value int) (val int, err error) { + err = s.brk.DoWithAcceptable(func() error { conn, err := getRedis(s) if err != nil { return err } - _, err = conn.SetBit(ctx, key, offset, value).Result() - return err + v, err := conn.SetBit(ctx, key, offset, value).Result() + if err != nil { + return err + } + + val = int(v) + return nil }, acceptable) + + return } // Sscan is the implementation of redis sscan command. diff --git a/core/stores/redis/redis_test.go b/core/stores/redis/redis_test.go index fe8aa7ca..fdf90e71 100644 --- a/core/stores/redis/redis_test.go +++ b/core/stores/redis/redis_test.go @@ -387,30 +387,33 @@ func TestRedis_Mget(t *testing.T) { func TestRedis_SetBit(t *testing.T) { runOnRedis(t, func(client *Redis) { - err := New(client.Addr, badType()).SetBit("key", 1, 1) + _, err := New(client.Addr, badType()).SetBit("key", 1, 1) assert.NotNil(t, err) - err = client.SetBit("key", 1, 1) + val, err := client.SetBit("key", 1, 1) assert.Nil(t, err) + assert.Equal(t, 0, val) }) } func TestRedis_GetBit(t *testing.T) { runOnRedis(t, func(client *Redis) { - err := client.SetBit("key", 2, 1) + val, err := client.SetBit("key", 2, 1) assert.Nil(t, err) + assert.Equal(t, 0, val) _, err = New(client.Addr, badType()).GetBit("key", 2) assert.NotNil(t, err) - val, err := client.GetBit("key", 2) + v, err := client.GetBit("key", 2) assert.Nil(t, err) - assert.Equal(t, 1, val) + assert.Equal(t, 1, v) }) } func TestRedis_BitCount(t *testing.T) { runOnRedis(t, func(client *Redis) { for i := 0; i < 11; i++ { - err := client.SetBit("key", int64(i), 1) + val, err := client.SetBit("key", int64(i), 1) assert.Nil(t, err) + assert.Equal(t, 0, val) } _, err := New(client.Addr, badType()).BitCount("key", 0, -1)