From 64ab00e8e3ac21366e6c817d4e1e6356c3bd5082 Mon Sep 17 00:00:00 2001 From: cong Date: Sun, 19 Feb 2023 17:18:19 +0800 Subject: [PATCH] refactor: simplify sqlx fail fast ping and simplify miniredis setup in test (#2897) * chore(redistest): simplify miniredis setup in test * refactor(sqlx): simplify sqlx fail fast ping * chore: close connection if not available --- core/stores/redis/redis_test.go | 32 ++------------------------------ core/stores/sqlx/sqlconn_test.go | 7 ++----- core/stores/sqlx/sqlmanager.go | 30 ++++++++++-------------------- 3 files changed, 14 insertions(+), 55 deletions(-) diff --git a/core/stores/redis/redis_test.go b/core/stores/redis/redis_test.go index 7a83b6f4..88ddd48b 100644 --- a/core/stores/redis/redis_test.go +++ b/core/stores/redis/redis_test.go @@ -1761,21 +1761,7 @@ func TestRedis_WithPass(t *testing.T) { func runOnRedis(t *testing.T, fn func(client *Redis)) { logx.Disable() - s, err := miniredis.Run() - assert.Nil(t, err) - defer func() { - client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) { - return nil, errors.New("should already exist") - }) - if err != nil { - t.Error(err) - } - - if client != nil { - _ = client.Close() - } - }() - + s := miniredis.RunT(t) fn(MustNewRedis(RedisConf{ Host: s.Addr(), Type: NodeType, @@ -1785,21 +1771,7 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) { func runOnRedisWithError(t *testing.T, fn func(client *Redis)) { logx.Disable() - s, err := miniredis.Run() - assert.NoError(t, err) - defer func() { - client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) { - return nil, errors.New("should already exist") - }) - if err != nil { - t.Error(err) - } - - if client != nil { - _ = client.Close() - } - }() - + s := miniredis.RunT(t) s.SetError("mock error") fn(New(s.Addr())) } diff --git a/core/stores/sqlx/sqlconn_test.go b/core/stores/sqlx/sqlconn_test.go index 8c68c5e6..50dd23c8 100644 --- a/core/stores/sqlx/sqlconn_test.go +++ b/core/stores/sqlx/sqlconn_test.go @@ -52,14 +52,11 @@ func TestSqlConn(t *testing.T) { } func buildConn() (mock sqlmock.Sqlmock, err error) { - _, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) { + connManager.GetResource(mockedDatasource, func() (io.Closer, error) { var db *sql.DB var err error db, mock, err = sqlmock.New() - return &pingedDB{ - DB: db, - }, err + return db, err }) - return } diff --git a/core/stores/sqlx/sqlmanager.go b/core/stores/sqlx/sqlmanager.go index b2c1e169..db652edc 100644 --- a/core/stores/sqlx/sqlmanager.go +++ b/core/stores/sqlx/sqlmanager.go @@ -3,7 +3,6 @@ package sqlx import ( "database/sql" "io" - "sync" "time" "github.com/zeromicro/go-zero/core/syncx" @@ -17,43 +16,29 @@ const ( var connManager = syncx.NewResourceManager() -type pingedDB struct { - *sql.DB - once sync.Once -} - -func getCachedSqlConn(driverName, server string) (*pingedDB, error) { +func getCachedSqlConn(driverName, server string) (*sql.DB, error) { val, err := connManager.GetResource(server, func() (io.Closer, error) { conn, err := newDBConnection(driverName, server) if err != nil { return nil, err } - return &pingedDB{ - DB: conn, - }, nil + return conn, nil }) if err != nil { return nil, err } - return val.(*pingedDB), nil + return val.(*sql.DB), nil } func getSqlConn(driverName, server string) (*sql.DB, error) { - pdb, err := getCachedSqlConn(driverName, server) + conn, err := getCachedSqlConn(driverName, server) if err != nil { return nil, err } - pdb.once.Do(func() { - err = pdb.Ping() - }) - if err != nil { - return nil, err - } - - return pdb.DB, nil + return conn, nil } func newDBConnection(driverName, datasource string) (*sql.DB, error) { @@ -70,5 +55,10 @@ func newDBConnection(driverName, datasource string) (*sql.DB, error) { conn.SetMaxOpenConns(maxOpenConns) conn.SetConnMaxLifetime(maxLifetime) + if err := conn.Ping(); err != nil { + _ = conn.Close() + return nil, err + } + return conn, nil }