diff --git a/core/collection/timingwheel.go b/core/collection/timingwheel.go index 9fb9ebf4..945d4e19 100644 --- a/core/collection/timingwheel.go +++ b/core/collection/timingwheel.go @@ -204,6 +204,7 @@ func (tw *TimingWheel) removeTask(key interface{}) { timer := val.(*positionEntry) timer.item.removed = true + tw.timers.Del(key) } func (tw *TimingWheel) run() { @@ -248,7 +249,6 @@ func (tw *TimingWheel) scanAndRunTasks(l *list.List) { if task.removed { next := e.Next() l.Remove(e) - tw.timers.Del(task.key) e = next continue } else if task.circle > 0 { @@ -301,6 +301,7 @@ func (tw *TimingWheel) setTask(task *timingEntry) { func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) { if val, ok := tw.timers.Get(task.key); ok { timer := val.(*positionEntry) + timer.item = task timer.pos = pos } else { tw.timers.Set(task.key, &positionEntry{ diff --git a/core/collection/timingwheel_test.go b/core/collection/timingwheel_test.go index d85955fa..6cd9aabe 100644 --- a/core/collection/timingwheel_test.go +++ b/core/collection/timingwheel_test.go @@ -594,6 +594,31 @@ func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) { } } +func TestMoveAndRemoveTask(t *testing.T) { + ticker := timex.NewFakeTicker() + tick := func(v int) { + for i := 0; i < v; i++ { + ticker.Tick() + } + } + var keys []int + tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) { + assert.Equal(t, "any", k) + assert.Equal(t, 3, v.(int)) + keys = append(keys, v.(int)) + ticker.Done() + }, ticker) + defer tw.Stop() + tw.SetTimer("any", 3, testStep*8) + tick(6) + tw.MoveTimer("any", testStep*7) + tick(3) + tw.RemoveTimer("any") + tick(30) + time.Sleep(time.Millisecond) + assert.Equal(t, 0, len(keys)) +} + func BenchmarkTimingWheel(b *testing.B) { b.ReportAllocs()