diff --git a/core/stores/redis/conf.go b/core/stores/redis/conf.go index 25621b22..083a7e25 100644 --- a/core/stores/redis/conf.go +++ b/core/stores/redis/conf.go @@ -14,9 +14,10 @@ var ( type ( // A RedisConf is a redis config. RedisConf struct { - Host string - Type string `json:",default=node,options=node|cluster"` - Pass string `json:",optional"` + Host string + Type string `json:",default=node,options=node|cluster"` + Pass string `json:",optional"` + TLSFlag bool `json:",default=false,options=true|false"` } // A RedisKeyConf is a redis config with key. @@ -28,6 +29,9 @@ type ( // NewRedis returns a Redis. func (rc RedisConf) NewRedis() *Redis { + if rc.TLSFlag { + return NewRedisWithTLS(rc.Host, rc.Type, rc.TLSFlag, rc.Pass) + } return NewRedis(rc.Host, rc.Type, rc.Pass) } diff --git a/core/stores/redis/redis.go b/core/stores/redis/redis.go index 4d79b171..697cd177 100644 --- a/core/stores/redis/redis.go +++ b/core/stores/redis/redis.go @@ -37,10 +37,11 @@ type ( // Redis defines a redis node/cluster. It is thread-safe. Redis struct { - Addr string - Type string - Pass string - brk breaker.Breaker + Addr string + Type string + Pass string + brk breaker.Breaker + TLSFlag bool } // RedisNode interface represents a redis node. @@ -71,16 +72,21 @@ type ( // NewRedis returns a Redis. func NewRedis(redisAddr, redisType string, redisPass ...string) *Redis { + return NewRedisWithTLS(redisAddr, redisType, false, redisPass...) +} + +func NewRedisWithTLS(redisAddr, redisType string, tlsFlag bool, redisPass ...string) *Redis { var pass string for _, v := range redisPass { pass = v } return &Redis{ - Addr: redisAddr, - Type: redisType, - Pass: pass, - brk: breaker.NewBreaker(), + Addr: redisAddr, + Type: redisType, + Pass: pass, + brk: breaker.NewBreaker(), + TLSFlag: tlsFlag, } } @@ -1704,9 +1710,17 @@ func acceptable(err error) bool { func getRedis(r *Redis) (RedisNode, error) { switch r.Type { case ClusterType: - return getCluster(r.Addr, r.Pass) + if r.TLSFlag { + return getClusterWithTLS(r.Addr, r.Pass, r.TLSFlag) + } else { + return getCluster(r.Addr, r.Pass) + } case NodeType: - return getClient(r.Addr, r.Pass) + if r.TLSFlag { + return getClientWithTLS(r.Addr, r.Pass, r.TLSFlag) + } else { + return getClient(r.Addr, r.Pass) + } default: return nil, fmt.Errorf("redis type '%s' is not supported", r.Type) } diff --git a/core/stores/redis/redis_test.go b/core/stores/redis/redis_test.go index 6d8af051..c4950d2c 100644 --- a/core/stores/redis/redis_test.go +++ b/core/stores/redis/redis_test.go @@ -1,6 +1,7 @@ package redis import ( + "crypto/tls" "errors" "io" "strconv" @@ -26,6 +27,20 @@ func TestRedis_Exists(t *testing.T) { }) } +func TestRedisTLS_Exists(t *testing.T) { + runOnRedisTLS(t, func(client *Redis) { + _, err := NewRedisWithTLS(client.Addr, "", true).Exists("a") + assert.NotNil(t, err) + ok, err := client.Exists("a") + assert.NotNil(t, err) + assert.False(t, ok) + assert.NotNil(t, client.Set("a", "b")) + ok, err = client.Exists("a") + assert.NotNil(t, err) + assert.False(t, ok) + }) +} + func TestRedis_Eval(t *testing.T) { runOnRedis(t, func(client *Redis) { _, err := NewRedis(client.Addr, "").Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"}) @@ -1062,8 +1077,28 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) { client.Close() } }() - fn(NewRedis(s.Addr(), NodeType)) + +} + +func runOnRedisTLS(t *testing.T, fn func(client *Redis)) { + s, err := miniredis.RunTLS(&tls.Config{ + Certificates: make([]tls.Certificate, 1), + InsecureSkipVerify: true, + }) + assert.Nil(t, err) + defer func() { + client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) { + return nil, errors.New("should already exist") + }) + if err != nil { + t.Error(err) + } + if client != nil { + client.Close() + } + }() + fn(NewRedisWithTLS(s.Addr(), NodeType, true)) } type mockedNode struct { diff --git a/core/stores/redis/redisclientmanager.go b/core/stores/redis/redisclientmanager.go index f0caf931..f5a1f718 100644 --- a/core/stores/redis/redisclientmanager.go +++ b/core/stores/redis/redisclientmanager.go @@ -1,6 +1,7 @@ package redis import ( + "crypto/tls" "io" red "github.com/go-redis/redis" @@ -16,13 +17,24 @@ const ( var clientManager = syncx.NewResourceManager() func getClient(server, pass string) (*red.Client, error) { + return getClientWithTLS(server, pass, false) +} + +func getClientWithTLS(server, pass string, tlsFlag bool) (*red.Client, error) { val, err := clientManager.GetResource(server, func() (io.Closer, error) { + var tlsConfig *tls.Config = nil + if tlsFlag { + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } store := red.NewClient(&red.Options{ Addr: server, Password: pass, DB: defaultDatabase, MaxRetries: maxRetries, MinIdleConns: idleConns, + TLSConfig: tlsConfig, }) store.WrapProcess(process) return store, nil diff --git a/core/stores/redis/redisclustermanager.go b/core/stores/redis/redisclustermanager.go index ec52eb54..eec9e2a0 100644 --- a/core/stores/redis/redisclustermanager.go +++ b/core/stores/redis/redisclustermanager.go @@ -1,6 +1,7 @@ package redis import ( + "crypto/tls" "io" red "github.com/go-redis/redis" @@ -10,12 +11,23 @@ import ( var clusterManager = syncx.NewResourceManager() func getCluster(server, pass string) (*red.ClusterClient, error) { + return getClusterWithTLS(server, pass, false) +} + +func getClusterWithTLS(server, pass string, tlsFlag bool) (*red.ClusterClient, error) { val, err := clusterManager.GetResource(server, func() (io.Closer, error) { + var tlsConfig *tls.Config = nil + if tlsFlag { + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } store := red.NewClusterClient(&red.ClusterOptions{ Addrs: []string{server}, Password: pass, MaxRetries: maxRetries, MinIdleConns: idleConns, + TLSConfig: tlsConfig, }) store.WrapProcess(process)