diff --git a/zrpc/server.go b/zrpc/server.go index 5327b83e..158d8109 100644 --- a/zrpc/server.go +++ b/zrpc/server.go @@ -109,30 +109,38 @@ func SetServerSlowThreshold(threshold time.Duration) { serverinterceptors.SetSlowThreshold(threshold) } -func setupInterceptors(server internal.Server, c RpcServerConf, metrics *stat.Metrics) error { +func setupAuthInterceptors(svr internal.Server, c RpcServerConf) error { + rds, err := redis.NewRedis(c.Redis.RedisConf) + if err != nil { + return err + } + + authenticator, err := auth.NewAuthenticator(rds, c.Redis.Key, c.StrictControl) + if err != nil { + return err + } + + svr.AddStreamInterceptors(serverinterceptors.StreamAuthorizeInterceptor(authenticator)) + svr.AddUnaryInterceptors(serverinterceptors.UnaryAuthorizeInterceptor(authenticator)) + + return nil +} + +func setupInterceptors(svr internal.Server, c RpcServerConf, metrics *stat.Metrics) error { if c.CpuThreshold > 0 { shedder := load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold)) - server.AddUnaryInterceptors(serverinterceptors.UnarySheddingInterceptor(shedder, metrics)) + svr.AddUnaryInterceptors(serverinterceptors.UnarySheddingInterceptor(shedder, metrics)) } if c.Timeout > 0 { - server.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor( + svr.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor( time.Duration(c.Timeout) * time.Millisecond)) } if c.Auth { - rds, err := redis.NewRedis(c.Redis.RedisConf) - if err != nil { - return err - } - - authenticator, err := auth.NewAuthenticator(rds, c.Redis.Key, c.StrictControl) - if err != nil { + if err := setupAuthInterceptors(svr, c); err != nil { return err } - - server.AddStreamInterceptors(serverinterceptors.StreamAuthorizeInterceptor(authenticator)) - server.AddUnaryInterceptors(serverinterceptors.UnaryAuthorizeInterceptor(authenticator)) } return nil