diff --git a/core/executors/periodicalexecutor.go b/core/executors/periodicalexecutor.go index be84c3b2..6ceea15c 100644 --- a/core/executors/periodicalexecutor.go +++ b/core/executors/periodicalexecutor.go @@ -3,6 +3,7 @@ package executors import ( "reflect" "sync" + "sync/atomic" "time" "github.com/tal-tech/go-zero/core/lang" @@ -35,9 +36,9 @@ type ( // avoid race condition on waitGroup when calling wg.Add/Done/Wait(...) wgBarrier syncx.Barrier confirmChan chan lang.PlaceholderType + inflight int32 guarded bool newTicker func(duration time.Duration) timex.Ticker - currTask int lock sync.Mutex } ) @@ -104,9 +105,8 @@ func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) }() if pe.container.AddTask(task) { - vals := pe.container.RemoveAll() - pe.currTask++ - return vals, true + atomic.AddInt32(&pe.inflight, 1) + return pe.container.RemoveAll(), true } return nil, false @@ -123,11 +123,9 @@ func (pe *PeriodicalExecutor) backgroundFlush() { select { case vals := <-pe.commander: commanded = true + atomic.AddInt32(&pe.inflight, -1) pe.enterExecution() pe.confirmChan <- lang.Placeholder - pe.lock.Lock() - pe.currTask-- - pe.lock.Unlock() pe.executeTasks(vals) last = timex.Now() case <-ticker.Chan(): @@ -136,18 +134,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() { } else if pe.Flush() { last = timex.Now() } else if timex.Since(last) > pe.interval*idleRound { - var exit bool = true - pe.lock.Lock() - if pe.currTask > 0 { - exit = false - } else { - pe.guarded = false - } - pe.lock.Unlock() - - if exit { - // flush again to avoid missing tasks - pe.Flush() + if pe.cleanup() { return } } @@ -156,6 +143,22 @@ func (pe *PeriodicalExecutor) backgroundFlush() { }) } +func (pe *PeriodicalExecutor) cleanup() (stop bool) { + pe.lock.Lock() + pe.guarded = false + if atomic.LoadInt32(&pe.inflight) == 0 { + stop = true + } + pe.lock.Unlock() + + if stop { + // flush again to avoid missing tasks + pe.Flush() + } + + return +} + func (pe *PeriodicalExecutor) doneExecution() { pe.waitGroup.Done() } diff --git a/core/executors/periodicalexecutor_test.go b/core/executors/periodicalexecutor_test.go index 66291057..d2ba21ae 100644 --- a/core/executors/periodicalexecutor_test.go +++ b/core/executors/periodicalexecutor_test.go @@ -140,6 +140,26 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) { assert.Equal(t, total, cnt) } +func TestPeriodicalExecutor_Deadlock(t *testing.T) { + executor := NewBulkExecutor(func(tasks []interface{}) { + }, WithBulkTasks(1), WithBulkInterval(time.Millisecond)) + for i := 0; i < 1e5; i++ { + executor.Add(1) + } +} + +func TestPeriodicalExecutor_hasTasks(t *testing.T) { + ticker := timex.NewFakeTicker() + defer ticker.Stop() + + exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil)) + exec.newTicker = func(d time.Duration) timex.Ticker { + return ticker + } + assert.False(t, exec.hasTasks(nil)) + assert.True(t, exec.hasTasks(1)) +} + // go test -benchtime 10s -bench . func BenchmarkExecutor(b *testing.B) { b.ReportAllocs() @@ -149,11 +169,3 @@ func BenchmarkExecutor(b *testing.B) { executor.Add(1) } } - -func TestPeriodicalExecutor_Deadlock(t *testing.T) { - executer := NewBulkExecutor(func(tasks []interface{}) { - }, WithBulkTasks(1), WithBulkInterval(time.Millisecond)) - for i := 0; i < 1e6; i++ { - executer.Add(1) - } -}