diff --git a/core/fx/retry.go b/core/fx/retry.go index 035be0e2..4c0c0376 100644 --- a/core/fx/retry.go +++ b/core/fx/retry.go @@ -24,27 +24,33 @@ type ( ) // DoWithRetry runs fn, and retries if failed. Default to retry 3 times. -// Note that if the fn function accesses global variables outside the function and performs modification operations, -// it is best to lock them, otherwise there may be data race issues +// Note that if the fn function accesses global variables outside the function +// and performs modification operations, it is best to lock them, +// otherwise there may be data race issues func DoWithRetry(fn func() error, opts ...RetryOption) error { - return retry(fn, opts...) + return retry(func(errChan chan error, retryCount int) { + errChan <- fn() + }, opts...) } // DoWithRetryCtx runs fn, and retries if failed. Default to retry 3 times. -// fn retryCount indicates the current number of retries,starting from 0 -// Note that if the fn function accesses global variables outside the function and performs modification operations, -// it is best to lock them, otherwise there may be data race issues -func DoWithRetryCtx(fn func(ctx context.Context, retryCount int) error, opts ...RetryOption) error { - return retry(fn, opts...) +// fn retryCount indicates the current number of retries, starting from 0 +// Note that if the fn function accesses global variables outside the function +// and performs modification operations, it is best to lock them, +// otherwise there may be data race issues +func DoWithRetryCtx(ctx context.Context, fn func(ctx context.Context, retryCount int) error, + opts ...RetryOption) error { + return retry(func(errChan chan error, retryCount int) { + errChan <- fn(ctx, retryCount) + }, opts...) } -func retry(fn interface{}, opts ...RetryOption) error { +func retry(fn func(errChan chan error, retryCount int), opts ...RetryOption) error { options := newRetryOptions() for _, opt := range opts { opt(options) } - sign := make(chan error, 1) var berr errorx.BatchError var cancelFunc context.CancelFunc ctx := context.Background() @@ -53,18 +59,12 @@ func retry(fn interface{}, opts ...RetryOption) error { defer cancelFunc() } + errChan := make(chan error, 1) for i := 0; i < options.times; i++ { - go func(retryCount int) { - switch f := fn.(type) { - case func() error: - sign <- f() - case func(ctx context.Context, retryCount int) error: - sign <- f(ctx, retryCount) - } - }(i) + go fn(errChan, i) select { - case err := <-sign: + case err := <-errChan: if err != nil { berr.Add(err) } else { @@ -109,8 +109,6 @@ func WithTimeout(timeout time.Duration) RetryOption { func newRetryOptions() *retryOptions { return &retryOptions{ - times: defaultRetryTimes, - interval: 0, - timeout: 0, + times: defaultRetryTimes, } } diff --git a/core/fx/retry_test.go b/core/fx/retry_test.go index 2686ca17..d4569dc4 100644 --- a/core/fx/retry_test.go +++ b/core/fx/retry_test.go @@ -46,7 +46,7 @@ func TestRetry(t *testing.T) { func TestRetryWithTimeout(t *testing.T) { assert.Nil(t, DoWithRetry(func() error { return nil - }, WithTimeout(time.Second*10))) + }, WithTimeout(time.Millisecond*500))) times1 := 0 assert.Nil(t, DoWithRetry(func() error { @@ -54,9 +54,9 @@ func TestRetryWithTimeout(t *testing.T) { if times1 == 1 { return errors.New("any ") } - time.Sleep(time.Second * 3) + time.Sleep(time.Millisecond * 150) return nil - }, WithTimeout(time.Second*5))) + }, WithTimeout(time.Millisecond*250))) total := defaultRetryTimes times2 := 0 @@ -65,13 +65,13 @@ func TestRetryWithTimeout(t *testing.T) { if times2 == total { return nil } - time.Sleep(time.Second) + time.Sleep(time.Millisecond * 50) return errors.New("any") - }, WithTimeout(time.Second*(time.Duration(total)+2)))) + }, WithTimeout(time.Millisecond*50*(time.Duration(total)+2)))) assert.NotNil(t, DoWithRetry(func() error { return errors.New("any") - }, WithTimeout(time.Second*5))) + }, WithTimeout(time.Millisecond*250))) } func TestRetryWithInterval(t *testing.T) { @@ -81,9 +81,9 @@ func TestRetryWithInterval(t *testing.T) { if times1 == 1 { return errors.New("any") } - time.Sleep(time.Second * 3) + time.Sleep(time.Millisecond * 150) return nil - }, WithTimeout(time.Second*5), WithInterval(time.Second*3))) + }, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150))) times2 := 0 assert.NotNil(t, DoWithRetry(func() error { @@ -91,26 +91,26 @@ func TestRetryWithInterval(t *testing.T) { if times2 == 2 { return nil } - time.Sleep(time.Second * 3) + time.Sleep(time.Millisecond * 150) return errors.New("any ") - }, WithTimeout(time.Second*5), WithInterval(time.Second*3))) + }, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150))) } func TestRetryCtx(t *testing.T) { - assert.NotNil(t, DoWithRetryCtx(func(ctx context.Context, retryCount int) error { + assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error { if retryCount == 0 { return errors.New("any") } - time.Sleep(time.Second * 3) + time.Sleep(time.Millisecond * 150) return nil - }, WithTimeout(time.Second*5), WithInterval(time.Second*3))) + }, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150))) - assert.NotNil(t, DoWithRetryCtx(func(ctx context.Context, retryCount int) error { + assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error { if retryCount == 1 { return nil } - time.Sleep(time.Second * 3) + time.Sleep(time.Millisecond * 150) return errors.New("any ") - }, WithTimeout(time.Second*5), WithInterval(time.Second*3))) + }, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150))) }