diff --git a/core/collection/timingwheel.go b/core/collection/timingwheel.go index 2fe836ba..38953c8b 100644 --- a/core/collection/timingwheel.go +++ b/core/collection/timingwheel.go @@ -2,6 +2,7 @@ package collection import ( "container/list" + "errors" "fmt" "time" @@ -12,6 +13,11 @@ import ( const drainWorkers = 8 +var ( + ErrClosed = errors.New("TimingWheel is closed already") + ErrArgument = errors.New("incorrect task argument") +) + type ( // Execute defines the method to execute the task. Execute func(key, value interface{}) @@ -89,43 +95,63 @@ func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execu } // Drain drains all items and executes them. -func (tw *TimingWheel) Drain(fn func(key, value interface{})) { - tw.drainChannel <- fn +func (tw *TimingWheel) Drain(fn func(key, value interface{})) error { + select { + case tw.drainChannel <- fn: + return nil + case <-tw.stopChannel: + return ErrClosed + } } // MoveTimer moves the task with the given key to the given delay. -func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) { +func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) error { if delay <= 0 || key == nil { - return + return ErrArgument } - tw.moveChannel <- baseEntry{ + select { + case tw.moveChannel <- baseEntry{ delay: delay, key: key, + }: + return nil + case <-tw.stopChannel: + return ErrClosed } } // RemoveTimer removes the task with the given key. -func (tw *TimingWheel) RemoveTimer(key interface{}) { +func (tw *TimingWheel) RemoveTimer(key interface{}) error { if key == nil { - return + return ErrArgument } - tw.removeChannel <- key + select { + case tw.removeChannel <- key: + return nil + case <-tw.stopChannel: + return ErrClosed + } } // SetTimer sets the task value with the given key to the delay. -func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) { +func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) error { if delay <= 0 || key == nil { - return + return ErrArgument } - tw.setChannel <- timingEntry{ + select { + case tw.setChannel <- timingEntry{ baseEntry: baseEntry{ delay: delay, key: key, }, value: value, + }: + return nil + case <-tw.stopChannel: + return ErrClosed } } diff --git a/core/collection/timingwheel_test.go b/core/collection/timingwheel_test.go index 93e669b2..6d51a094 100644 --- a/core/collection/timingwheel_test.go +++ b/core/collection/timingwheel_test.go @@ -28,7 +28,6 @@ func TestTimingWheel_Drain(t *testing.T) { ticker := timex.NewFakeTicker() tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) { }, ticker) - defer tw.Stop() tw.SetTimer("first", 3, testStep*4) tw.SetTimer("second", 5, testStep*7) tw.SetTimer("third", 7, testStep*7) @@ -56,6 +55,8 @@ func TestTimingWheel_Drain(t *testing.T) { }) time.Sleep(time.Millisecond * 100) assert.Equal(t, 0, count) + tw.Stop() + assert.Equal(t, ErrClosed, tw.Drain(func(key, value interface{}) {})) } func TestTimingWheel_SetTimerSoon(t *testing.T) { @@ -102,6 +103,13 @@ func TestTimingWheel_SetTimerWrongDelay(t *testing.T) { }) } +func TestTimingWheel_SetTimerAfterClose(t *testing.T) { + ticker := timex.NewFakeTicker() + tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker) + tw.Stop() + assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep)) +} + func TestTimingWheel_MoveTimer(t *testing.T) { run := syncx.NewAtomicBool() ticker := timex.NewFakeTicker() @@ -111,7 +119,6 @@ func TestTimingWheel_MoveTimer(t *testing.T) { assert.Equal(t, 3, v.(int)) ticker.Done() }, ticker) - defer tw.Stop() tw.SetTimer("any", 3, testStep*4) tw.MoveTimer("any", testStep*7) tw.MoveTimer("any", -testStep) @@ -125,6 +132,8 @@ func TestTimingWheel_MoveTimer(t *testing.T) { } assert.Nil(t, ticker.Wait(waitTime)) assert.True(t, run.True()) + tw.Stop() + assert.Equal(t, ErrClosed, tw.MoveTimer("any", time.Millisecond)) } func TestTimingWheel_MoveTimerSoon(t *testing.T) { @@ -175,6 +184,7 @@ func TestTimingWheel_RemoveTimer(t *testing.T) { ticker.Tick() } tw.Stop() + assert.Equal(t, ErrClosed, tw.RemoveTimer("any")) } func TestTimingWheel_SetTimer(t *testing.T) {