diff --git a/core/stores/sqlc/cachedsql.go b/core/stores/sqlc/cachedsql.go index 3e3b3aa2..9a09802b 100644 --- a/core/stores/sqlc/cachedsql.go +++ b/core/stores/sqlc/cachedsql.go @@ -1,6 +1,7 @@ package sqlc import ( + "context" "database/sql" "time" @@ -18,19 +19,27 @@ var ( ErrNotFound = sqlx.ErrNotFound // can't use one SingleFlight per conn, because multiple conns may share the same cache key. - exclusiveCalls = syncx.NewSingleFlight() - stats = cache.NewStat("sqlc") + singleFlights = syncx.NewSingleFlight() + stats = cache.NewStat("sqlc") ) type ( // ExecFn defines the sql exec method. ExecFn func(conn sqlx.SqlConn) (sql.Result, error) + // ExecCtxFn defines the sql exec method. + ExecCtxFn func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) // IndexQueryFn defines the query method that based on unique indexes. IndexQueryFn func(conn sqlx.SqlConn, v interface{}) (interface{}, error) + // IndexQueryCtxFn defines the query method that based on unique indexes. + IndexQueryCtxFn func(ctx context.Context, conn sqlx.SqlConn, v interface{}) (interface{}, error) // PrimaryQueryFn defines the query method that based on primary keys. PrimaryQueryFn func(conn sqlx.SqlConn, v, primary interface{}) error + // PrimaryQueryCtxFn defines the query method that based on primary keys. + PrimaryQueryCtxFn func(ctx context.Context, conn sqlx.SqlConn, v, primary interface{}) error // QueryFn defines the query method. QueryFn func(conn sqlx.SqlConn, v interface{}) error + // QueryCtxFn defines the query method. + QueryCtxFn func(ctx context.Context, conn sqlx.SqlConn, v interface{}) error // A CachedConn is a DB connection with cache capability. CachedConn struct { @@ -41,7 +50,7 @@ type ( // NewConn returns a CachedConn with a redis cluster cache. func NewConn(db sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) CachedConn { - cc := cache.New(c, exclusiveCalls, stats, sql.ErrNoRows, opts...) + cc := cache.New(c, singleFlights, stats, sql.ErrNoRows, opts...) return NewConnWithCache(db, cc) } @@ -55,28 +64,46 @@ func NewConnWithCache(db sqlx.SqlConn, c cache.Cache) CachedConn { // NewNodeConn returns a CachedConn with a redis node cache. func NewNodeConn(db sqlx.SqlConn, rds *redis.Redis, opts ...cache.Option) CachedConn { - c := cache.NewNode(rds, exclusiveCalls, stats, sql.ErrNoRows, opts...) + c := cache.NewNode(rds, singleFlights, stats, sql.ErrNoRows, opts...) return NewConnWithCache(db, c) } // DelCache deletes cache with keys. func (cc CachedConn) DelCache(keys ...string) error { - return cc.cache.Del(keys...) + return cc.DelCacheCtx(context.Background(), keys...) +} + +// DelCacheCtx deletes cache with keys. +func (cc CachedConn) DelCacheCtx(ctx context.Context, keys ...string) error { + return cc.cache.DelCtx(ctx, keys...) } // GetCache unmarshals cache with given key into v. func (cc CachedConn) GetCache(key string, v interface{}) error { - return cc.cache.Get(key, v) + return cc.GetCacheCtx(context.Background(), key, v) +} + +// GetCacheCtx unmarshals cache with given key into v. +func (cc CachedConn) GetCacheCtx(ctx context.Context, key string, v interface{}) error { + return cc.cache.GetCtx(ctx, key, v) } // Exec runs given exec on given keys, and returns execution result. func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) { - res, err := exec(cc.db) + execCtx := func(_ context.Context, conn sqlx.SqlConn) (sql.Result, error) { + return exec(conn) + } + return cc.ExecCtx(context.Background(), execCtx, keys...) +} + +// ExecCtx runs given exec on given keys, and returns execution result. +func (cc CachedConn) ExecCtx(ctx context.Context, exec ExecCtxFn, keys ...string) (sql.Result, error) { + res, err := exec(ctx, cc.db) if err != nil { return nil, err } - if err := cc.DelCache(keys...); err != nil { + if err := cc.DelCacheCtx(ctx, keys...); err != nil { return nil, err } @@ -85,31 +112,61 @@ func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) { // ExecNoCache runs exec with given sql statement, without affecting cache. func (cc CachedConn) ExecNoCache(q string, args ...interface{}) (sql.Result, error) { - return cc.db.Exec(q, args...) + return cc.ExecNoCacheCtx(context.Background(), q, args...) +} + +// ExecNoCacheCtx runs exec with given sql statement, without affecting cache. +func (cc CachedConn) ExecNoCacheCtx(ctx context.Context, q string, args ...interface{}) ( + sql.Result, error) { + return cc.db.ExecCtx(ctx, q, args...) } // QueryRow unmarshals into v with given key and query func. func (cc CachedConn) QueryRow(v interface{}, key string, query QueryFn) error { - return cc.cache.Take(v, key, func(v interface{}) error { - return query(cc.db, v) + queryCtx := func(_ context.Context, conn sqlx.SqlConn, v interface{}) error { + return query(conn, v) + } + return cc.QueryRowCtx(context.Background(), v, key, queryCtx) +} + +// QueryRowCtx unmarshals into v with given key and query func. +func (cc CachedConn) QueryRowCtx(ctx context.Context, v interface{}, key string, query QueryCtxFn) error { + return cc.cache.TakeCtx(ctx, v, key, func(v interface{}) error { + return query(ctx, cc.db, v) }) } // QueryRowIndex unmarshals into v with given key. func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary interface{}) string, indexQuery IndexQueryFn, primaryQuery PrimaryQueryFn) error { + indexQueryCtx := func(_ context.Context, conn sqlx.SqlConn, v interface{}) (interface{}, error) { + return indexQuery(conn, v) + } + primaryQueryCtx := func(_ context.Context, conn sqlx.SqlConn, v, primary interface{}) error { + return primaryQuery(conn, v, primary) + } + + return cc.QueryRowIndexCtx(context.Background(), v, key, keyer, indexQueryCtx, primaryQueryCtx) +} + +// QueryRowIndexCtx unmarshals into v with given key. +func (cc CachedConn) QueryRowIndexCtx(ctx context.Context, v interface{}, key string, + keyer func(primary interface{}) string, indexQuery IndexQueryCtxFn, + primaryQuery PrimaryQueryCtxFn) error { var primaryKey interface{} var found bool - if err := cc.cache.TakeWithExpire(&primaryKey, key, func(val interface{}, expire time.Duration) (err error) { - primaryKey, err = indexQuery(cc.db, v) - if err != nil { - return - } - - found = true - return cc.cache.SetWithExpire(keyer(primaryKey), v, expire+cacheSafeGapBetweenIndexAndPrimary) - }); err != nil { + if err := cc.cache.TakeWithExpireCtx(ctx, &primaryKey, key, + func(val interface{}, expire time.Duration) (err error) { + primaryKey, err = indexQuery(ctx, cc.db, v) + if err != nil { + return + } + + found = true + return cc.cache.SetWithExpireCtx(ctx, keyer(primaryKey), v, + expire+cacheSafeGapBetweenIndexAndPrimary) + }); err != nil { return err } @@ -117,28 +174,54 @@ func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary return nil } - return cc.cache.Take(v, keyer(primaryKey), func(v interface{}) error { - return primaryQuery(cc.db, v, primaryKey) + return cc.cache.TakeCtx(ctx, v, keyer(primaryKey), func(v interface{}) error { + return primaryQuery(ctx, cc.db, v, primaryKey) }) } // QueryRowNoCache unmarshals into v with given statement. func (cc CachedConn) QueryRowNoCache(v interface{}, q string, args ...interface{}) error { - return cc.db.QueryRow(v, q, args...) + return cc.QueryRowNoCacheCtx(context.Background(), v, q, args...) +} + +// QueryRowNoCacheCtx unmarshals into v with given statement. +func (cc CachedConn) QueryRowNoCacheCtx(ctx context.Context, v interface{}, q string, + args ...interface{}) error { + return cc.db.QueryRowCtx(ctx, v, q, args...) } // QueryRowsNoCache unmarshals into v with given statement. // It doesn't use cache, because it might cause consistency problem. func (cc CachedConn) QueryRowsNoCache(v interface{}, q string, args ...interface{}) error { - return cc.db.QueryRows(v, q, args...) + return cc.QueryRowsNoCacheCtx(context.Background(), v, q, args...) +} + +// QueryRowsNoCacheCtx unmarshals into v with given statement. +// It doesn't use cache, because it might cause consistency problem. +func (cc CachedConn) QueryRowsNoCacheCtx(ctx context.Context, v interface{}, q string, + args ...interface{}) error { + return cc.db.QueryRowsCtx(ctx, v, q, args...) } // SetCache sets v into cache with given key. -func (cc CachedConn) SetCache(key string, v interface{}) error { - return cc.cache.Set(key, v) +func (cc CachedConn) SetCache(key string, val interface{}) error { + return cc.SetCacheCtx(context.Background(), key, val) +} + +// SetCacheCtx sets v into cache with given key. +func (cc CachedConn) SetCacheCtx(ctx context.Context, key string, val interface{}) error { + return cc.cache.SetCtx(ctx, key, val) } // Transact runs given fn in transaction mode. func (cc CachedConn) Transact(fn func(sqlx.Session) error) error { - return cc.db.Transact(fn) + fnCtx := func(_ context.Context, session sqlx.Session) error { + return fn(session) + } + return cc.TransactCtx(context.Background(), fnCtx) +} + +// TransactCtx runs given fn in transaction mode. +func (cc CachedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error { + return cc.db.TransactCtx(ctx, fn) } diff --git a/core/stores/sqlc/cachedsql_test.go b/core/stores/sqlc/cachedsql_test.go index 3652126c..f42e49ba 100644 --- a/core/stores/sqlc/cachedsql_test.go +++ b/core/stores/sqlc/cachedsql_test.go @@ -1,6 +1,7 @@ package sqlc import ( + "context" "database/sql" "encoding/json" "errors" @@ -568,7 +569,7 @@ func TestNewConnWithCache(t *testing.T) { defer clean() var conn trackedConn - c := NewConnWithCache(&conn, cache.NewNode(r, exclusiveCalls, stats, sql.ErrNoRows)) + c := NewConnWithCache(&conn, cache.NewNode(r, singleFlights, stats, sql.ErrNoRows)) _, err = c.ExecNoCache("delete from user_table where id='kevin'") assert.Nil(t, err) assert.True(t, conn.execValue) @@ -585,6 +586,30 @@ type dummySqlConn struct { queryRow func(interface{}, string, ...interface{}) error } +func (d dummySqlConn) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return nil, nil +} + +func (d dummySqlConn) PrepareCtx(ctx context.Context, query string) (sqlx.StmtSession, error) { + return nil, nil +} + +func (d dummySqlConn) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { + return nil +} + +func (d dummySqlConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { + return nil +} + +func (d dummySqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { + return nil +} + +func (d dummySqlConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error { + return nil +} + func (d dummySqlConn) Exec(query string, args ...interface{}) (sql.Result, error) { return nil, nil } @@ -594,6 +619,10 @@ func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) { } func (d dummySqlConn) QueryRow(v interface{}, query string, args ...interface{}) error { + return d.QueryRowCtx(context.Background(), v, query, args...) +} + +func (d dummySqlConn) QueryRowCtx(_ context.Context, v interface{}, query string, args ...interface{}) error { if d.queryRow != nil { return d.queryRow(v, query, args...) } @@ -628,13 +657,21 @@ type trackedConn struct { } func (c *trackedConn) Exec(query string, args ...interface{}) (sql.Result, error) { + return c.ExecCtx(context.Background(), query, args...) +} + +func (c *trackedConn) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { c.execValue = true - return c.dummySqlConn.Exec(query, args...) + return c.dummySqlConn.ExecCtx(ctx, query, args...) } func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{}) error { + return c.QueryRowsCtx(context.Background(), v, query, args...) +} + +func (c *trackedConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { c.queryRowsValue = true - return c.dummySqlConn.QueryRows(v, query, args...) + return c.dummySqlConn.QueryRowsCtx(ctx, v, query, args...) } func (c *trackedConn) RawDB() (*sql.DB, error) { @@ -642,6 +679,12 @@ func (c *trackedConn) RawDB() (*sql.DB, error) { } func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error { + return c.TransactCtx(context.Background(), func(_ context.Context, session sqlx.Session) error { + return fn(session) + }) +} + +func (c *trackedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error { c.transactValue = true - return c.dummySqlConn.Transact(fn) + return c.dummySqlConn.TransactCtx(ctx, fn) } diff --git a/core/stores/sqlx/bulkinserter_test.go b/core/stores/sqlx/bulkinserter_test.go index 7a1e1a0d..1c16a5b0 100644 --- a/core/stores/sqlx/bulkinserter_test.go +++ b/core/stores/sqlx/bulkinserter_test.go @@ -1,6 +1,7 @@ package sqlx import ( + "context" "database/sql" "errors" "strconv" @@ -17,12 +18,40 @@ type mockedConn struct { execErr error } -func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) { +func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...interface{}) (sql.Result, error) { c.query = query c.args = args return nil, c.execErr } +func (c *mockedConn) PrepareCtx(ctx context.Context, query string) (StmtSession, error) { + panic("implement me") +} + +func (c *mockedConn) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { + panic("implement me") +} + +func (c *mockedConn) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { + panic("implement me") +} + +func (c *mockedConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { + panic("implement me") +} + +func (c *mockedConn) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { + panic("implement me") +} + +func (c *mockedConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error { + panic("should not called") +} + +func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) { + return c.ExecCtx(context.Background(), query, args...) +} + func (c *mockedConn) Prepare(query string) (StmtSession, error) { panic("should not called") } diff --git a/core/stores/sqlx/orm_test.go b/core/stores/sqlx/orm_test.go index 2dbfa516..a629d43d 100644 --- a/core/stores/sqlx/orm_test.go +++ b/core/stores/sqlx/orm_test.go @@ -1,6 +1,7 @@ package sqlx import ( + "context" "database/sql" "errors" "testing" @@ -16,7 +17,7 @@ func TestUnmarshalRowBool(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value bool - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.True(t, value) @@ -29,7 +30,7 @@ func TestUnmarshalRowBoolNotSettable(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value bool - assert.NotNil(t, query(db, func(rows *sql.Rows) error { + assert.NotNil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(value, rows, true) }, "select value from users where user=?", "anyone")) }) @@ -41,7 +42,7 @@ func TestUnmarshalRowInt(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value int - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, 2, value) @@ -54,7 +55,7 @@ func TestUnmarshalRowInt8(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value int8 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, int8(3), value) @@ -67,7 +68,7 @@ func TestUnmarshalRowInt16(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value int16 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.Equal(t, int16(4), value) @@ -80,7 +81,7 @@ func TestUnmarshalRowInt32(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value int32 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.Equal(t, int32(5), value) @@ -93,7 +94,7 @@ func TestUnmarshalRowInt64(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value int64 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, int64(6), value) @@ -106,7 +107,7 @@ func TestUnmarshalRowUint(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value uint - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, uint(2), value) @@ -119,7 +120,7 @@ func TestUnmarshalRowUint8(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value uint8 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, uint8(3), value) @@ -132,7 +133,7 @@ func TestUnmarshalRowUint16(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value uint16 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, uint16(4), value) @@ -145,7 +146,7 @@ func TestUnmarshalRowUint32(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value uint32 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, uint32(5), value) @@ -158,7 +159,7 @@ func TestUnmarshalRowUint64(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value uint64 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, uint16(6), value) @@ -171,7 +172,7 @@ func TestUnmarshalRowFloat32(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value float32 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, float32(7), value) @@ -184,7 +185,7 @@ func TestUnmarshalRowFloat64(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value float64 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, float64(8), value) @@ -198,7 +199,7 @@ func TestUnmarshalRowString(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value string - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -215,7 +216,7 @@ func TestUnmarshalRowStruct(t *testing.T) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(value, rows, true) }, "select name, age from users where user=?", "anyone")) assert.Equal(t, "liao", value.Name) @@ -233,7 +234,7 @@ func TestUnmarshalRowStructWithTags(t *testing.T) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(value, rows, true) }, "select name, age from users where user=?", "anyone")) assert.Equal(t, "liao", value.Name) @@ -251,7 +252,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) { rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) - assert.NotNil(t, query(db, func(rows *sql.Rows) error { + assert.NotNil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(value, rows, true) }, "select name, age from users where user=?", "anyone")) }) @@ -264,7 +265,7 @@ func TestUnmarshalRowsBool(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []bool - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -278,7 +279,7 @@ func TestUnmarshalRowsInt(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []int - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -292,7 +293,7 @@ func TestUnmarshalRowsInt8(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []int8 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -306,7 +307,7 @@ func TestUnmarshalRowsInt16(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []int16 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -320,7 +321,7 @@ func TestUnmarshalRowsInt32(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []int32 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -334,7 +335,7 @@ func TestUnmarshalRowsInt64(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []int64 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -348,7 +349,7 @@ func TestUnmarshalRowsUint(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []uint - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -362,7 +363,7 @@ func TestUnmarshalRowsUint8(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []uint8 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -376,7 +377,7 @@ func TestUnmarshalRowsUint16(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []uint16 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -390,7 +391,7 @@ func TestUnmarshalRowsUint32(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []uint32 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -404,7 +405,7 @@ func TestUnmarshalRowsUint64(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []uint64 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -418,7 +419,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []float32 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -432,7 +433,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []float64 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -446,7 +447,7 @@ func TestUnmarshalRowsString(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []string - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -462,7 +463,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*bool - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -478,7 +479,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*int - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -494,7 +495,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*int8 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -510,7 +511,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*int16 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -526,7 +527,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*int32 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -542,7 +543,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*int64 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -558,7 +559,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*uint - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -574,7 +575,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*uint8 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -590,7 +591,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*uint16 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -606,7 +607,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*uint32 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -622,7 +623,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*uint64 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -638,7 +639,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*float32 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -654,7 +655,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*float64 - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -670,7 +671,7 @@ func TestUnmarshalRowsStringPtr(t *testing.T) { mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) var value []*string - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select value from users where user=?", "anyone")) assert.EqualValues(t, expect, value) @@ -699,7 +700,7 @@ func TestUnmarshalRowsStruct(t *testing.T) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select name, age from users where user=?", "anyone")) @@ -739,7 +740,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) { rs := sqlmock.NewRows([]string{"name", "value"}).AddRow( "first", "firstnullstring").AddRow("second", nil) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select name, age from users where user=?", "anyone")) @@ -773,7 +774,7 @@ func TestUnmarshalRowsStructWithTags(t *testing.T) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select name, age from users where user=?", "anyone")) @@ -814,7 +815,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select name, age, value from users where user=?", "anyone")) @@ -856,7 +857,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T) runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select name, age, value from users where user=?", "anyone")) @@ -890,7 +891,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select name, age from users where user=?", "anyone")) @@ -923,7 +924,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select name, age from users where user=?", "anyone")) @@ -956,7 +957,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) }, "select name, age from users where user=?", "anyone")) @@ -976,7 +977,7 @@ func TestCommonSqlConn_QueryRowOptional(t *testing.T) { User string `db:"user"` Age int `db:"age"` } - assert.Nil(t, query(db, func(rows *sql.Rows) error { + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRow(&r, rows, false) }, "select age from users where user=?", "anyone")) assert.Empty(t, r.User) @@ -1027,7 +1028,7 @@ func TestUnmarshalRowError(t *testing.T) { User string `db:"user"` Age int `db:"age"` } - test.validate(query(db, func(rows *sql.Rows) error { + test.validate(query(context.Background(), db, func(rows *sql.Rows) error { scanner := mockedScanner{ colErr: test.colErr, scanErr: test.scanErr, diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go index 9d4c529e..e2b017d2 100644 --- a/core/stores/sqlx/sqlconn.go +++ b/core/stores/sqlx/sqlconn.go @@ -1,6 +1,7 @@ package sqlx import ( + "context" "database/sql" "github.com/zeromicro/go-zero/core/breaker" @@ -14,11 +15,17 @@ type ( // Session stands for raw connections or transaction sessions Session interface { Exec(query string, args ...interface{}) (sql.Result, error) + ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) Prepare(query string) (StmtSession, error) + PrepareCtx(ctx context.Context, query string) (StmtSession, error) QueryRow(v interface{}, query string, args ...interface{}) error + QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error QueryRowPartial(v interface{}, query string, args ...interface{}) error + QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error QueryRows(v interface{}, query string, args ...interface{}) error + QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error QueryRowsPartial(v interface{}, query string, args ...interface{}) error + QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error } // SqlConn only stands for raw connections, so Transact method can be called. @@ -27,7 +34,8 @@ type ( // RawDB is for other ORM to operate with, use it with caution. // Notice: don't close it. RawDB() (*sql.DB, error) - Transact(func(session Session) error) error + Transact(fn func(Session) error) error + TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error } // SqlOption defines the method to customize a sql connection. @@ -37,10 +45,15 @@ type ( StmtSession interface { Close() error Exec(args ...interface{}) (sql.Result, error) + ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error) QueryRow(v interface{}, args ...interface{}) error + QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error QueryRowPartial(v interface{}, args ...interface{}) error + QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error QueryRows(v interface{}, args ...interface{}) error + QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error QueryRowsPartial(v interface{}, args ...interface{}) error + QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error } // thread-safe @@ -58,7 +71,9 @@ type ( sessionConn interface { Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) Query(query string, args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } statement struct { @@ -68,7 +83,9 @@ type ( stmtConn interface { Exec(args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) Query(args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) } ) @@ -112,6 +129,11 @@ func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn { } func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) { + return db.ExecCtx(context.Background(), q, args...) +} + +func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...interface{}) ( + result sql.Result, err error) { err = db.brk.DoWithAcceptable(func() error { var conn *sql.DB conn, err = db.connProv() @@ -120,7 +142,7 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, return err } - result, err = exec(conn, q, args...) + result, err = exec(ctx, conn, q, args...) return err }, db.acceptable) @@ -128,6 +150,10 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, } func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) { + return db.PrepareCtx(context.Background(), query) +} + +func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) { err = db.brk.DoWithAcceptable(func() error { var conn *sql.DB conn, err = db.connProv() @@ -136,7 +162,7 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) { return err } - st, err := conn.Prepare(query) + st, err := conn.PrepareContext(ctx, query) if err != nil { return err } @@ -152,25 +178,45 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) { } func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error { - return db.queryRows(func(rows *sql.Rows) error { + return db.QueryRowCtx(context.Background(), v, q, args...) +} + +func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v interface{}, q string, + args ...interface{}) error { + return db.queryRows(ctx, func(rows *sql.Rows) error { return unmarshalRow(v, rows, true) }, q, args...) } func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error { - return db.queryRows(func(rows *sql.Rows) error { + return db.QueryRowPartialCtx(context.Background(), v, q, args...) +} + +func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v interface{}, + q string, args ...interface{}) error { + return db.queryRows(ctx, func(rows *sql.Rows) error { return unmarshalRow(v, rows, false) }, q, args...) } func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error { - return db.queryRows(func(rows *sql.Rows) error { + return db.QueryRowsCtx(context.Background(), v, q, args...) +} + +func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v interface{}, q string, + args ...interface{}) error { + return db.queryRows(ctx, func(rows *sql.Rows) error { return unmarshalRows(v, rows, true) }, q, args...) } func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { - return db.queryRows(func(rows *sql.Rows) error { + return db.QueryRowsPartialCtx(context.Background(), v, q, args...) +} + +func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{}, + q string, args ...interface{}) error { + return db.queryRows(ctx, func(rows *sql.Rows) error { return unmarshalRows(v, rows, false) }, q, args...) } @@ -180,13 +226,19 @@ func (db *commonSqlConn) RawDB() (*sql.DB, error) { } func (db *commonSqlConn) Transact(fn func(Session) error) error { + return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error { + return fn(session) + }) +} + +func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error { return db.brk.DoWithAcceptable(func() error { - return transact(db, db.beginTx, fn) + return transact(ctx, db, db.beginTx, fn) }, db.acceptable) } func (db *commonSqlConn) acceptable(err error) bool { - ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone + ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled if db.accept == nil { return ok } @@ -194,7 +246,8 @@ func (db *commonSqlConn) acceptable(err error) bool { return ok || db.accept(err) } -func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error { +func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error, + q string, args ...interface{}) error { var qerr error return db.brk.DoWithAcceptable(func() error { conn, err := db.connProv() @@ -203,7 +256,7 @@ func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args return err } - return query(conn, func(rows *sql.Rows) error { + return query(ctx, conn, func(rows *sql.Rows) error { qerr = scanner(rows) return qerr }, q, args...) @@ -217,29 +270,49 @@ func (s statement) Close() error { } func (s statement) Exec(args ...interface{}) (sql.Result, error) { - return execStmt(s.stmt, s.query, args...) + return s.ExecCtx(context.Background(), args...) +} + +func (s statement) ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error) { + return execStmt(ctx, s.stmt, s.query, args...) } func (s statement) QueryRow(v interface{}, args ...interface{}) error { - return queryStmt(s.stmt, func(rows *sql.Rows) error { + return s.QueryRowCtx(context.Background(), v, args...) +} + +func (s statement) QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error { + return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { return unmarshalRow(v, rows, true) }, s.query, args...) } func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error { - return queryStmt(s.stmt, func(rows *sql.Rows) error { + return s.QueryRowPartialCtx(context.Background(), v, args...) +} + +func (s statement) QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error { + return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { return unmarshalRow(v, rows, false) }, s.query, args...) } func (s statement) QueryRows(v interface{}, args ...interface{}) error { - return queryStmt(s.stmt, func(rows *sql.Rows) error { + return s.QueryRowsCtx(context.Background(), v, args...) +} + +func (s statement) QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error { + return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { return unmarshalRows(v, rows, true) }, s.query, args...) } func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error { - return queryStmt(s.stmt, func(rows *sql.Rows) error { + return s.QueryRowsPartialCtx(context.Background(), v, args...) +} + +func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error { + return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { return unmarshalRows(v, rows, false) }, s.query, args...) } diff --git a/core/stores/sqlx/stmt.go b/core/stores/sqlx/stmt.go index 3cda4853..2bbdf7b0 100644 --- a/core/stores/sqlx/stmt.go +++ b/core/stores/sqlx/stmt.go @@ -1,6 +1,7 @@ package sqlx import ( + "context" "database/sql" "time" @@ -18,64 +19,65 @@ func SetSlowThreshold(threshold time.Duration) { slowThreshold.Set(threshold) } -func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) { +func exec(ctx context.Context, conn sessionConn, q string, args ...interface{}) (sql.Result, error) { stmt, err := format(q, args...) if err != nil { return nil, err } startTime := timex.Now() - result, err := conn.Exec(q, args...) + result, err := conn.ExecContext(ctx, q, args...) duration := timex.Since(startTime) if duration > slowThreshold.Load() { - logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt) + logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt) } else { - logx.WithDuration(duration).Infof("sql exec: %s", stmt) + logx.WithContext(ctx).WithDuration(duration).Infof("sql exec: %s", stmt) } if err != nil { - logSqlError(stmt, err) + logSqlError(ctx, stmt, err) } return result, err } -func execStmt(conn stmtConn, q string, args ...interface{}) (sql.Result, error) { +func execStmt(ctx context.Context, conn stmtConn, q string, args ...interface{}) (sql.Result, error) { stmt, err := format(q, args...) if err != nil { return nil, err } startTime := timex.Now() - result, err := conn.Exec(args...) + result, err := conn.ExecContext(ctx, args...) duration := timex.Since(startTime) if duration > slowThreshold.Load() { - logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt) + logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt) } else { - logx.WithDuration(duration).Infof("sql execStmt: %s", stmt) + logx.WithContext(ctx).WithDuration(duration).Infof("sql execStmt: %s", stmt) } if err != nil { - logSqlError(stmt, err) + logSqlError(ctx, stmt, err) } return result, err } -func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error { +func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error, + q string, args ...interface{}) error { stmt, err := format(q, args...) if err != nil { return err } startTime := timex.Now() - rows, err := conn.Query(q, args...) + rows, err := conn.QueryContext(ctx, q, args...) duration := timex.Since(startTime) if duration > slowThreshold.Load() { - logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt) + logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt) } else { - logx.WithDuration(duration).Infof("sql query: %s", stmt) + logx.WithContext(ctx).WithDuration(duration).Infof("sql query: %s", stmt) } if err != nil { - logSqlError(stmt, err) + logSqlError(ctx, stmt, err) return err } defer rows.Close() @@ -83,22 +85,23 @@ func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...in return scanner(rows) } -func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error { +func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error, + q string, args ...interface{}) error { stmt, err := format(q, args...) if err != nil { return err } startTime := timex.Now() - rows, err := conn.Query(args...) + rows, err := conn.QueryContext(ctx, args...) duration := timex.Since(startTime) if duration > slowThreshold.Load() { - logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt) + logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt) } else { - logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt) + logx.WithContext(ctx).WithDuration(duration).Infof("sql queryStmt: %s", stmt) } if err != nil { - logSqlError(stmt, err) + logSqlError(ctx, stmt, err) return err } defer rows.Close() diff --git a/core/stores/sqlx/stmt_test.go b/core/stores/sqlx/stmt_test.go index 9a252afa..074ba2c6 100644 --- a/core/stores/sqlx/stmt_test.go +++ b/core/stores/sqlx/stmt_test.go @@ -1,6 +1,7 @@ package sqlx import ( + "context" "database/sql" "errors" "testing" @@ -57,7 +58,7 @@ func TestStmt_exec(t *testing.T) { test := test fns := []func(args ...interface{}) (sql.Result, error){ func(args ...interface{}) (sql.Result, error) { - return exec(&mockedSessionConn{ + return exec(context.Background(), &mockedSessionConn{ lastInsertId: test.lastInsertId, rowsAffected: test.rowsAffected, err: test.err, @@ -65,7 +66,7 @@ func TestStmt_exec(t *testing.T) { }, test.query, args...) }, func(args ...interface{}) (sql.Result, error) { - return execStmt(&mockedStmtConn{ + return execStmt(context.Background(), &mockedStmtConn{ lastInsertId: test.lastInsertId, rowsAffected: test.rowsAffected, err: test.err, @@ -137,7 +138,7 @@ func TestStmt_query(t *testing.T) { test := test fns := []func(args ...interface{}) error{ func(args ...interface{}) error { - return query(&mockedSessionConn{ + return query(context.Background(), &mockedSessionConn{ err: test.err, delay: test.delay, }, func(rows *sql.Rows) error { @@ -145,7 +146,7 @@ func TestStmt_query(t *testing.T) { }, test.query, args...) }, func(args ...interface{}) error { - return queryStmt(&mockedStmtConn{ + return queryStmt(context.Background(), &mockedStmtConn{ err: test.err, delay: test.delay, }, func(rows *sql.Rows) error { @@ -185,6 +186,10 @@ type mockedSessionConn struct { } func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) { + return m.ExecContext(context.Background(), query, args...) +} + +func (m *mockedSessionConn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { if m.delay { time.Sleep(defaultSlowThreshold + time.Millisecond) } @@ -195,6 +200,10 @@ func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, } func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) { + return m.QueryContext(context.Background(), query, args...) +} + +func (m *mockedSessionConn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { if m.delay { time.Sleep(defaultSlowThreshold + time.Millisecond) } @@ -214,6 +223,10 @@ type mockedStmtConn struct { } func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) { + return m.ExecContext(context.Background(), args...) +} + +func (m *mockedStmtConn) ExecContext(_ context.Context, _ ...interface{}) (sql.Result, error) { if m.delay { time.Sleep(defaultSlowThreshold + time.Millisecond) } @@ -224,6 +237,10 @@ func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) { } func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) { + return m.QueryContext(context.Background(), args...) +} + +func (m *mockedStmtConn) QueryContext(_ context.Context, _ ...interface{}) (*sql.Rows, error) { if m.delay { time.Sleep(defaultSlowThreshold + time.Millisecond) } diff --git a/core/stores/sqlx/tx.go b/core/stores/sqlx/tx.go index cbb5c4db..67c02ff1 100644 --- a/core/stores/sqlx/tx.go +++ b/core/stores/sqlx/tx.go @@ -1,6 +1,7 @@ package sqlx import ( + "context" "database/sql" "fmt" ) @@ -26,11 +27,19 @@ func NewSessionFromTx(tx *sql.Tx) Session { } func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) { - return exec(t.Tx, q, args...) + return t.ExecCtx(context.Background(), q, args...) +} + +func (t txSession) ExecCtx(ctx context.Context, q string, args ...interface{}) (sql.Result, error) { + return exec(ctx, t.Tx, q, args...) } func (t txSession) Prepare(q string) (StmtSession, error) { - stmt, err := t.Tx.Prepare(q) + return t.PrepareCtx(context.Background(), q) +} + +func (t txSession) PrepareCtx(ctx context.Context, q string) (StmtSession, error) { + stmt, err := t.Tx.PrepareContext(ctx, q) if err != nil { return nil, err } @@ -42,25 +51,43 @@ func (t txSession) Prepare(q string) (StmtSession, error) { } func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error { - return query(t.Tx, func(rows *sql.Rows) error { + return t.QueryRowCtx(context.Background(), v, q, args...) +} + +func (t txSession) QueryRowCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error { + return query(ctx, t.Tx, func(rows *sql.Rows) error { return unmarshalRow(v, rows, true) }, q, args...) } func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error { - return query(t.Tx, func(rows *sql.Rows) error { + return t.QueryRowPartialCtx(context.Background(), v, q, args...) +} + +func (t txSession) QueryRowPartialCtx(ctx context.Context, v interface{}, q string, + args ...interface{}) error { + return query(ctx, t.Tx, func(rows *sql.Rows) error { return unmarshalRow(v, rows, false) }, q, args...) } func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error { - return query(t.Tx, func(rows *sql.Rows) error { + return t.QueryRowsCtx(context.Background(), v, q, args...) +} + +func (t txSession) QueryRowsCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error { + return query(ctx, t.Tx, func(rows *sql.Rows) error { return unmarshalRows(v, rows, true) }, q, args...) } func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { - return query(t.Tx, func(rows *sql.Rows) error { + return t.QueryRowsPartialCtx(context.Background(), v, q, args...) +} + +func (t txSession) QueryRowsPartialCtx(ctx context.Context, v interface{}, q string, + args ...interface{}) error { + return query(ctx, t.Tx, func(rows *sql.Rows) error { return unmarshalRows(v, rows, false) }, q, args...) } @@ -76,17 +103,19 @@ func begin(db *sql.DB) (trans, error) { }, nil } -func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) { +func transact(ctx context.Context, db *commonSqlConn, b beginnable, + fn func(context.Context, Session) error) (err error) { conn, err := db.connProv() if err != nil { db.onError(err) return err } - return transactOnConn(conn, b, fn) + return transactOnConn(ctx, conn, b, fn) } -func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) { +func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable, + fn func(context.Context, Session) error) (err error) { var tx trans tx, err = b(conn) if err != nil { @@ -96,18 +125,18 @@ func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err err defer func() { if p := recover(); p != nil { if e := tx.Rollback(); e != nil { - err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e) + err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e) } else { err = fmt.Errorf("recoveer from %#v", p) } } else if err != nil { if e := tx.Rollback(); e != nil { - err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e) + err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e) } } else { err = tx.Commit() } }() - return fn(tx) + return fn(ctx, tx) } diff --git a/core/stores/sqlx/tx_test.go b/core/stores/sqlx/tx_test.go index 72ac5f17..297f0562 100644 --- a/core/stores/sqlx/tx_test.go +++ b/core/stores/sqlx/tx_test.go @@ -1,6 +1,7 @@ package sqlx import ( + "context" "database/sql" "errors" "testing" @@ -26,26 +27,50 @@ func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) { return nil, nil } +func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return nil, nil +} + func (mt *mockTx) Prepare(query string) (StmtSession, error) { return nil, nil } +func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) { + return nil, nil +} + func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error { return nil } +func (mt *mockTx) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { + return nil +} + func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error { return nil } +func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { + return nil +} + func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error { return nil } +func (mt *mockTx) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { + return nil +} + func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { return nil } +func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { + return nil +} + func (mt *mockTx) Rollback() error { mt.status |= mockRollback return nil @@ -59,18 +84,20 @@ func beginMock(mock *mockTx) beginnable { func TestTransactCommit(t *testing.T) { mock := &mockTx{} - err := transactOnConn(nil, beginMock(mock), func(Session) error { - return nil - }) + err := transactOnConn(context.Background(), nil, beginMock(mock), + func(context.Context, Session) error { + return nil + }) assert.Equal(t, mockCommit, mock.status) assert.Nil(t, err) } func TestTransactRollback(t *testing.T) { mock := &mockTx{} - err := transactOnConn(nil, beginMock(mock), func(Session) error { - return errors.New("rollback") - }) + err := transactOnConn(context.Background(), nil, beginMock(mock), + func(context.Context, Session) error { + return errors.New("rollback") + }) assert.Equal(t, mockRollback, mock.status) assert.NotNil(t, err) } diff --git a/core/stores/sqlx/utils.go b/core/stores/sqlx/utils.go index 74faef7e..9888e400 100644 --- a/core/stores/sqlx/utils.go +++ b/core/stores/sqlx/utils.go @@ -1,6 +1,7 @@ package sqlx import ( + "context" "fmt" "strconv" "strings" @@ -109,9 +110,9 @@ func logInstanceError(datasource string, err error) { logx.Errorf("Error on getting sql instance of %s: %v", datasource, err) } -func logSqlError(stmt string, err error) { +func logSqlError(ctx context.Context, stmt string, err error) { if err != nil && err != ErrNotFound { - logx.Errorf("stmt: %s, error: %s", stmt, err.Error()) + logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error()) } } diff --git a/tools/goctl/api/docgen/gen.go b/tools/goctl/api/docgen/gen.go index f298576e..fee801d6 100644 --- a/tools/goctl/api/docgen/gen.go +++ b/tools/goctl/api/docgen/gen.go @@ -12,7 +12,7 @@ import ( "github.com/zeromicro/go-zero/tools/goctl/util/pathx" ) -// DocCommand generate markdown doc file +// DocCommand generate Markdown doc file func DocCommand(c *cli.Context) error { dir := c.String("dir") if len(dir) == 0 { @@ -45,7 +45,7 @@ func DocCommand(c *cli.Context) error { for _, p := range files { api, err := parser.Parse(p) if err != nil { - return fmt.Errorf("parse file: %s, err: %s", p, err.Error()) + return fmt.Errorf("parse file: %s, err: %w", p, err) } api.Service = api.Service.JoinPrefix() diff --git a/tools/goctl/migrate/migrate.go b/tools/goctl/migrate/migrate.go index b7b5082e..6094a688 100644 --- a/tools/goctl/migrate/migrate.go +++ b/tools/goctl/migrate/migrate.go @@ -164,12 +164,12 @@ func writeFile(pkgs []*ast.Package, verbose bool) error { w := bytes.NewBuffer(nil) err := format.Node(w, fset, file) if err != nil { - return fmt.Errorf("[rewriteImport] format file %s error: %+v", filename, err) + return fmt.Errorf("[rewriteImport] format file %s error: %w", filename, err) } err = ioutil.WriteFile(filename, w.Bytes(), os.ModePerm) if err != nil { - return fmt.Errorf("[rewriteImport] write file %s error: %+v", filename, err) + return fmt.Errorf("[rewriteImport] write file %s error: %w", filename, err) } if verbose { console.Success("[OK] migrated %q successfully", filepath.Base(filename))