refactor: simplify sqlx fail fast ping and simplify miniredis setup in test (#2897)

* chore(redistest): simplify miniredis setup in test

* refactor(sqlx): simplify sqlx fail fast ping

* chore: close connection if not available
master
cong 2 years ago committed by GitHub
parent d113e1352c
commit 64ab00e8e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1761,21 +1761,7 @@ func TestRedis_WithPass(t *testing.T) {
func runOnRedis(t *testing.T, fn func(client *Redis)) { func runOnRedis(t *testing.T, fn func(client *Redis)) {
logx.Disable() logx.Disable()
s, err := miniredis.Run() s := miniredis.RunT(t)
assert.Nil(t, err)
defer func() {
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
return nil, errors.New("should already exist")
})
if err != nil {
t.Error(err)
}
if client != nil {
_ = client.Close()
}
}()
fn(MustNewRedis(RedisConf{ fn(MustNewRedis(RedisConf{
Host: s.Addr(), Host: s.Addr(),
Type: NodeType, Type: NodeType,
@ -1785,21 +1771,7 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) {
func runOnRedisWithError(t *testing.T, fn func(client *Redis)) { func runOnRedisWithError(t *testing.T, fn func(client *Redis)) {
logx.Disable() logx.Disable()
s, err := miniredis.Run() s := miniredis.RunT(t)
assert.NoError(t, err)
defer func() {
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
return nil, errors.New("should already exist")
})
if err != nil {
t.Error(err)
}
if client != nil {
_ = client.Close()
}
}()
s.SetError("mock error") s.SetError("mock error")
fn(New(s.Addr())) fn(New(s.Addr()))
} }

@ -52,14 +52,11 @@ func TestSqlConn(t *testing.T) {
} }
func buildConn() (mock sqlmock.Sqlmock, err error) { func buildConn() (mock sqlmock.Sqlmock, err error) {
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) { connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB var db *sql.DB
var err error var err error
db, mock, err = sqlmock.New() db, mock, err = sqlmock.New()
return &pingedDB{ return db, err
DB: db,
}, err
}) })
return return
} }

@ -3,7 +3,6 @@ package sqlx
import ( import (
"database/sql" "database/sql"
"io" "io"
"sync"
"time" "time"
"github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/syncx"
@ -17,43 +16,29 @@ const (
var connManager = syncx.NewResourceManager() var connManager = syncx.NewResourceManager()
type pingedDB struct { func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
*sql.DB
once sync.Once
}
func getCachedSqlConn(driverName, server string) (*pingedDB, error) {
val, err := connManager.GetResource(server, func() (io.Closer, error) { val, err := connManager.GetResource(server, func() (io.Closer, error) {
conn, err := newDBConnection(driverName, server) conn, err := newDBConnection(driverName, server)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &pingedDB{ return conn, nil
DB: conn,
}, nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return val.(*pingedDB), nil return val.(*sql.DB), nil
} }
func getSqlConn(driverName, server string) (*sql.DB, error) { func getSqlConn(driverName, server string) (*sql.DB, error) {
pdb, err := getCachedSqlConn(driverName, server) conn, err := getCachedSqlConn(driverName, server)
if err != nil { if err != nil {
return nil, err return nil, err
} }
pdb.once.Do(func() { return conn, nil
err = pdb.Ping()
})
if err != nil {
return nil, err
}
return pdb.DB, nil
} }
func newDBConnection(driverName, datasource string) (*sql.DB, error) { func newDBConnection(driverName, datasource string) (*sql.DB, error) {
@ -70,5 +55,10 @@ func newDBConnection(driverName, datasource string) (*sql.DB, error) {
conn.SetMaxOpenConns(maxOpenConns) conn.SetMaxOpenConns(maxOpenConns)
conn.SetConnMaxLifetime(maxLifetime) conn.SetConnMaxLifetime(maxLifetime)
if err := conn.Ping(); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil return conn, nil
} }

Loading…
Cancel
Save