diff --git a/core/stores/mon/clientmanager.go b/core/stores/mon/clientmanager.go index 5a6c0278..24b8f4cd 100644 --- a/core/stores/mon/clientmanager.go +++ b/core/stores/mon/clientmanager.go @@ -3,15 +3,12 @@ package mon import ( "context" "io" - "time" "github.com/zeromicro/go-zero/core/syncx" "go.mongodb.org/mongo-driver/mongo" mopt "go.mongodb.org/mongo-driver/mongo/options" ) -const defaultTimeout = time.Second - var clientManager = syncx.NewResourceManager() // ClosableClient wraps *mongo.Client and provides a Close method. @@ -30,9 +27,20 @@ func Inject(key string, client *mongo.Client) { clientManager.Inject(key, &ClosableClient{client}) } -func getClient(url string) (*mongo.Client, error) { +func getClient(url string, opts ...Option) (*mongo.Client, error) { val, err := clientManager.GetResource(url, func() (io.Closer, error) { - cli, err := mongo.Connect(context.Background(), mopt.Client().ApplyURI(url)) + o := mopt.Client().ApplyURI(url) + opts = append([]Option{defaultTimeoutOption()}, opts...) + for _, opt := range opts { + opt(o) + } + + cli, err := mongo.Connect(context.Background(), o) + if err != nil { + return nil, err + } + + err = cli.Ping(context.Background(), nil) if err != nil { return nil, err } diff --git a/core/stores/mon/model.go b/core/stores/mon/model.go index 1f27ff72..d7ec04f2 100644 --- a/core/stores/mon/model.go +++ b/core/stores/mon/model.go @@ -48,7 +48,7 @@ func MustNewModel(uri, db, collection string, opts ...Option) *Model { // NewModel returns a Model. func NewModel(uri, db, collection string, opts ...Option) (*Model, error) { - cli, err := getClient(uri) + cli, err := getClient(uri, opts...) if err != nil { return nil, err } diff --git a/core/stores/mon/options.go b/core/stores/mon/options.go index 31f9ecdb..ca13bff8 100644 --- a/core/stores/mon/options.go +++ b/core/stores/mon/options.go @@ -4,14 +4,15 @@ import ( "time" "github.com/zeromicro/go-zero/core/syncx" + mopt "go.mongodb.org/mongo-driver/mongo/options" ) +const defaultTimeout = time.Second * 3 + var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) type ( - options struct { - timeout time.Duration - } + options = mopt.ClientOptions // Option defines the method to customize a mongo model. Option func(opts *options) @@ -22,8 +23,15 @@ func SetSlowThreshold(threshold time.Duration) { slowThreshold.Set(threshold) } -func defaultOptions() *options { - return &options{ - timeout: defaultTimeout, +func defaultTimeoutOption() Option { + return func(opts *options) { + opts.SetTimeout(defaultTimeout) + } +} + +// WithTimeout set the mon client operation timeout. +func WithTimeout(timeout time.Duration) Option { + return func(opts *options) { + opts.SetTimeout(timeout) } } diff --git a/core/stores/mon/options_test.go b/core/stores/mon/options_test.go index 4cefad2a..5582f492 100644 --- a/core/stores/mon/options_test.go +++ b/core/stores/mon/options_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/stretchr/testify/assert" + mopt "go.mongodb.org/mongo-driver/mongo/options" ) func TestSetSlowThreshold(t *testing.T) { @@ -13,6 +14,14 @@ func TestSetSlowThreshold(t *testing.T) { assert.Equal(t, time.Second, slowThreshold.Load()) } -func TestDefaultOptions(t *testing.T) { - assert.Equal(t, defaultTimeout, defaultOptions().timeout) +func Test_defaultTimeoutOption(t *testing.T) { + opts := mopt.Client() + defaultTimeoutOption()(opts) + assert.Equal(t, defaultTimeout, *opts.Timeout) +} + +func TestWithTimeout(t *testing.T) { + opts := mopt.Client() + WithTimeout(time.Second)(opts) + assert.Equal(t, time.Second, *opts.Timeout) }