diff --git a/core/prometheus/agent.go b/core/prometheus/agent.go index d3d67cb0..64f04bd8 100644 --- a/core/prometheus/agent.go +++ b/core/prometheus/agent.go @@ -7,10 +7,19 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/tal-tech/go-zero/core/logx" + "github.com/tal-tech/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/threading" ) -var once sync.Once +var ( + once sync.Once + enabled syncx.AtomicBool +) + +// Enabled returns if prometheus is enabled. +func Enabled() bool { + return enabled.True() +} // StartAgent starts a prometheus agent. func StartAgent(c Config) { @@ -19,6 +28,7 @@ func StartAgent(c Config) { return } + enabled.Set(true) threading.GoSafe(func() { http.Handle(c.Path, promhttp.Handler()) addr := fmt.Sprintf("%s:%d", c.Host, c.Port) diff --git a/rest/handler/prometheushandler.go b/rest/handler/prometheushandler.go index 43995cd9..7ca29537 100644 --- a/rest/handler/prometheushandler.go +++ b/rest/handler/prometheushandler.go @@ -6,6 +6,7 @@ import ( "time" "github.com/tal-tech/go-zero/core/metric" + "github.com/tal-tech/go-zero/core/prometheus" "github.com/tal-tech/go-zero/core/timex" "github.com/tal-tech/go-zero/rest/internal/security" ) @@ -34,6 +35,10 @@ var ( // PrometheusHandler returns a middleware that reports stats to prometheus. func PrometheusHandler(path string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { + if !prometheus.Enabled() { + return next + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { startTime := timex.Now() cw := &security.WithCodeResponseWriter{Writer: w} diff --git a/rest/handler/prometheushandler_test.go b/rest/handler/prometheushandler_test.go index 720a6356..a951145e 100644 --- a/rest/handler/prometheushandler_test.go +++ b/rest/handler/prometheushandler_test.go @@ -6,9 +6,26 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/prometheus" ) -func TestPromMetricHandler(t *testing.T) { +func TestPromMetricHandler_Disabled(t *testing.T) { + promMetricHandler := PrometheusHandler("/user/login") + handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) +} + +func TestPromMetricHandler_Enabled(t *testing.T) { + prometheus.StartAgent(prometheus.Config{ + Host: "localhost", + Path: "/", + }) promMetricHandler := PrometheusHandler("/user/login") handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) diff --git a/zrpc/internal/clientinterceptors/prometheusinterceptor.go b/zrpc/internal/clientinterceptors/prometheusinterceptor.go index 4efb3385..635e7ee2 100644 --- a/zrpc/internal/clientinterceptors/prometheusinterceptor.go +++ b/zrpc/internal/clientinterceptors/prometheusinterceptor.go @@ -6,6 +6,7 @@ import ( "time" "github.com/tal-tech/go-zero/core/metric" + "github.com/tal-tech/go-zero/core/prometheus" "github.com/tal-tech/go-zero/core/timex" "google.golang.org/grpc" "google.golang.org/grpc/status" @@ -35,6 +36,10 @@ var ( // PrometheusInterceptor is an interceptor that reports to prometheus server. func PrometheusInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if !prometheus.Enabled() { + return invoker(ctx, method, req, reply, cc, opts...) + } + startTime := timex.Now() err := invoker(ctx, method, req, reply, cc, opts...) metricClientReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), method) diff --git a/zrpc/internal/clientinterceptors/prometheusinterceptor_test.go b/zrpc/internal/clientinterceptors/prometheusinterceptor_test.go index c7bdb237..70c46026 100644 --- a/zrpc/internal/clientinterceptors/prometheusinterceptor_test.go +++ b/zrpc/internal/clientinterceptors/prometheusinterceptor_test.go @@ -6,25 +6,38 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/prometheus" "google.golang.org/grpc" ) func TestPromMetricInterceptor(t *testing.T) { tests := []struct { - name string - err error + name string + enable bool + err error }{ { - name: "nil", - err: nil, + name: "nil", + enable: true, + err: nil, }, { - name: "with error", - err: errors.New("mock"), + name: "with error", + enable: true, + err: errors.New("mock"), + }, + { + name: "disabled", }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + if test.enable { + prometheus.StartAgent(prometheus.Config{ + Host: "localhost", + Path: "/", + }) + } cc := new(grpc.ClientConn) err := PrometheusInterceptor(context.Background(), "/foo", nil, nil, cc, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, diff --git a/zrpc/internal/serverinterceptors/prometheusinterceptor.go b/zrpc/internal/serverinterceptors/prometheusinterceptor.go index b0120daf..decc0d59 100644 --- a/zrpc/internal/serverinterceptors/prometheusinterceptor.go +++ b/zrpc/internal/serverinterceptors/prometheusinterceptor.go @@ -6,6 +6,7 @@ import ( "time" "github.com/tal-tech/go-zero/core/metric" + "github.com/tal-tech/go-zero/core/prometheus" "github.com/tal-tech/go-zero/core/timex" "google.golang.org/grpc" "google.golang.org/grpc/status" @@ -36,6 +37,10 @@ var ( func UnaryPrometheusInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( interface{}, error) { + if !prometheus.Enabled() { + return handler(ctx, req) + } + startTime := timex.Now() resp, err := handler(ctx, req) metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), info.FullMethod) diff --git a/zrpc/internal/serverinterceptors/prometheusinterceptor_test.go b/zrpc/internal/serverinterceptors/prometheusinterceptor_test.go index 9d31f4f9..5eb1840e 100644 --- a/zrpc/internal/serverinterceptors/prometheusinterceptor_test.go +++ b/zrpc/internal/serverinterceptors/prometheusinterceptor_test.go @@ -5,10 +5,25 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/prometheus" "google.golang.org/grpc" ) -func TestUnaryPromMetricInterceptor(t *testing.T) { +func TestUnaryPromMetricInterceptor_Disabled(t *testing.T) { + interceptor := UnaryPrometheusInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/", + }, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + assert.Nil(t, err) +} + +func TestUnaryPromMetricInterceptor_Enabled(t *testing.T) { + prometheus.StartAgent(prometheus.Config{ + Host: "localhost", + Path: "/", + }) interceptor := UnaryPrometheusInterceptor() _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ FullMethod: "/",