diff --git a/cmd/spicedb/memoryprotection/memory_protection_integration_test.go b/cmd/spicedb/memoryprotection/memory_protection_integration_test.go index 0dc1230b77..0a54175e12 100644 --- a/cmd/spicedb/memoryprotection/memory_protection_integration_test.go +++ b/cmd/spicedb/memoryprotection/memory_protection_integration_test.go @@ -33,9 +33,18 @@ func TestServeWithMemoryProtectionMiddleware(t *testing.T) { serverToken := "mykey" serveResource, err := pool.RunWithOptions(&dockertest.RunOptions{ - Repository: "authzed/spicedb", - Tag: "ci", - Cmd: []string{"serve", "--log-level=debug", "--grpc-preshared-key", serverToken, "--telemetry-endpoint=\"\""}, + Repository: "authzed/spicedb", + Tag: "ci", + Cmd: []string{ + "serve", + "--log-level=debug", + "--grpc-preshared-key", serverToken, + "--telemetry-endpoint", "", + // With very low GOMEMLIMIT values, percentage-based dispatch cache defaults + // can round down to zero and fail startup. + "--dispatch-cache-max-cost", "8KiB", + "--dispatch-cluster-cache-max-cost", "8KiB", + }, ExposedPorts: []string{"50051/tcp"}, Env: []string{"GOMEMLIMIT=1B"}, // NOTE: Absurdly low on purpose }, func(config *docker.HostConfig) { diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index 842e3d69e1..8a3b39ee9b 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -93,7 +93,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas initCtx, initCancel := context.WithTimeout(context.Background(), 5*time.Minute) defer initCancel() - healthChecker, err := pool.NewNodeHealthChecker(url) + healthChecker, err := pool.NewNodeHealthChecker(url, config.prometheusRegisterer) if err != nil { return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url) } @@ -103,7 +103,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas // interfere with pool setup. initPoolConfig := readPoolConfig.Copy() initPoolConfig.MinConns = 1 - initPool, err := pool.NewRetryPool(initCtx, "init", initPoolConfig, healthChecker, config.maxRetries, config.connectRate) + initPool, err := pool.NewRetryPool(initCtx, "init", initPoolConfig, healthChecker, config.maxRetries, config.connectRate, config.prometheusRegisterer) if err != nil { return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url) } @@ -197,23 +197,25 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas gcWindow: config.gcWindow, watchEnabled: !config.watchDisabled, schema: *schema.Schema(config.columnOptimizationOption, config.withIntegrity, false), + prometheusRegisterer: config.prometheusRegisterer, + prometheusUnregisterFunction: func() {}, } ds.SetNowFunc(ds.headRevisionInternal) // this ctx and cancel is tied to the lifetime of the datastore ds.ctx, ds.cancel = context.WithCancel(context.Background()) - ds.writePool, err = pool.NewRetryPool(ds.ctx, "write", writePoolConfig, healthChecker, config.maxRetries, config.connectRate) + ds.writePool, err = pool.NewRetryPool(ds.ctx, "write", writePoolConfig, healthChecker, config.maxRetries, config.connectRate, config.prometheusRegisterer) if err != nil { ds.cancel() return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url) } - ds.readPool, err = pool.NewRetryPool(ds.ctx, "read", readPoolConfig, healthChecker, config.maxRetries, config.connectRate) + ds.readPool, err = pool.NewRetryPool(ds.ctx, "read", readPoolConfig, healthChecker, config.maxRetries, config.connectRate, config.prometheusRegisterer) if err != nil { ds.cancel() return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url) } - err = ds.registerPrometheusCollectors(config.enablePrometheusStats) + err = ds.registerPrometheusCollectors(config.prometheusRegisterer, config.enablePrometheusStats) if err != nil { ds.cancel() return nil, err @@ -225,6 +227,15 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas // Start goroutines for pruning if config.enableConnectionBalancing { log.Ctx(initCtx).Info().Msg("starting cockroach connection balancer") + balancerUnregister, _ := pool.RegisterNodeConnectionBalancerMetrics(config.prometheusRegisterer) + if balancerUnregister != nil { + previousUnregister := ds.prometheusUnregisterFunction + ds.prometheusUnregisterFunction = func() { + previousUnregister() + balancerUnregister() + } + } + ds.pruneGroup, ds.ctx = errgroup.WithContext(ds.ctx) writePoolBalancer := pool.NewNodeConnectionBalancer(ds.writePool, healthChecker, 5*time.Second) readPoolBalancer := pool.NewNodeConnectionBalancer(ds.readPool, healthChecker, 5*time.Second) @@ -263,7 +274,8 @@ type crdbDatastore struct { dburl string readPool, writePool *pool.RetryPool - collectors []prometheus.Collector + prometheusRegisterer prometheus.Registerer + prometheusUnregisterFunction func() watchBufferLength uint16 watchChangeBufferMaximumSize uint64 watchBufferWriteTimeout time.Duration @@ -481,12 +493,8 @@ func (cds *crdbDatastore) Close() error { } cds.readPool.Close() cds.writePool.Close() - for _, collector := range cds.collectors { - ok := prometheus.Unregister(collector) - if !ok { - errs = append(errs, errors.New("could not unregister collector for CRDB datastore")) - } - } + cds.prometheusUnregisterFunction() + return errors.Join(errs...) } @@ -658,7 +666,7 @@ func readClusterTTLNanos(ctx context.Context, conn pgxcommon.DBFuncQuerier) (int return gcSeconds * 1_000_000_000, nil } -func (cds *crdbDatastore) registerPrometheusCollectors(enablePrometheusStats bool) error { +func (cds *crdbDatastore) registerPrometheusCollectors(registerer prometheus.Registerer, enablePrometheusStats bool) error { if !enablePrometheusStats { return nil } @@ -668,20 +676,14 @@ func (cds *crdbDatastore) registerPrometheusCollectors(enablePrometheusStats boo "pool_usage": "read", }) - if err := prometheus.Register(readCollector); err != nil { - return fmt.Errorf("failed to register prometheus read collector: %w", err) - } - cds.collectors = append(cds.collectors, readCollector) - writeCollector := pgxpoolprometheus.NewCollector(cds.writePool, map[string]string{ "db_name": "spicedb", "pool_usage": "write", }) - if err := prometheus.Register(writeCollector); err != nil { - return fmt.Errorf("failed to register prometheus write collector: %w", err) - } - cds.collectors = append(cds.collectors, writeCollector) + unregister, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register crdb pool metrics", readCollector, writeCollector) - return nil + cds.prometheusUnregisterFunction = unregister + + return err } diff --git a/internal/datastore/crdb/crdb_test.go b/internal/datastore/crdb/crdb_test.go index 69a2f9f4c7..e95108d454 100644 --- a/internal/datastore/crdb/crdb_test.go +++ b/internal/datastore/crdb/crdb_test.go @@ -896,7 +896,7 @@ func TestRegisterPrometheusCollectors(t *testing.T) { // Create read & write pools readPoolConfig, err := pgxpool.ParseConfig(fmt.Sprintf("postgres://db:password@pg.example.com:5432/mydb?pool_max_conns=%d", readMaxConns)) require.NoError(t, err) - readPool, err := pool.NewRetryPool(t.Context(), "read", readPoolConfig, nil, 18, 20) + readPool, err := pool.NewRetryPool(t.Context(), "read", readPoolConfig, nil, 18, 20, nil) require.NoError(t, err) t.Cleanup(func() { readPool.Close() @@ -904,7 +904,7 @@ func TestRegisterPrometheusCollectors(t *testing.T) { writePoolConfig, err := pgxpool.ParseConfig(fmt.Sprintf("postgres://db:password@pg.example.com:5432/mydb?pool_max_conns=%d", writeMaxConns)) require.NoError(t, err) - writePool, err := pool.NewRetryPool(t.Context(), "read", writePoolConfig, nil, 18, 20) + writePool, err := pool.NewRetryPool(t.Context(), "write", writePoolConfig, nil, 18, 20, nil) require.NoError(t, err) // Create datastore with those pools @@ -913,14 +913,12 @@ func TestRegisterPrometheusCollectors(t *testing.T) { _ = cds.Close() }) - err = cds.registerPrometheusCollectors(false) + err = cds.registerPrometheusCollectors(prometheus.NewPedanticRegistry(), false) require.NoError(t, err) - require.Empty(t, cds.collectors) // Register collectors and verify the values of the metrics - err = cds.registerPrometheusCollectors(true) + err = cds.registerPrometheusCollectors(prometheus.NewPedanticRegistry(), true) require.NoError(t, err) - require.Len(t, cds.collectors, 2) metricFamily, err := prometheus.DefaultGatherer.Gather() require.NoError(t, err) @@ -977,9 +975,9 @@ func TestVersionReading(t *testing.T) { // Set up a raw connection to the DB initPoolConfig, err := pgxpool.ParseConfig(uri) require.NoError(err) - checker, err := pool.NewNodeHealthChecker(uri) + checker, err := pool.NewNodeHealthChecker(uri, nil) require.NoError(err) - initPool, err := pool.NewRetryPool(t.Context(), "pool", initPoolConfig, checker, 18, 20) + initPool, err := pool.NewRetryPool(t.Context(), "pool", initPoolConfig, checker, 18, 20, nil) require.NoError(err) t.Cleanup(func() { initPool.Close() diff --git a/internal/datastore/crdb/options.go b/internal/datastore/crdb/options.go index 9e745fc213..f35a9026de 100644 --- a/internal/datastore/crdb/options.go +++ b/internal/datastore/crdb/options.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/prometheus/client_golang/prometheus" "k8s.io/utils/ptr" "github.com/authzed/spicedb/internal/datastore/common" @@ -29,6 +30,7 @@ type crdbOptions struct { enableConnectionBalancing bool analyzeBeforeStatistics bool filterMaximumIDCount uint16 + prometheusRegisterer prometheus.Registerer enablePrometheusStats bool withIntegrity bool allowedMigrations []string @@ -127,6 +129,13 @@ func generateConfig(options []Option) (crdbOptions, error) { return computed, nil } +// WithPrometheusRegisterer sets the prometheus.Registerer used for CockroachDB datastore metrics. +func WithPrometheusRegisterer(registerer prometheus.Registerer) Option { + return func(po *crdbOptions) { + po.prometheusRegisterer = registerer + } +} + // ReadConnHealthCheckInterval is the frequency at which both idle and max // lifetime connections are checked, and also the frequency at which the // minimum number of connections is checked. diff --git a/internal/datastore/crdb/pool/balancer.go b/internal/datastore/crdb/pool/balancer.go index dbbbd23d27..cb1b1761d7 100644 --- a/internal/datastore/crdb/pool/balancer.go +++ b/internal/datastore/crdb/pool/balancer.go @@ -17,6 +17,7 @@ import ( "golang.org/x/sync/semaphore" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/genutil" ) @@ -33,11 +34,6 @@ var ( }, []string{"pool"}) ) -func init() { - prometheus.MustRegister(connectionsPerCRDBNodeCountGauge) - prometheus.MustRegister(pruningTimeHistogram) -} - type balancePoolConn[C balanceConn] interface { Conn() C Release() @@ -67,9 +63,16 @@ type NodeConnectionBalancer struct { nodeConnectionBalancer[*pgxpool.Conn, *pgx.Conn] } +// RegisterNodeConnectionBalancerMetrics registers the shared connection balancer collectors. +func RegisterNodeConnectionBalancerMetrics(registerer prometheus.Registerer) (func(), error) { + return datastore.RegisterPrometheusCollectors(registerer, "failed to register crdb connection balancer metrics", connectionsPerCRDBNodeCountGauge, pruningTimeHistogram) +} + // NewNodeConnectionBalancer builds a new nodeConnectionBalancer for a given connection pool and health tracker. func NewNodeConnectionBalancer(pool *RetryPool, healthTracker *NodeHealthTracker, interval time.Duration) *NodeConnectionBalancer { - return &NodeConnectionBalancer{*newNodeConnectionBalancer[*pgxpool.Conn, *pgx.Conn](pool, healthTracker, interval)} + return &NodeConnectionBalancer{ + *newNodeConnectionBalancer[*pgxpool.Conn, *pgx.Conn](pool, healthTracker, interval), + } } // nodeConnectionBalancer is generic over underlying connection types for diff --git a/internal/datastore/crdb/pool/balancer_test.go b/internal/datastore/crdb/pool/balancer_test.go index 230eafa8c1..b075f46675 100644 --- a/internal/datastore/crdb/pool/balancer_test.go +++ b/internal/datastore/crdb/pool/balancer_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" ) @@ -139,7 +140,8 @@ func TestNodeConnectionBalancerPrune(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tracker, err := NewNodeHealthChecker("") + reg := prometheus.NewRegistry() + tracker, err := NewNodeHealthChecker("", reg) require.NoError(t, err) for _, n := range tt.nodes { tracker.healthyNodes[n] = struct{}{} diff --git a/internal/datastore/crdb/pool/health.go b/internal/datastore/crdb/pool/health.go index d1c511b94d..68192bda63 100644 --- a/internal/datastore/crdb/pool/health.go +++ b/internal/datastore/crdb/pool/health.go @@ -14,6 +14,7 @@ import ( pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" ) const errorBurst = 2 @@ -23,10 +24,6 @@ var healthyCRDBNodeCountGauge = prometheus.NewGauge(prometheus.GaugeOpts{ Help: "the number of healthy crdb nodes detected by spicedb", }) -func init() { - prometheus.MustRegister(healthyCRDBNodeCountGauge) -} - // NodeHealthTracker detects changes in the node pool by polling the cluster periodically and recording // the node ids that are seen. This is used to detect new nodes that come online that have either previously // been marked unhealthy due to connection errors or due to scale up. @@ -34,19 +31,22 @@ func init() { // Consumers can manually mark a node healthy or unhealthy as well. type NodeHealthTracker struct { sync.RWMutex - connConfig *pgx.ConnConfig - healthyNodes map[uint32]struct{} // GUARDED_BY(RWMutex) - nodesEverSeen map[uint32]*rate.Limiter // GUARDED_BY(RWMutex) - newLimiter func() *rate.Limiter + connConfig *pgx.ConnConfig + healthyNodes map[uint32]struct{} // GUARDED_BY(RWMutex) + nodesEverSeen map[uint32]*rate.Limiter // GUARDED_BY(RWMutex) + newLimiter func() *rate.Limiter + prometheusUnregisterFunction func() } // NewNodeHealthChecker builds a health checker that polls the cluster at the given url. -func NewNodeHealthChecker(url string) (*NodeHealthTracker, error) { +func NewNodeHealthChecker(url string, registerer prometheus.Registerer) (*NodeHealthTracker, error) { connConfig, err := pgxcommon.ParseConfigWithInstrumentation(url) if err != nil { return nil, err } + unregister, _ := datastore.RegisterPrometheusCollectors(registerer, "failed to register crdb health metrics", healthyCRDBNodeCountGauge) + return &NodeHealthTracker{ connConfig: connConfig, healthyNodes: make(map[uint32]struct{}, 0), @@ -54,6 +54,7 @@ func NewNodeHealthChecker(url string) (*NodeHealthTracker, error) { newLimiter: func() *rate.Limiter { return rate.NewLimiter(rate.Every(1*time.Minute), errorBurst) }, + prometheusUnregisterFunction: unregister, }, nil } diff --git a/internal/datastore/crdb/pool/pool.go b/internal/datastore/crdb/pool/pool.go index 766de31e9a..a5f117e013 100644 --- a/internal/datastore/crdb/pool/pool.go +++ b/internal/datastore/crdb/pool/pool.go @@ -16,6 +16,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/postgres/common" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/spiceerrors" ) @@ -34,10 +35,6 @@ var resetHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{ Buckets: []float64{0, 1, 2, 5, 10, 20, 50}, }) -func init() { - prometheus.MustRegister(resetHistogram) -} - type ctxDisableRetries struct{} var ( @@ -51,19 +48,23 @@ type RetryPool struct { healthTracker *NodeHealthTracker sync.RWMutex - maxRetries uint8 - nodeForConn map[*pgx.Conn]uint32 // GUARDED_BY(RWMutex) - gc map[*pgx.Conn]struct{} // GUARDED_BY(RWMutex) + maxRetries uint8 + nodeForConn map[*pgx.Conn]uint32 // GUARDED_BY(RWMutex) + gc map[*pgx.Conn]struct{} // GUARDED_BY(RWMutex) + prometheusUnregisterFunction func() } -func NewRetryPool(ctx context.Context, name string, config *pgxpool.Config, healthTracker *NodeHealthTracker, maxRetries uint8, connectRate time.Duration) (*RetryPool, error) { +func NewRetryPool(ctx context.Context, name string, config *pgxpool.Config, healthTracker *NodeHealthTracker, maxRetries uint8, connectRate time.Duration, registerer prometheus.Registerer) (*RetryPool, error) { + unregister, _ := datastore.RegisterPrometheusCollectors(registerer, "failed to register crdb pool metrics", resetHistogram) + config = config.Copy() p := &RetryPool{ - id: name, - maxRetries: maxRetries, - healthTracker: healthTracker, - nodeForConn: make(map[*pgx.Conn]uint32, 0), - gc: make(map[*pgx.Conn]struct{}, 0), + id: name, + maxRetries: maxRetries, + healthTracker: healthTracker, + nodeForConn: make(map[*pgx.Conn]uint32, 0), + gc: make(map[*pgx.Conn]struct{}, 0), + prometheusUnregisterFunction: unregister, } limiter := rate.NewLimiter(rate.Every(connectRate), 1) diff --git a/internal/datastore/mysql/connection.go b/internal/datastore/mysql/connection.go index d986afdbe0..442cbf7edb 100644 --- a/internal/datastore/mysql/connection.go +++ b/internal/datastore/mysql/connection.go @@ -2,14 +2,17 @@ package mysql import ( "context" + "database/sql" "database/sql/driver" "fmt" "strconv" "time" "github.com/prometheus/client_golang/prometheus" + prom_collectors "github.com/prometheus/client_golang/prometheus/collectors" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" ) // instrumentedConnector wraps the default MySQL driver connector @@ -45,14 +48,15 @@ func (d *instrumentedConnector) Driver() driver.Driver { return d.drv } -func instrumentConnector(c driver.Connector, replicaIndex string) (driver.Connector, []prometheus.Collector, error) { +func instrumentConnector(registerer prometheus.Registerer, db *sql.DB, c driver.Connector, primaryIndex int, replicaIndex int) (driver.Connector, func(), error) { + replicaID := strconv.Itoa(replicaIndex) var ( connectHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{ Namespace: "spicedb", Subsystem: "datastore", Name: "mysql_connect_duration", ConstLabels: prometheus.Labels{ - "replica": replicaIndex, // this is needed to avoid "duplicate metrics collector registration attempted" + "replica": replicaID, }, Help: "distribution in seconds of time spent opening a new MySQL connection.", Buckets: []float64{0.01, 0.1, 0.5, 1, 5, 10, 25, 60, 120}, @@ -62,33 +66,26 @@ func instrumentConnector(c driver.Connector, replicaIndex string) (driver.Connec Subsystem: "datastore", Name: "mysql_connect_count_total", ConstLabels: prometheus.Labels{ - "replica": replicaIndex, // this is needed to avoid "duplicate metrics collector registration attempted" + "replica": replicaID, }, Help: "number of mysql connections opened.", }, []string{"success"}) ) - var collectors []prometheus.Collector - - err := prometheus.Register(connectHistogram) - if err != nil { - return nil, collectors, err - } - - collectors = append(collectors, connectHistogram) - err = prometheus.Register(connectCount) - if err != nil { - return nil, collectors, err + dbName := "spicedb" + if replicaIndex != primaryIndex { + dbName = fmt.Sprintf("spicedb_replica_%d", replicaIndex) } + dbStatsCollector := prom_collectors.NewDBStatsCollector(db, dbName) - collectors = append(collectors, connectCount) + unregister, _ := datastore.RegisterPrometheusCollectors(registerer, "failed to register mysql connector metrics", connectHistogram, connectCount, dbStatsCollector) return &instrumentedConnector{ conn: c, drv: c.Driver(), connectHistogram: connectHistogram, connectCount: connectCount, - }, collectors, nil + }, unregister, nil } type sessionVariableConnector struct { diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index ec66276e04..649b727e60 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -15,7 +15,6 @@ import ( "github.com/go-sql-driver/mysql" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" - prom_collectors "github.com/prometheus/client_golang/prometheus/collectors" "go.opentelemetry.io/otel" "golang.org/x/sync/errgroup" @@ -174,11 +173,9 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option } } - db, collectors, err := registerAndReturnPrometheusCollectors(replicaIndex, isPrimary, connector, config.enablePrometheusStats) + db, unregister, err := registerAndReturnPrometheusCollectors(config.prometheusRegisterer, replicaIndex, isPrimary, connector, config.enablePrometheusStats) if err != nil { - for _, collector := range collectors { - _ = prometheus.Unregister(collector) - } + unregister() return nil, err } @@ -251,8 +248,8 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option MigrationValidator: common.NewMigrationValidator(headMigration, config.allowedMigrations), db: db, driver: driver, - collectors: collectors, url: uri, + prometheusUnregister: unregister, revisionQuantization: config.revisionQuantization, gcWindow: config.gcWindow, gcInterval: config.gcInterval, @@ -472,12 +469,12 @@ type mysqlDatastore struct { *revisions.CachedOptimizedRevisions *common.MigrationValidator - db *sql.DB - driver *migrations.MySQLDriver - readTxOptions *sql.TxOptions - url string - analyzeBeforeStats bool - collectors []prometheus.Collector + db *sql.DB + driver *migrations.MySQLDriver + readTxOptions *sql.TxOptions + url string + analyzeBeforeStats bool + prometheusUnregister func() revisionQuantization time.Duration gcWindow time.Duration @@ -517,8 +514,8 @@ func (mds *mysqlDatastore) Close() error { log.Error().Err(err).Msg("error waiting for garbage collector to shutdown") } } - for _, collector := range mds.collectors { - _ = prometheus.Unregister(collector) + if mds.prometheusUnregister != nil { + mds.prometheusUnregister() } return mds.db.Close() } @@ -685,34 +682,27 @@ func (debugLogger) Print(v ...any) { log.Logger.Debug().CallerSkipFrame(1).Str("datastore", "mysql").Msg(fmt.Sprint(v...)) } -func registerAndReturnPrometheusCollectors(replicaIndex int, isPrimary bool, connector driver.Connector, enablePrometheusStats bool) (*sql.DB, []prometheus.Collector, error) { +func registerAndReturnPrometheusCollectors(registerer prometheus.Registerer, replicaIndex int, isPrimary bool, connector driver.Connector, enablePrometheusStats bool) (*sql.DB, func(), error) { if !enablePrometheusStats { return sql.OpenDB(connector), nil, nil } - connector, collectors, err := instrumentConnector(connector, strconv.Itoa(replicaIndex)) - if err != nil { - return nil, collectors, err - } - - dbName := "spicedb" - if replicaIndex != primaryInstanceID { - dbName = fmt.Sprintf("spicedb_replica_%d", replicaIndex) - } - db := sql.OpenDB(connector) - collector := prom_collectors.NewDBStatsCollector(db, dbName) - if err := prometheus.Register(collector); err != nil { - return nil, collectors, err + _, unregister, err := instrumentConnector(registerer, db, connector, primaryInstanceID, replicaIndex) + if err != nil { + return nil, unregister, err } - collectors = append(collectors, collector) if isPrimary { - gcMetrics, err := datastore.RegisterGCMetrics() + unregisterGC, err := datastore.RegisterGCMetrics(registerer) if err != nil { - return nil, collectors, err + return nil, unregister, err + } + unregister = func() { + unregisterGC() + unregister() } - collectors = append(collectors, gcMetrics...) } - return db, collectors, nil + + return db, unregister, nil } diff --git a/internal/datastore/mysql/datastore_test.go b/internal/datastore/mysql/datastore_test.go index 9b4dfb6a0f..45108db7b1 100644 --- a/internal/datastore/mysql/datastore_test.go +++ b/internal/datastore/mysql/datastore_test.go @@ -87,6 +87,22 @@ func createDatastoreTest(b testdatastore.RunningEngineForTest, tf datastoreTestF } } +type datastoreTestFuncWithGatherer func(t *testing.T, ds datastore.Datastore, g prometheus.Gatherer) + +func createDatastoreTestWithGatherer(b testdatastore.RunningEngineForTest, g prometheus.Gatherer, tf datastoreTestFuncWithGatherer, options ...Option) func(*testing.T) { + return func(t *testing.T) { + ctx := t.Context() + ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore { + ds, err := newMySQLDatastore(ctx, uri, primaryInstanceID, options...) + require.NoError(t, err) + return ds + }) + defer failOnError(t, ds.Close) + + tf(t, ds, g) + } +} + type multiDatastoreTestFunc func(t *testing.T, ds1 datastore.Datastore, ds2 datastore.Datastore) func createMultiDatastoreTest(b testdatastore.RunningEngineForTest, tf multiDatastoreTestFunc, options ...Option) func(*testing.T) { @@ -129,20 +145,21 @@ func TestMySQLRevisionTimestamps(t *testing.T) { func additionalMySQLTests(t *testing.T, b testdatastore.RunningEngineForTest) { reg := prometheus.NewRegistry() - prometheus.DefaultGatherer = reg - prometheus.DefaultRegisterer = reg - - t.Run("DatabaseSeeding", createDatastoreTest(b, DatabaseSeedingTest, defaultOptions...)) - t.Run("PrometheusCollector", createDatastoreTest(b, PrometheusCollectorTest, defaultOptions...)) - t.Run("GarbageCollection", createDatastoreTest(b, GarbageCollectionTest, defaultOptions...)) - t.Run("GarbageCollectionByTime", createDatastoreTest(b, GarbageCollectionByTimeTest, defaultOptions...)) - t.Run("ChunkedGarbageCollection", createDatastoreTest(b, ChunkedGarbageCollectionTest, defaultOptions...)) - t.Run("EmptyGarbageCollection", createDatastoreTest(b, EmptyGarbageCollectionTest, defaultOptions...)) - t.Run("NoRelationshipsGarbageCollection", createDatastoreTest(b, NoRelationshipsGarbageCollectionTest, defaultOptions...)) + regOptions := make([]Option, 0, len(defaultOptions)+1) + copy(regOptions, defaultOptions) + regOptions = append(regOptions, WithPrometheusRegisterer(reg)) + + t.Run("DatabaseSeeding", createDatastoreTest(b, DatabaseSeedingTest, regOptions...)) + t.Run("PrometheusCollector", createDatastoreTestWithGatherer(b, reg, PrometheusCollectorTest, regOptions...)) + t.Run("GarbageCollection", createDatastoreTest(b, GarbageCollectionTest, regOptions...)) + t.Run("GarbageCollectionByTime", createDatastoreTest(b, GarbageCollectionByTimeTest, regOptions...)) + t.Run("ChunkedGarbageCollection", createDatastoreTest(b, ChunkedGarbageCollectionTest, regOptions...)) + t.Run("EmptyGarbageCollection", createDatastoreTest(b, EmptyGarbageCollectionTest, regOptions...)) + t.Run("NoRelationshipsGarbageCollection", createDatastoreTest(b, NoRelationshipsGarbageCollectionTest, regOptions...)) t.Run("QuantizedRevisions", func(t *testing.T) { QuantizedRevisionTest(t, b) }) - t.Run("Locking", createMultiDatastoreTest(b, LockingTest, defaultOptions...)) + t.Run("Locking", createMultiDatastoreTest(b, LockingTest, regOptions...)) } func LockingTest(t *testing.T, ds datastore.Datastore, ds2 datastore.Datastore) { @@ -197,14 +214,14 @@ func DatabaseSeedingTest(t *testing.T, ds datastore.Datastore) { req.True(r.IsReady) } -func PrometheusCollectorTest(t *testing.T, ds datastore.Datastore) { +func PrometheusCollectorTest(t *testing.T, ds datastore.Datastore, g prometheus.Gatherer) { req := require.New(t) // cause some use of the SQL connection pool to generate metrics _, err := ds.ReadyState(t.Context()) req.NoError(err) - metrics, err := prometheus.DefaultGatherer.Gather() + metrics, err := g.Gather() req.NoError(err, metrics) var collectorStatsFound, connectorStatsFound bool for _, metric := range metrics { diff --git a/internal/datastore/mysql/options.go b/internal/datastore/mysql/options.go index 63f6e195c8..73f5fdd4de 100644 --- a/internal/datastore/mysql/options.go +++ b/internal/datastore/mysql/options.go @@ -4,6 +4,8 @@ import ( "fmt" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/authzed/spicedb/internal/datastore/common" log "github.com/authzed/spicedb/internal/logging" ) @@ -56,6 +58,7 @@ type mysqlOptions struct { allowedMigrations []string columnOptimizationOption common.ColumnOptimizationOption watchDisabled bool + prometheusRegisterer prometheus.Registerer } // Option provides the facility to configure how clients within the @@ -195,8 +198,7 @@ func TablePrefix(prefix string) Option { } // WithEnablePrometheusStats marks whether Prometheus metrics provided by Go's database/sql package -// are enabled. -// +// WithEnablePrometheusStats enables prometheus metrics for the MySQL datastore. // Prometheus metrics are disabled by default. func WithEnablePrometheusStats(enablePrometheusStats bool) Option { return func(mo *mysqlOptions) { @@ -204,6 +206,14 @@ func WithEnablePrometheusStats(enablePrometheusStats bool) Option { } } +// WithPrometheusRegisterer sets the prometheus.Registerer used for MySQL datastore metrics. +// If not set, prometheus.DefaultRegisterer is used. +func WithPrometheusRegisterer(registerer prometheus.Registerer) Option { + return func(mo *mysqlOptions) { + mo.prometheusRegisterer = registerer + } +} + // ConnMaxIdleTime is the duration after which an idle connection will be // automatically closed. // See https://pkg.go.dev/database/sql#DB.SetConnMaxIdleTime/ diff --git a/internal/datastore/postgres/options.go b/internal/datastore/postgres/options.go index 1e2e1a812a..73a62b5b3d 100644 --- a/internal/datastore/postgres/options.go +++ b/internal/datastore/postgres/options.go @@ -4,6 +4,8 @@ import ( "fmt" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/authzed/spicedb/internal/datastore/common" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" log "github.com/authzed/spicedb/internal/logging" @@ -42,7 +44,8 @@ type postgresOptions struct { logger *tracingLogger - queryInterceptor pgxcommon.QueryInterceptor + queryInterceptor pgxcommon.QueryInterceptor + prometheusRegisterer prometheus.Registerer } type migrationPhase uint8 diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index 5b1e8c018c..2af4d77e2a 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -240,7 +240,12 @@ func newPostgresDatastore( } } - collectors, err := registerAndReturnPrometheusCollectors(replicaIndex, isPrimary, readPool, writePool, config.enablePrometheusStats) + registerer := config.prometheusRegisterer + if registerer == nil { + registerer = prometheus.DefaultRegisterer + } + + collectors, gcMetricsUnregister, err := registerAndReturnPrometheusCollectors(registerer, replicaIndex, isPrimary, readPool, writePool, config.enablePrometheusStats) if err != nil { return nil, err } @@ -306,6 +311,7 @@ func newPostgresDatastore( readPool: pgxcommon.MustNewInterceptorPooler(readPool, config.queryInterceptor), writePool: nil, /* disabled by default */ collectors: collectors, + gcMetricsUnregister: gcMetricsUnregister, watchBufferLength: config.watchBufferLength, watchChangeBufferMaximumSize: config.watchChangeBufferMaximumSize, watchBufferWriteTimeout: config.watchBufferWriteTimeout, @@ -374,6 +380,7 @@ type pgDatastore struct { dburl string readPool, writePool pgxcommon.ConnPooler collectors []prometheus.Collector + gcMetricsUnregister func() watchBufferLength uint16 watchChangeBufferMaximumSize uint64 watchBufferWriteTimeout time.Duration @@ -391,6 +398,7 @@ type pgDatastore struct { inStrictReadMode bool schema common.SchemaInformation includeQueryParametersInTraces bool + prometheusRegisterer prometheus.Registerer credentialsProvider datastore.CredentialsProvider uniqueID atomic.Pointer[string] @@ -647,7 +655,10 @@ func (pgd *pgDatastore) Close() error { pgd.writePool.Close() } for _, collector := range pgd.collectors { - prometheus.Unregister(collector) + pgd.prometheusRegisterer.Unregister(collector) + } + if pgd.gcMetricsUnregister != nil { + pgd.gcMetricsUnregister() } return nil } @@ -801,10 +812,15 @@ func currentlyLivingObjects(original sq.SelectBuilder) sq.SelectBuilder { var _ datastore.Datastore = &pgDatastore{} -func registerAndReturnPrometheusCollectors(replicaIndex int, isPrimary bool, readPool, writePool *pgxpool.Pool, enablePrometheusStats bool) ([]prometheus.Collector, error) { +func registerAndReturnPrometheusCollectors(registerer prometheus.Registerer, replicaIndex int, isPrimary bool, readPool, writePool *pgxpool.Pool, enablePrometheusStats bool) ([]prometheus.Collector, func(), error) { + if registerer == nil { + registerer = prometheus.DefaultRegisterer + } + collectors := []prometheus.Collector{} + gcMetricsUnregister := func() {} if !enablePrometheusStats { - return collectors, nil + return collectors, gcMetricsUnregister, nil } replicaIndexStr := strconv.Itoa(replicaIndex) @@ -817,8 +833,8 @@ func registerAndReturnPrometheusCollectors(replicaIndex int, isPrimary bool, rea "db_name": dbname, "pool_usage": "read", }) - if err := prometheus.Register(readCollector); err != nil { - return collectors, err + if err := registerer.Register(readCollector); err != nil { + return collectors, gcMetricsUnregister, err } collectors = append(collectors, readCollector) @@ -828,17 +844,17 @@ func registerAndReturnPrometheusCollectors(replicaIndex int, isPrimary bool, rea "pool_usage": "write", }) - if err := prometheus.Register(writeCollector); err != nil { - return collectors, nil + if err := registerer.Register(writeCollector); err != nil { + return collectors, gcMetricsUnregister, nil } collectors = append(collectors, writeCollector) - gcCollectors, err := datastore.RegisterGCMetrics() + unregisterGC, err := datastore.RegisterGCMetrics(registerer) if err != nil { - return collectors, err + return collectors, gcMetricsUnregister, err } - collectors = append(collectors, gcCollectors...) + gcMetricsUnregister = unregisterGC } - return collectors, nil + return collectors, gcMetricsUnregister, nil } diff --git a/internal/datastore/proxy/checkingreplicated.go b/internal/datastore/proxy/checkingreplicated.go index 39fe7c79c9..fcfaa04f8d 100644 --- a/internal/datastore/proxy/checkingreplicated.go +++ b/internal/datastore/proxy/checkingreplicated.go @@ -7,7 +7,6 @@ import ( "sync/atomic" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" @@ -16,21 +15,21 @@ import ( ) var ( - checkingReplicatedTotalReaderCount = promauto.NewCounter(prometheus.CounterOpts{ + checkingReplicatedTotalReaderCount = prometheus.NewCounter(prometheus.CounterOpts{ Namespace: "spicedb", Subsystem: "datastore_replica", Name: "checking_replicated_reader_total", Help: "total number of readers created by the checking replica proxy", }) - checkingReplicatedReplicaReaderCount = promauto.NewCounterVec(prometheus.CounterOpts{ + checkingReplicatedReplicaReaderCount = prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: "spicedb", Subsystem: "datastore_replica", Name: "checking_replicated_replica_reader_total", Help: "number of readers created by the checking replica proxy that are using the replica", }, []string{"replica"}) - readReplicatedSelectedReplicaCount = promauto.NewCounterVec(prometheus.CounterOpts{ + readReplicatedSelectedReplicaCount = prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: "spicedb", Subsystem: "datastore_replica", Name: "selected_replica_total", @@ -38,6 +37,17 @@ var ( }, []string{"replica"}) ) +// RegisterCheckingReplicatedMetrics registers the checking replicated datastore proxy prometheus metrics. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterCheckingReplicatedMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register replicated datastore metrics", + checkingReplicatedTotalReaderCount, + checkingReplicatedReplicaReaderCount, + readReplicatedSelectedReplicaCount) + + return err +} + // NewCheckingReplicatedDatastore creates a new datastore that writes to the provided primary and reads // from the provided replicas. The replicas are chosen in a round-robin fashion. If a replica does // not have the requested revision, the primary is used instead. diff --git a/internal/datastore/proxy/observable.go b/internal/datastore/proxy/observable.go index a54cc76f22..a664ac295a 100644 --- a/internal/datastore/proxy/observable.go +++ b/internal/datastore/proxy/observable.go @@ -4,7 +4,6 @@ import ( "context" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -22,7 +21,7 @@ import ( var ( tracer = otel.Tracer("spicedb/datastore/proxy/observable") - loadedRelationshipCount = promauto.NewHistogram(prometheus.HistogramOpts{ + loadedRelationshipCount = prometheus.NewHistogram(prometheus.HistogramOpts{ Namespace: "spicedb", Subsystem: "datastore", Name: "loaded_relationships_count", @@ -30,7 +29,7 @@ var ( Help: "Histogram of the number of relationships loaded per individual datastore query. High p99 values (>1000) may indicate broad permission checks or missing filters.", }) - queryLatency = promauto.NewHistogramVec(prometheus.HistogramOpts{ + queryLatency = prometheus.NewHistogramVec(prometheus.HistogramOpts{ Namespace: "spicedb", Subsystem: "datastore", Name: "query_latency", @@ -41,6 +40,16 @@ var ( }) ) +// RegisterMetrics registers the observable datastore proxy prometheus metrics with the provided registerer. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register observable datastore proxy metrics", + loadedRelationshipCount, + queryLatency) + + return err +} + func filterToAttributes(filter *v1.RelationshipFilter) []attribute.KeyValue { attrs := []attribute.KeyValue{common.ObjNamespaceNameKey.String(filter.ResourceType)} if filter.OptionalResourceId != "" { diff --git a/internal/datastore/proxy/schemacaching/watchingcache.go b/internal/datastore/proxy/schemacaching/watchingcache.go index f7fa5b3f4c..16fb29a67a 100644 --- a/internal/datastore/proxy/schemacaching/watchingcache.go +++ b/internal/datastore/proxy/schemacaching/watchingcache.go @@ -56,8 +56,17 @@ var definitionsReadTotalCounter = prometheus.NewCounterVec(prometheus.CounterOpt const maximumRetryCount = 10 -func init() { - prometheus.MustRegister(namespacesFallbackModeGauge, caveatsFallbackModeGauge, schemaCacheRevisionGauge, definitionsReadCachedCounter, definitionsReadTotalCounter) +// RegisterMetrics registers the watching schema cache prometheus metrics with the provided registerer. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register watching cache metrics", + namespacesFallbackModeGauge, + caveatsFallbackModeGauge, + schemaCacheRevisionGauge, + definitionsReadCachedCounter, + definitionsReadTotalCounter) + + return err } // watchingCachingProxy is a datastore proxy that caches schema (namespaces and caveat definitions) diff --git a/internal/datastore/proxy/strictreplicated.go b/internal/datastore/proxy/strictreplicated.go index 5d664a08cc..e94452bf3d 100644 --- a/internal/datastore/proxy/strictreplicated.go +++ b/internal/datastore/proxy/strictreplicated.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" "github.com/authzed/spicedb/internal/datastore/common" log "github.com/authzed/spicedb/internal/logging" @@ -17,14 +16,14 @@ import ( ) var ( - strictReadReplicatedTotalQueryCount = promauto.NewCounter(prometheus.CounterOpts{ + strictReadReplicatedTotalQueryCount = prometheus.NewCounter(prometheus.CounterOpts{ Namespace: "spicedb", Subsystem: "datastore_replica", Name: "strict_replicated_query_total", Help: "total number of reads made by the strict read replicated datastore", }) - strictReadReplicatedFallbackQueryCount = promauto.NewCounterVec(prometheus.CounterOpts{ + strictReadReplicatedFallbackQueryCount = prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: "spicedb", Subsystem: "datastore_replica", Name: "strict_replicated_fallback_query_total", @@ -32,6 +31,16 @@ var ( }, []string{"replica"}) ) +// RegisterStrictReplicatedMetrics registers the strict replicated datastore proxy prometheus metrics. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterStrictReplicatedMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register strict replicated datastore metrics", + strictReadReplicatedTotalQueryCount, + strictReadReplicatedFallbackQueryCount) + + return err +} + // NewStrictReplicatedDatastore creates a new datastore that writes to the provided primary and reads // from the provided replicas. The replicas are chosen in a round-robin fashion. If a replica does // not have the requested revision, the primary is used instead. diff --git a/internal/dispatch/caching/caching.go b/internal/dispatch/caching/caching.go index 25ae26bb10..96db3d1900 100644 --- a/internal/dispatch/caching/caching.go +++ b/internal/dispatch/caching/caching.go @@ -18,6 +18,7 @@ import ( "github.com/authzed/spicedb/internal/dispatch/keys" "github.com/authzed/spicedb/internal/telemetry/otelconv" "github.com/authzed/spicedb/pkg/cache" + ds "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/middleware/nodeid" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" ) @@ -40,6 +41,7 @@ type Dispatcher struct { lookupResourcesFromCacheCounter prometheus.Counter lookupSubjectsTotalCounter prometheus.Counter lookupSubjectsFromCacheCounter prometheus.Counter + registerer prometheus.Registerer queryPlanTotalCounter *prometheus.CounterVec queryPlanFromCacheCounter *prometheus.CounterVec } @@ -55,7 +57,7 @@ func DispatchTestCache(t testing.TB) cache.Cache[keys.DispatchCacheKey, any] { // NewCachingDispatcher creates a new dispatch.Dispatcher which delegates // dispatch requests and caches the responses when possible and desirable. -func NewCachingDispatcher(cacheInst cache.Cache[keys.DispatchCacheKey, any], metricsEnabled bool, prometheusSubsystem string, keyHandler keys.Handler) (*Dispatcher, error) { +func NewCachingDispatcher(cacheInst cache.Cache[keys.DispatchCacheKey, any], metricsEnabled bool, registerer prometheus.Registerer, prometheusSubsystem string, keyHandler keys.Handler) (*Dispatcher, error) { if cacheInst == nil { cacheInst = cache.NoopCache[keys.DispatchCacheKey, any]() } @@ -92,6 +94,7 @@ func NewCachingDispatcher(cacheInst cache.Cache[keys.DispatchCacheKey, any], met Name: "lookup_subjects_total", Help: "Total number of LookupSubjects dispatch requests processed.", }) + lookupSubjectsFromCacheCounter := prometheus.NewCounter(prometheus.CounterOpts{ Namespace: prometheusNamespace, Subsystem: prometheusSubsystem, @@ -113,35 +116,16 @@ func NewCachingDispatcher(cacheInst cache.Cache[keys.DispatchCacheKey, any], met }, []string{"operation"}) if metricsEnabled && prometheusSubsystem != "" { - err := prometheus.Register(checkTotalCounter) - if err != nil { - return nil, fmt.Errorf(errCachingInitialization, err) - } - err = prometheus.Register(checkFromCacheCounter) - if err != nil { - return nil, fmt.Errorf(errCachingInitialization, err) - } - err = prometheus.Register(lookupResourcesTotalCounter) - if err != nil { - return nil, fmt.Errorf(errCachingInitialization, err) - } - err = prometheus.Register(lookupResourcesFromCacheCounter) - if err != nil { - return nil, fmt.Errorf(errCachingInitialization, err) - } - err = prometheus.Register(lookupSubjectsTotalCounter) - if err != nil { - return nil, fmt.Errorf(errCachingInitialization, err) - } - err = prometheus.Register(lookupSubjectsFromCacheCounter) - if err != nil { - return nil, fmt.Errorf(errCachingInitialization, err) - } - err = prometheus.Register(queryPlanTotalCounter) - if err != nil { - return nil, fmt.Errorf(errCachingInitialization, err) - } - err = prometheus.Register(queryPlanFromCacheCounter) + _, err := ds.RegisterPrometheusCollectors(registerer, prometheusSubsystem+": failed to register caching dispatcher metrics", + checkTotalCounter, + checkFromCacheCounter, + lookupResourcesTotalCounter, + lookupResourcesFromCacheCounter, + lookupSubjectsTotalCounter, + lookupSubjectsFromCacheCounter, + queryPlanTotalCounter, + queryPlanFromCacheCounter, + ) if err != nil { return nil, fmt.Errorf(errCachingInitialization, err) } @@ -161,6 +145,7 @@ func NewCachingDispatcher(cacheInst cache.Cache[keys.DispatchCacheKey, any], met lookupResourcesFromCacheCounter: lookupResourcesFromCacheCounter, lookupSubjectsTotalCounter: lookupSubjectsTotalCounter, lookupSubjectsFromCacheCounter: lookupSubjectsFromCacheCounter, + registerer: registerer, queryPlanTotalCounter: queryPlanTotalCounter, queryPlanFromCacheCounter: queryPlanFromCacheCounter, }, nil @@ -482,14 +467,16 @@ func (cd *Dispatcher) dispatchQueryPlanCheckCached(req *v1.DispatchQueryPlanRequ } func (cd *Dispatcher) Close() error { - prometheus.Unregister(cd.checkTotalCounter) - prometheus.Unregister(cd.checkFromCacheCounter) - prometheus.Unregister(cd.lookupResourcesTotalCounter) - prometheus.Unregister(cd.lookupResourcesFromCacheCounter) - prometheus.Unregister(cd.lookupSubjectsTotalCounter) - prometheus.Unregister(cd.lookupSubjectsFromCacheCounter) - prometheus.Unregister(cd.queryPlanTotalCounter) - prometheus.Unregister(cd.queryPlanFromCacheCounter) + if cd.registerer != nil { + cd.registerer.Unregister(cd.checkTotalCounter) + cd.registerer.Unregister(cd.checkFromCacheCounter) + cd.registerer.Unregister(cd.lookupResourcesTotalCounter) + cd.registerer.Unregister(cd.lookupResourcesFromCacheCounter) + cd.registerer.Unregister(cd.lookupSubjectsFromCacheCounter) + cd.registerer.Unregister(cd.lookupSubjectsTotalCounter) + cd.registerer.Unregister(cd.queryPlanTotalCounter) + cd.registerer.Unregister(cd.queryPlanFromCacheCounter) + } if cache := cd.c; cache != nil { cache.Close() } diff --git a/internal/dispatch/caching/cachingdispatch_test.go b/internal/dispatch/caching/cachingdispatch_test.go index 56b3b2b7b8..90a2cc18a1 100644 --- a/internal/dispatch/caching/cachingdispatch_test.go +++ b/internal/dispatch/caching/cachingdispatch_test.go @@ -125,7 +125,7 @@ func TestMaxDepthCaching(t *testing.T) { } } - dispatch, err := NewCachingDispatcher(DispatchTestCache(t), false, "", nil) + dispatch, err := NewCachingDispatcher(DispatchTestCache(t), false, prometheus.DefaultRegisterer, "", nil) dispatch.SetDelegate(delegate) require.NoError(err) defer dispatch.Close() @@ -177,7 +177,7 @@ func TestConcurrentDebugInfoAccess(t *testing.T) { }, }, nil) - dispatcher, err := NewCachingDispatcher(DispatchTestCache(t), false, "", nil) + dispatcher, err := NewCachingDispatcher(DispatchTestCache(t), false, prometheus.DefaultRegisterer, "", nil) require.NoError(err) dispatcher.SetDelegate(delegate) t.Cleanup(func() { @@ -294,8 +294,9 @@ func TestDispatchQueryPlanRecordsCachingMetrics(t *testing.T) { // Use a unique subsystem so we don't collide with other tests in the global // default Prometheus registry. subsystem := fmt.Sprintf("test_caching_%d", time.Now().UnixNano()) + registry := prometheus.NewRegistry() - dispatcher, err := NewCachingDispatcher(DispatchTestCache(t), true, subsystem, nil) + dispatcher, err := NewCachingDispatcher(DispatchTestCache(t), true, registry, subsystem, nil) require.NoError(t, err) t.Cleanup(func() { _ = dispatcher.Close() }) @@ -326,7 +327,7 @@ func TestDispatchQueryPlanRecordsCachingMetrics(t *testing.T) { require.NoError(t, dispatcher.DispatchQueryPlan(req, stream2)) require.Len(t, stream2.Results(), 1) - gatherer := prometheus.DefaultGatherer + gatherer := registry totalCheck := sumOperationCounter(t, gatherer, "spicedb_"+subsystem+"_query_plan_total", "check") require.InEpsilon(t, float64(2), totalCheck, 1e-9, "query_plan_total{operation=check} should bump for each plan dispatch") diff --git a/internal/dispatch/cluster/cluster.go b/internal/dispatch/cluster/cluster.go index d27fe898f1..a9a484ccca 100644 --- a/internal/dispatch/cluster/cluster.go +++ b/internal/dispatch/cluster/cluster.go @@ -4,6 +4,8 @@ import ( "fmt" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/dispatch/caching" "github.com/authzed/spicedb/internal/dispatch/graph" @@ -20,6 +22,7 @@ type Option func(*optionState) type optionState struct { metricsEnabled bool prometheusSubsystem string + prometheusRegisterer prometheus.Registerer cache cache.Cache[keys.DispatchCacheKey, any] concurrencyLimits graph.ConcurrencyLimits remoteDispatchTimeout time.Duration @@ -148,6 +151,7 @@ func NewClusterDispatcher(dispatch dispatch.Dispatcher, options ...Option) (disp TypeSet: cts, DispatchChunkSize: opts.dispatchChunkSize, RelationshipChunkCache: relationshipChunkCache, + PrometheusRegisterer: opts.prometheusRegisterer, QueryPlanMetadata: opts.queryPlanMetadata, } clusterDispatch, err := graph.NewDispatcher(dispatch, params) @@ -159,7 +163,7 @@ func NewClusterDispatcher(dispatch dispatch.Dispatcher, options ...Option) (disp opts.prometheusSubsystem = "dispatch" } - cachingClusterDispatch, err := caching.NewCachingDispatcher(opts.cache, opts.metricsEnabled, opts.prometheusSubsystem, &keys.CanonicalKeyHandler{}) + cachingClusterDispatch, err := caching.NewCachingDispatcher(opts.cache, opts.metricsEnabled, opts.prometheusRegisterer, opts.prometheusSubsystem, &keys.CanonicalKeyHandler{}) if err != nil { return nil, err } diff --git a/internal/dispatch/combined/combined.go b/internal/dispatch/combined/combined.go index 5f9b3c7214..b2cb5f3f49 100644 --- a/internal/dispatch/combined/combined.go +++ b/internal/dispatch/combined/combined.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/prometheus/client_golang/prometheus" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -29,6 +30,7 @@ type Option func(*optionState) type optionState struct { metricsEnabled bool + prometheusRegisterer prometheus.Registerer prometheusSubsystem string upstreamAddr string upstreamCAPath string @@ -64,6 +66,13 @@ func MetricsEnabled(enabled bool) Option { } } +// PrometheusRegisterer sets the prometheus registerer for dispatcher metrics. +func PrometheusRegisterer(registerer prometheus.Registerer) Option { + return func(state *optionState) { + state.prometheusRegisterer = registerer + } +} + // PrometheusSubsystem sets the subsystem name for the prometheus metrics func PrometheusSubsystem(name string) Option { return func(state *optionState) { @@ -202,7 +211,7 @@ func NewDispatcher(options ...Option) (dispatch.Dispatcher, error) { opts.prometheusSubsystem = "dispatch_client" } - cachingRedispatch, err := caching.NewCachingDispatcher(opts.cache, opts.metricsEnabled, opts.prometheusSubsystem, &keys.CanonicalKeyHandler{}) + cachingRedispatch, err := caching.NewCachingDispatcher(opts.cache, opts.metricsEnabled, opts.prometheusRegisterer, opts.prometheusSubsystem, &keys.CanonicalKeyHandler{}) if err != nil { return nil, err } @@ -244,6 +253,7 @@ func NewDispatcher(options ...Option) (dispatch.Dispatcher, error) { TypeSet: caveattypes.TypeSetOrDefault(opts.caveatTypeSet), DispatchChunkSize: chunkSize, RelationshipChunkCache: relationshipChunkCache, + PrometheusRegisterer: opts.prometheusRegisterer, QueryPlanMetadata: opts.queryPlanMetadata, } redispatch, err = graph.NewDispatcher(cachingRedispatch, params) @@ -311,7 +321,7 @@ func NewDispatcher(options ...Option) (dispatch.Dispatcher, error) { re, err := remote.NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, remote.ClusterDispatcherConfig{ KeyHandler: &keys.CanonicalKeyHandler{}, - DispatchOverallTimeout: opts.remoteDispatchTimeout, + DispatchOverallTimeout: opts.remoteDispatchTimeout, Registerer: opts.prometheusRegisterer, }, secondaryClients, secondaryExprs, opts.startingPrimaryHedgingDelay) if err != nil { return nil, err diff --git a/internal/dispatch/graph/check_test.go b/internal/dispatch/graph/check_test.go index 27dcff4cc9..9890e72865 100644 --- a/internal/dispatch/graph/check_test.go +++ b/internal/dispatch/graph/check_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/datastore/common" @@ -2074,7 +2075,7 @@ func newLocalDispatcher(t testing.TB) (context.Context, dispatch.Dispatcher, dat _ = dispatch.Close() }) - cachingDispatcher, err := caching.NewCachingDispatcher(caching.DispatchTestCache(t), false, "", &keys.CanonicalKeyHandler{}) + cachingDispatcher, err := caching.NewCachingDispatcher(caching.DispatchTestCache(t), false, prometheus.DefaultRegisterer, "", &keys.CanonicalKeyHandler{}) require.NoError(t, err) cachingDispatcher.SetDelegate(dispatch) t.Cleanup(func() { @@ -2099,7 +2100,7 @@ func newLocalDispatcherWithSchemaAndRels(t testing.TB, schema string, rels []tup dispatch.Close() }) - cachingDispatcher, err := caching.NewCachingDispatcher(caching.DispatchTestCache(t), false, "", &keys.CanonicalKeyHandler{}) + cachingDispatcher, err := caching.NewCachingDispatcher(caching.DispatchTestCache(t), false, prometheus.DefaultRegisterer, "", &keys.CanonicalKeyHandler{}) require.NoError(t, err) cachingDispatcher.SetDelegate(dispatch) t.Cleanup(func() { diff --git a/internal/dispatch/graph/graph.go b/internal/dispatch/graph/graph.go index 6c9bce9512..a453fb2507 100644 --- a/internal/dispatch/graph/graph.go +++ b/internal/dispatch/graph/graph.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/rs/zerolog" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -94,6 +95,7 @@ type DispatcherParameters struct { DispatchChunkSize uint16 TypeSet *caveattypes.TypeSet RelationshipChunkCache cache.Cache[cache.StringKey, any] + PrometheusRegisterer prometheus.Registerer // QueryPlanMetadata is the shared count-stats store consulted and updated // by the receiver-side DispatchQueryPlan handler (advisor at compile, @@ -168,7 +170,7 @@ func NewLocalOnlyDispatcher(parameters DispatcherParameters) (dispatch.Dispatche concurrencyLimits := limitsOrDefaults(parameters.ConcurrencyLimits, defaultConcurrencyLimit) chunkSize := parameters.DispatchChunkSize - d.checker = graph.NewConcurrentChecker(d, concurrencyLimits.Check, chunkSize) + d.checker = graph.NewConcurrentChecker(d, concurrencyLimits.Check, chunkSize, parameters.PrometheusRegisterer) d.expander = graph.NewConcurrentExpander(d) d.lookupSubjectsHandler = graph.NewConcurrentLookupSubjects(d, concurrencyLimits.LookupSubjects, chunkSize) d.lookupResourcesHandler2 = graph.NewCursoredLookupResources2(d, d, typeSet, concurrencyLimits.LookupResources, chunkSize) @@ -193,7 +195,7 @@ func NewDispatcher(redispatcher dispatch.Dispatcher, parameters DispatcherParame log.Warn().Msgf("Dispatcher: dispatchChunkSize not set, defaulting to %d", chunkSize) } - checker := graph.NewConcurrentChecker(redispatcher, concurrencyLimits.Check, chunkSize) + checker := graph.NewConcurrentChecker(redispatcher, concurrencyLimits.Check, chunkSize, parameters.PrometheusRegisterer) expander := graph.NewConcurrentExpander(redispatcher) lookupSubjectsHandler := graph.NewConcurrentLookupSubjects(redispatcher, concurrencyLimits.LookupSubjects, chunkSize) lookupResourcesHandler2 := graph.NewCursoredLookupResources2(redispatcher, redispatcher, typeSet, concurrencyLimits.LookupResources, chunkSize) diff --git a/internal/dispatch/remote/cluster.go b/internal/dispatch/remote/cluster.go index e12bbbdef1..945597a383 100644 --- a/internal/dispatch/remote/cluster.go +++ b/internal/dispatch/remote/cluster.go @@ -26,6 +26,7 @@ import ( "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/dispatch/keys" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" "github.com/authzed/spicedb/pkg/spiceerrors" @@ -77,13 +78,6 @@ const defaultTDigestCompression = float64(1000) var supportsSecondaries = []string{"check", "lookupresources", "lookupsubjects"} -func init() { - prometheus.MustRegister(dispatchCounter) - prometheus.MustRegister(hedgeWaitHistogram) - prometheus.MustRegister(hedgeActualWaitHistogram) - prometheus.MustRegister(primaryDispatch) -} - type ClusterClient interface { DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest, opts ...grpc.CallOption) (*v1.DispatchCheckResponse, error) DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest, opts ...grpc.CallOption) (*v1.DispatchExpandResponse, error) @@ -100,6 +94,9 @@ type ClusterDispatcherConfig struct { // DispatchOverallTimeout is the maximum duration of a dispatched request // before it should timeout. DispatchOverallTimeout time.Duration + + // Registerer is the prometheus registerer to use for metrics. If nil, prometheus.DefaultRegisterer is used. + Registerer prometheus.Registerer } // SecondaryDispatch defines a struct holding a client and its name for secondary @@ -119,6 +116,11 @@ type SecondaryDispatch struct { // NewClusterDispatcher creates a dispatcher implementation that uses the provided client // to dispatch requests to peer nodes in the cluster. func NewClusterDispatcher(client ClusterClient, conn *grpc.ClientConn, config ClusterDispatcherConfig, secondaryDispatch map[string]SecondaryDispatch, secondaryDispatchExprs map[string]*DispatchExpr, startingPrimaryHedgingDelay time.Duration) (dispatch.Dispatcher, error) { + unregister, err := datastore.RegisterPrometheusCollectors(config.Registerer, "failed to register cluster dispatch metrics", dispatchCounter, hedgeWaitHistogram, hedgeActualWaitHistogram, primaryDispatch) + if err != nil { + return nil, err + } + keyHandler := config.KeyHandler if keyHandler == nil { keyHandler = &keys.DirectKeyHandler{} @@ -156,6 +158,7 @@ func NewClusterDispatcher(client ClusterClient, conn *grpc.ClientConn, config Cl secondaryDispatchExprs: secondaryDispatchExprs, secondaryInitialResponseDigests: secondaryInitialResponseDigests, supportedResourceSubjectTracker: newSupportedResourceSubjectTracker(), + prometheusUnregisterFunction: unregister, }, nil } @@ -169,6 +172,7 @@ type clusterDispatcher struct { secondaryDispatchExprs map[string]*DispatchExpr secondaryInitialResponseDigests map[string]*digestAndLock supportedResourceSubjectTracker *supportedResourceSubjectTracker + prometheusUnregisterFunction func() } // digestAndLock is a struct that holds a TDigest and a lock to protect it. diff --git a/internal/dispatch/remote/cluster_benchmark_test.go b/internal/dispatch/remote/cluster_benchmark_test.go index 8c5e3f310f..782f8aa63e 100644 --- a/internal/dispatch/remote/cluster_benchmark_test.go +++ b/internal/dispatch/remote/cluster_benchmark_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "google.golang.org/grpc" @@ -49,6 +50,7 @@ func BenchmarkSecondaryDispatching(b *testing.B) { config := ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: 30 * time.Second, + Registerer: prometheus.NewRegistry(), } parsed, err := ParseDispatchExpression("check", "['secondary']") diff --git a/internal/dispatch/remote/cluster_test.go b/internal/dispatch/remote/cluster_test.go index 766c14ae78..4be50aee52 100644 --- a/internal/dispatch/remote/cluster_test.go +++ b/internal/dispatch/remote/cluster_test.go @@ -12,6 +12,7 @@ import ( "github.com/caio/go-tdigest/v4" "github.com/ccoveille/go-safecast/v2" "github.com/dustin/go-humanize" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc" @@ -200,6 +201,7 @@ func TestDispatchTimeout(t *testing.T) { dispatcher, err := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: tc.timeout, + Registerer: prometheus.NewRegistry(), }, nil, nil, 0*time.Second) require.NoError(t, err) require.True(t, dispatcher.ReadyState().IsReady) @@ -355,6 +357,7 @@ func TestCheckSecondaryDispatch(t *testing.T) { dispatcher, err := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: 30 * time.Second, + Registerer: prometheus.NewRegistry(), }, map[string]SecondaryDispatch{ "secondary": {Name: "secondary", Client: v1.NewDispatchServiceClient(secondaryConn), MaximumPrimaryHedgingDelay: 5 * time.Millisecond}, }, map[string]*DispatchExpr{ @@ -650,6 +653,7 @@ func TestLRSecondaryDispatch(t *testing.T) { dispatcher, err := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: 30 * time.Second, + Registerer: prometheus.NewRegistry(), }, map[string]SecondaryDispatch{ "secondary": {Name: "secondary", Client: v1.NewDispatchServiceClient(secondaryConn), MaximumPrimaryHedgingDelay: 5 * time.Millisecond}, "tertiary": {Name: "tertiary", Client: v1.NewDispatchServiceClient(tertiaryConn), MaximumPrimaryHedgingDelay: 5 * time.Millisecond}, @@ -687,6 +691,7 @@ func TestLRDispatchFallbackToPrimary(t *testing.T) { dispatcher, err := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: 30 * time.Second, + Registerer: prometheus.NewRegistry(), }, map[string]SecondaryDispatch{ "secondary": {Name: "secondary", Client: v1.NewDispatchServiceClient(secondaryConn), MaximumPrimaryHedgingDelay: 5 * time.Millisecond}, }, map[string]*DispatchExpr{ @@ -787,6 +792,7 @@ func TestLSSecondaryDispatch(t *testing.T) { dispatcher, err := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: 30 * time.Second, + Registerer: prometheus.NewRegistry(), }, map[string]SecondaryDispatch{ "secondary": {Name: "secondary", Client: v1.NewDispatchServiceClient(secondaryConn), MaximumPrimaryHedgingDelay: 5 * time.Millisecond}, "tertiary": {Name: "tertiary", Client: v1.NewDispatchServiceClient(tertiaryConn), MaximumPrimaryHedgingDelay: 5 * time.Millisecond}, @@ -824,6 +830,7 @@ func TestLSDispatchFallbackToPrimary(t *testing.T) { dispatcher, err := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: 30 * time.Second, + Registerer: prometheus.NewRegistry(), }, map[string]SecondaryDispatch{ "secondary": {Name: "secondary", Client: v1.NewDispatchServiceClient(secondaryConn), MaximumPrimaryHedgingDelay: 5 * time.Millisecond}, }, map[string]*DispatchExpr{ @@ -865,6 +872,7 @@ func TestCheckUsesDelayByDefaultForPrimary(t *testing.T) { dispatcher, err := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: 30 * time.Second, + Registerer: prometheus.NewRegistry(), }, map[string]SecondaryDispatch{ "secondary": {Name: "secondary", Client: v1.NewDispatchServiceClient(secondaryConn), MaximumPrimaryHedgingDelay: 15 * time.Millisecond}, }, map[string]*DispatchExpr{ @@ -903,6 +911,7 @@ func TestStreamingDispatchDelayByDefaultForPrimary(t *testing.T) { dispatcher, err := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: 30 * time.Second, + Registerer: prometheus.NewRegistry(), }, map[string]SecondaryDispatch{ "secondary": {Name: "secondary", Client: v1.NewDispatchServiceClient(secondaryConn), MaximumPrimaryHedgingDelay: 15 * time.Millisecond}, }, map[string]*DispatchExpr{ @@ -950,6 +959,7 @@ func TestGetPrimaryWaitTime(t *testing.T) { d, err := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: 30 * time.Second, + Registerer: prometheus.NewRegistry(), }, map[string]SecondaryDispatch{ "secondary": {Name: "secondary", Client: v1.NewDispatchServiceClient(secondaryConn), MaximumPrimaryHedgingDelay: 5 * time.Millisecond}, }, map[string]*DispatchExpr{ @@ -994,6 +1004,7 @@ func TestCheckUsesMaximumDelayByDefaultForPrimary(t *testing.T) { dispatcher, err := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: 30 * time.Second, + Registerer: prometheus.NewRegistry(), }, map[string]SecondaryDispatch{ "secondary": {Name: "secondary", Client: v1.NewDispatchServiceClient(secondaryConn), MaximumPrimaryHedgingDelay: 0 * time.Millisecond}, }, map[string]*DispatchExpr{ @@ -1134,6 +1145,7 @@ func TestCheckToUnsupportedRemovesHedgingDelay(t *testing.T) { dispatcher, err := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: 30 * time.Second, + Registerer: prometheus.NewRegistry(), }, map[string]SecondaryDispatch{ "secondary": {Name: "secondary", Client: v1.NewDispatchServiceClient(secondaryConn), MaximumPrimaryHedgingDelay: 5 * time.Millisecond}, }, map[string]*DispatchExpr{ @@ -1267,6 +1279,7 @@ func TestPrimaryDispatcherErrorReturned(t *testing.T) { dispatcher, err := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: 30 * time.Second, + Registerer: prometheus.NewRegistry(), }, map[string]SecondaryDispatch{ "secondary": {Name: "secondary", Client: v1.NewDispatchServiceClient(secondaryConn), MaximumPrimaryHedgingDelay: 0}, // No delay so primary runs immediately }, map[string]*DispatchExpr{ @@ -1317,7 +1330,7 @@ func TestReadyStateConnecting(t *testing.T) { dispatcher, err := NewClusterDispatcher( v1.NewDispatchServiceClient(conn), conn, - ClusterDispatcherConfig{KeyHandler: &keys.DirectKeyHandler{}}, + ClusterDispatcherConfig{KeyHandler: &keys.DirectKeyHandler{}, Registerer: prometheus.NewRegistry()}, nil, nil, 0, ) t.Cleanup(func() { dispatcher.Close() }) diff --git a/internal/dispatch/singleflight/singleflight.go b/internal/dispatch/singleflight/singleflight.go index ae03f29127..8505ba05c6 100644 --- a/internal/dispatch/singleflight/singleflight.go +++ b/internal/dispatch/singleflight/singleflight.go @@ -7,7 +7,6 @@ import ( "strconv" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "resenje.org/singleflight" @@ -15,11 +14,12 @@ import ( "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/dispatch/keys" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" ) var ( - singleFlightCount = promauto.NewCounterVec(singleFlightCountConfig, []string{"method", "shared"}) + singleFlightCount = prometheus.NewCounterVec(singleFlightCountConfig, []string{"method", "shared"}) singleFlightCountConfig = prometheus.CounterOpts{ Namespace: "spicedb", Subsystem: "dispatch", @@ -28,6 +28,13 @@ var ( } ) +// RegisterMetrics registers singleflight prometheus metrics with the given registerer. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register singleflight metrics", singleFlightCount) + return err +} + func New(delegate dispatch.Dispatcher, handler keys.Handler) dispatch.Dispatcher { return &Dispatcher{ delegate: delegate, diff --git a/internal/fdw/pgserver_e2e_test.go b/internal/fdw/pgserver_e2e_test.go index 410af79274..1b7bf52430 100644 --- a/internal/fdw/pgserver_e2e_test.go +++ b/internal/fdw/pgserver_e2e_test.go @@ -17,6 +17,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" @@ -1017,11 +1018,13 @@ func runSpiceDB(t *testing.T) *authzed.Client { PresharedSecureKey: []string{"sometestkey"}, DatastoreConfig: datastore.Config{ Engine: "memory", + PrometheusRegisterer: prometheus.NewRegistry(), RevisionQuantization: 5 * time.Second, GCWindow: 5 * time.Minute, FilterMaximumIDCount: 100, }, } + config.PrometheusRegisterer = prometheus.NewRegistry() ctx := t.Context() ctx, cancel := context.WithCancel(ctx) diff --git a/internal/gateway/gateway.go b/internal/gateway/gateway.go index f8ff2cc936..9671f63231 100644 --- a/internal/gateway/gateway.go +++ b/internal/gateway/gateway.go @@ -11,7 +11,6 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" grpcfilters "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/filters" @@ -30,15 +29,23 @@ import ( "github.com/authzed/grpcutil" "github.com/authzed/spicedb/internal/grpchelpers" + "github.com/authzed/spicedb/pkg/datastore" ) -var histogram = promauto.NewHistogramVec(prometheus.HistogramOpts{ +var histogram = prometheus.NewHistogramVec(prometheus.HistogramOpts{ Namespace: "spicedb", Subsystem: "rest_gateway", Name: "request_duration_seconds", Help: "A histogram of the duration spent processing requests to the SpiceDB REST Gateway.", }, []string{"method"}) +// RegisterMetrics registers the REST gateway prometheus metrics with the provided registerer. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register gateway metrics", histogram) + return err +} + // NewHandler creates an REST gateway HTTP CloserHandler with the provided upstream // configuration. func NewHandler(ctx context.Context, upstreamAddr, upstreamTLSCertPath string) (*CloserHandler, error) { diff --git a/internal/graph/check.go b/internal/graph/check.go index 6d9a9938e1..0976ed0372 100644 --- a/internal/graph/check.go +++ b/internal/graph/check.go @@ -41,13 +41,11 @@ var dispatchChunkCountHistogram = prometheus.NewHistogram(prometheus.HistogramOp const noOriginalRelation = "" -func init() { - prometheus.MustRegister(dispatchChunkCountHistogram) -} - // NewConcurrentChecker creates an instance of ConcurrentChecker. -func NewConcurrentChecker(d dispatch.Check, concurrencyLimit uint16, dispatchChunkSize uint16) *ConcurrentChecker { - return &ConcurrentChecker{d, concurrencyLimit, dispatchChunkSize} +func NewConcurrentChecker(d dispatch.Check, concurrencyLimit uint16, dispatchChunkSize uint16, registerer prometheus.Registerer) *ConcurrentChecker { + unregister, _ := datastore.RegisterPrometheusCollectors(registerer, "failed to register dispatch metrics", dispatchChunkCountHistogram) + + return &ConcurrentChecker{d, concurrencyLimit, dispatchChunkSize, unregister} } // ConcurrentChecker exposes a method to perform Check requests, and delegates subproblems to the @@ -56,6 +54,7 @@ type ConcurrentChecker struct { d dispatch.Check concurrencyLimit uint16 dispatchChunkSize uint16 + unregister func() } // ValidatedCheckRequest represents a request after it has been validated and parsed for internal diff --git a/internal/middleware/memoryprotection/memory_protection.go b/internal/middleware/memoryprotection/memory_protection.go index b7b6312256..e5a4b9d3a5 100644 --- a/internal/middleware/memoryprotection/memory_protection.go +++ b/internal/middleware/memoryprotection/memory_protection.go @@ -8,25 +8,32 @@ import ( middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" "go.opentelemetry.io/otel" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" ) var tracer = otel.Tracer("spicedb/internal/middleware/memory_protection") // RequestsProcessed tracks requests that were processed by this middleware. -var RequestsProcessed = promauto.NewCounterVec(prometheus.CounterOpts{ +var RequestsProcessed = prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: "spicedb", Subsystem: "memory_middleware", Name: "requests_processed_total", Help: "Total requests processed by the memory protection middleware (flag --memory-protection-enabled)", }, []string{"endpoint", "accepted"}) +// RegisterMetrics registers the memory protection middleware prometheus metrics with the provided registerer. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register memory protection metrics", RequestsProcessed) + return err +} + type MemoryProtectionMiddleware struct { currentMemoryUsageProvider MemoryUsageProvider } diff --git a/internal/middleware/perfinsights/perfinsights.go b/internal/middleware/perfinsights/perfinsights.go index 0eef6229b8..37db67ddcc 100644 --- a/internal/middleware/perfinsights/perfinsights.go +++ b/internal/middleware/perfinsights/perfinsights.go @@ -9,7 +9,6 @@ import ( "github.com/ccoveille/go-safecast/v2" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" "github.com/rs/zerolog/log" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" @@ -18,6 +17,7 @@ import ( "github.com/authzed/ctxkey" "github.com/authzed/grpcutil" + "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/spiceerrors" ) @@ -61,7 +61,7 @@ func NoLabels() APIShapeLabels { // // To use make use of native histograms, a special flag must be set on Prometheus: // https://prometheus.io/docs/prometheus/latest/feature_flags/#native-histograms -var APIShapeLatency = promauto.NewHistogramVec(prometheus.HistogramOpts{ +var APIShapeLatency = prometheus.NewHistogramVec(prometheus.HistogramOpts{ Namespace: "spicedb", Subsystem: "perf_insights", Name: "api_shape_latency_seconds", @@ -72,6 +72,13 @@ var APIShapeLatency = promauto.NewHistogramVec(prometheus.HistogramOpts{ var tracer = otel.Tracer("spicedb/internal/middleware") +// RegisterMetrics registers the performance insights prometheus metrics with the provided registerer. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register perfinsights metrics", APIShapeLatency) + return err +} + // ShapeBuilder is a function that returns a slice of strings representing the shape of the API call. // This is used to report the shape of the API call to Prometheus. type ShapeBuilder func() APIShapeLabels diff --git a/internal/middleware/usagemetrics/usagemetrics.go b/internal/middleware/usagemetrics/usagemetrics.go index 3f0e6fbe9a..e88a9e2cf3 100644 --- a/internal/middleware/usagemetrics/usagemetrics.go +++ b/internal/middleware/usagemetrics/usagemetrics.go @@ -8,7 +8,6 @@ import ( "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" "go.opentelemetry.io/otel" "google.golang.org/grpc" @@ -17,6 +16,7 @@ import ( "github.com/authzed/grpcutil" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" ) @@ -28,7 +28,7 @@ var ( // DispatchedCountHistogram is the metric that SpiceDB uses to keep track // of the number of downstream dispatches that are performed to answer a // single query. - DispatchedCountHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{ + DispatchedCountHistogram = prometheus.NewHistogramVec(prometheus.HistogramOpts{ Namespace: "spicedb", Subsystem: "services", Name: "dispatches", @@ -39,6 +39,13 @@ var ( tracer = otel.Tracer("spicedb/internal/middleware") ) +// RegisterMetrics registers the usagemetrics prometheus metrics with the provided registerer. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register usagemetrics", DispatchedCountHistogram) + return err +} + type reporter struct{} func (r *reporter) ServerReporter(ctx context.Context, callMeta interceptors.CallMeta) (interceptors.Reporter, context.Context) { diff --git a/internal/services/integrationtesting/benchmark_dispatchqueryplan_test.go b/internal/services/integrationtesting/benchmark_dispatchqueryplan_test.go index 7cfb93a9a3..8df714d73a 100644 --- a/internal/services/integrationtesting/benchmark_dispatchqueryplan_test.go +++ b/internal/services/integrationtesting/benchmark_dispatchqueryplan_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/caveats" @@ -180,7 +181,7 @@ func (h *dispatchQueryPlanHandle) newCachingDispatcher(b *testing.B) *caching.Di c, err := cache.NewStandardCache[keys.DispatchCacheKey, any](cacheConfig) require.NoError(b, err) - cd, err := caching.NewCachingDispatcher(c, false, "bench", &keys.DirectKeyHandler{}) + cd, err := caching.NewCachingDispatcher(c, false, prometheus.DefaultRegisterer, "bench", &keys.DirectKeyHandler{}) require.NoError(b, err) lpd := &localQueryPlanDispatcher{handle: h} diff --git a/internal/services/integrationtesting/certtest/cert_test.go b/internal/services/integrationtesting/certtest/cert_test.go index 4f540f0b84..b032fe32b0 100644 --- a/internal/services/integrationtesting/certtest/cert_test.go +++ b/internal/services/integrationtesting/certtest/cert_test.go @@ -16,6 +16,7 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "go.uber.org/goleak" "google.golang.org/grpc" @@ -38,7 +39,12 @@ import ( func TestCertRotation(t *testing.T) { t.Cleanup(func() { - goleak.VerifyNone(t, testutil.GoLeakIgnores()...) + goleak.VerifyNone(t, append(testutil.GoLeakIgnores(), + goleak.IgnoreCurrent(), + goleak.IgnoreTopFunction("github.com/outcaste-io/ristretto.(*lfuPolicy).processItems"), + goleak.IgnoreTopFunction("github.com/outcaste-io/ristretto.(*Cache).processItems"), + goleak.IgnoreTopFunction("github.com/fsnotify/fsnotify.(*inotify).readEvents"), + )...) }) const ( @@ -125,7 +131,7 @@ func TestCertRotation(t *testing.T) { require.NoError(t, err) ctx, cancel := context.WithCancel(t.Context()) - srv, err := server.NewConfigWithOptionsAndDefaults( + srvConfig := server.NewConfigWithOptionsAndDefaults( server.WithDatastore(ds), server.WithDispatcher(dispatcher), server.WithDispatchMaxDepth(50), @@ -182,7 +188,9 @@ func TestCertRotation(t *testing.T) { }, }, }), - ).Complete(ctx) + ) + srvConfig.PrometheusRegisterer = prometheus.NewRegistry() + srv, err := srvConfig.Complete(ctx) require.NoError(t, err) wait := make(chan error, 1) diff --git a/internal/services/integrationtesting/consistencytestutil/clusteranddata.go b/internal/services/integrationtesting/consistencytestutil/clusteranddata.go index 8cd8ec47bb..c28814b974 100644 --- a/internal/services/integrationtesting/consistencytestutil/clusteranddata.go +++ b/internal/services/integrationtesting/consistencytestutil/clusteranddata.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "google.golang.org/grpc" @@ -75,7 +76,7 @@ func CreateDispatcherForTesting(t *testing.T, withCaching bool) dispatch.Dispatc dispatcher, err := graph.NewLocalOnlyDispatcher(params) require.NoError(err) if withCaching { - cachingDispatcher, err := caching.NewCachingDispatcher(nil, false, "", &keys.CanonicalKeyHandler{}) + cachingDispatcher, err := caching.NewCachingDispatcher(nil, false, prometheus.DefaultRegisterer, "", &keys.CanonicalKeyHandler{}) require.NoError(err) params2, err := graph.NewDefaultDispatcherParametersForTesting() diff --git a/internal/services/v1/debug_test.go b/internal/services/v1/debug_test.go index 757be9ad12..b174f65106 100644 --- a/internal/services/v1/debug_test.go +++ b/internal/services/v1/debug_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc" @@ -49,6 +50,12 @@ type debugCheckInfo struct { runDebugAssertions []rda } +func debugTestServerConfig() testserver.ServerConfig { + cfg := testserver.DefaultTestServerConfig + cfg.PrometheusRegisterer = prometheus.NewRegistry() + return cfg +} + func expectDebugFrames(permissionNames ...string) rda { return func(req *require.Assertions, debugInfo *v1.DebugInformation) { found := mapz.NewSet[string]() @@ -489,7 +496,8 @@ func TestCheckPermissionWithDebug(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - conn, cleanup, _, revision := testserver.NewTestServer(t, 5*time.Second, memdb.DisableGC, true, + conn, cleanup, _, revision := testserver.NewTestServerWithConfig(t, 5*time.Second, memdb.DisableGC, true, + debugTestServerConfig(), func(t testing.TB, ds datastore.Datastore) (datastore.Datastore, datastore.Revision) { return tf.DatastoreFromSchemaAndTestRelationships(t, ds, tc.schema, tc.relationships) }) @@ -883,7 +891,8 @@ func TestBulkCheckPermissionWithDebug(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { req := require.New(t) - conn, cleanup, _, revision := testserver.NewTestServer(t, 5*time.Second, memdb.DisableGC, true, + conn, cleanup, _, revision := testserver.NewTestServerWithConfig(t, 5*time.Second, memdb.DisableGC, true, + debugTestServerConfig(), func(t testing.TB, ds datastore.Datastore) (datastore.Datastore, datastore.Revision) { return tf.DatastoreFromSchemaAndTestRelationships(t, ds, tc.schema, tc.relationships) }) @@ -951,7 +960,8 @@ func TestLookupResourcesDebugTrace_LR3(t *testing.T) { } // NOTE: the default test server configuration selects LR3 - conn, cleanup, _, revision := testserver.NewTestServer(t, 5*time.Second, memdb.DisableGC, true, + conn, cleanup, _, revision := testserver.NewTestServerWithConfig(t, 5*time.Second, memdb.DisableGC, true, + debugTestServerConfig(), func(t testing.TB, ds datastore.Datastore) (datastore.Datastore, datastore.Revision) { return tf.DatastoreFromSchemaAndTestRelationships(t, ds, schema, relationships) }) @@ -1039,6 +1049,7 @@ func TestLookupResourcesDebugTrace_LR2(t *testing.T) { // NOTE: this makes LR2 the active implementation in the test server. lr2Config := testserver.DefaultTestServerConfig lr2Config.EnableExperimentalLookupResources3 = false + lr2Config.PrometheusRegisterer = prometheus.NewRegistry() conn, cleanup, _, revision := testserver.NewTestServerWithConfig(t, 5*time.Second, memdb.DisableGC, true, lr2Config, diff --git a/internal/services/v1/relationships.go b/internal/services/v1/relationships.go index a465ae5299..ddc9e40191 100644 --- a/internal/services/v1/relationships.go +++ b/internal/services/v1/relationships.go @@ -10,7 +10,6 @@ import ( "buf.build/go/protovalidate" grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" "go.opentelemetry.io/otel/trace" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" @@ -45,7 +44,7 @@ import ( "github.com/authzed/spicedb/pkg/zedtoken" ) -var writeUpdateCounter = promauto.NewHistogramVec(prometheus.HistogramOpts{ +var writeUpdateCounter = prometheus.NewHistogramVec(prometheus.HistogramOpts{ Namespace: "spicedb", Subsystem: "v1", Name: "write_relationships_updates", @@ -53,6 +52,13 @@ var writeUpdateCounter = promauto.NewHistogramVec(prometheus.HistogramOpts{ Buckets: []float64{0, 1, 2, 5, 10, 15, 25, 50, 100, 250, 500, 1000}, }, []string{"kind"}) +// RegisterMetrics registers the relationships service prometheus metrics with the provided registerer. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register relationships service metrics", writeUpdateCounter) + return err +} + const MaximumTransactionMetadataSize = 65000 // bytes. Limited by the BLOB size used in MySQL driver // PermissionsServerConfig is configuration for the permissions server. diff --git a/internal/telemetry/metrics.go b/internal/telemetry/metrics.go index 82a3d2c7f9..69c8d2d1d0 100644 --- a/internal/telemetry/metrics.go +++ b/internal/telemetry/metrics.go @@ -12,7 +12,6 @@ import ( "github.com/jzelinskie/cobrautil/v2" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" dto "github.com/prometheus/client_model/go" "golang.org/x/sync/errgroup" @@ -22,13 +21,20 @@ import ( "github.com/authzed/spicedb/pkg/promutil" ) -var LogicalChecks = promauto.NewCounter(prometheus.CounterOpts{ +var LogicalChecks = prometheus.NewCounter(prometheus.CounterOpts{ Namespace: "spicedb", Subsystem: "services", Name: "logical_checks_total", Help: `Count of the number of "checks" made across all APIs (e.g. each item within a CheckBulk, each item returned from a Lookup).`, }) +// RegisterMetrics registers the telemetry prometheus metrics with the provided registerer. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register telemetry metrics", LogicalChecks) + return err +} + func SpiceDBClusterInfoCollector(ctx context.Context, subsystem, dsEngine string, ds datastore.Datastore) (promutil.CollectorFunc, error) { nodeID, err := os.Hostname() if err != nil { diff --git a/internal/testserver/cluster.go b/internal/testserver/cluster.go index 086c74fa91..a88c02f5bc 100644 --- a/internal/testserver/cluster.go +++ b/internal/testserver/cluster.go @@ -12,6 +12,7 @@ import ( "time" "github.com/cespare/xxhash/v2" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/backoff" @@ -168,6 +169,7 @@ func TestClusterWithDispatch(t testing.TB, size uint, ds datastore.Datastore, ad dispatcherOptions := []combineddispatch.Option{ combineddispatch.UpstreamAddr("test://" + prefix), + combineddispatch.PrometheusRegisterer(prometheus.NewRegistry()), combineddispatch.PrometheusSubsystem(fmt.Sprintf("%s_%d_client_dispatch", prefix, i)), combineddispatch.QueryPlanMetadata(queryPlanMetadata), combineddispatch.GrpcDialOpts( @@ -226,6 +228,7 @@ func TestClusterWithDispatch(t testing.TB, size uint, ds datastore.Datastore, ad ctx, cancel := context.WithCancel(t.Context()) cfg := server.NewConfigWithOptionsAndDefaults(serverOptions...) + cfg.PrometheusRegisterer = prometheus.NewRegistry() srv, err := cfg.Complete(ctx) require.NoError(t, err) diff --git a/internal/testserver/datastore/config/config.go b/internal/testserver/datastore/config/config.go index 84d4c1ab1e..7e0b5ae6fa 100644 --- a/internal/testserver/datastore/config/config.go +++ b/internal/testserver/datastore/config/config.go @@ -3,6 +3,7 @@ package config import ( "testing" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" testdatastore "github.com/authzed/spicedb/internal/testserver/datastore" @@ -18,6 +19,9 @@ func DatastoreConfigInitFunc(t testing.TB, options ...dsconfig.ConfigOption) tes return func(engine, uri string) datastore.Datastore { ds, err := dsconfig.NewDatastore(t.Context(), append(options, + func(c *dsconfig.Config) { + c.PrometheusRegisterer = prometheus.NewRegistry() + }, dsconfig.WithEngine(engine), dsconfig.WithEnableDatastoreMetrics(false), dsconfig.WithURI(uri), diff --git a/internal/testserver/server.go b/internal/testserver/server.go index aa01f1fd0a..dc8631a61c 100644 --- a/internal/testserver/server.go +++ b/internal/testserver/server.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "google.golang.org/grpc" @@ -30,6 +31,7 @@ type ServerConfig struct { CaveatTypeSet *caveattypes.TypeSet EnableExperimentalLookupResources3 bool DataLayerOpts []datalayer.DataLayerOption + PrometheusRegisterer prometheus.Registerer } var DefaultTestServerConfig = ServerConfig{ @@ -77,6 +79,9 @@ func NewTestServerWithConfigAndDatastore(t testing.TB, dsInitFunc DatastoreInitFunc, ) (*grpc.ClientConn, func(), datastore.Datastore, datastore.Revision) { ds, revision := dsInitFunc(t, emptyDS) + if config.PrometheusRegisterer == nil { + config.PrometheusRegisterer = prometheus.NewRegistry() + } ctx, cancel := context.WithCancel(t.Context()) cts := caveattypes.TypeSetOrDefault(config.CaveatTypeSet) @@ -95,7 +100,7 @@ func NewTestServerWithConfigAndDatastore(t testing.TB, dispatcher, err := graph.NewLocalOnlyDispatcher(params) require.NoError(t, err) - srv, err := server.NewConfigWithOptionsAndDefaults( + srvConfig := server.NewConfigWithOptionsAndDefaults( server.WithDatastore(ds), server.WithDispatcher(dispatcher), server.WithQueryPlanMetadata(queryPlanMetadata), @@ -163,7 +168,9 @@ func NewTestServerWithConfigAndDatastore(t testing.TB, }, }, }), - ).Complete(ctx) + ) + srvConfig.PrometheusRegisterer = config.PrometheusRegisterer + srv, err := srvConfig.Complete(ctx) require.NoError(t, err) go func() { diff --git a/pkg/cache/metrics.go b/pkg/cache/metrics.go index b9b110b233..935e21da69 100644 --- a/pkg/cache/metrics.go +++ b/pkg/cache/metrics.go @@ -5,10 +5,15 @@ import ( "github.com/jzelinskie/stringz" "github.com/prometheus/client_golang/prometheus" + + "github.com/authzed/spicedb/pkg/datastore" ) -func init() { - prometheus.MustRegister(defaultCollector) +// RegisterMetrics registers the cache prometheus collector with the provided registerer. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register cache metrics", defaultCollector) + return err } const ( diff --git a/pkg/cmd/datastore/datastore.go b/pkg/cmd/datastore/datastore.go index 9c3c3c251c..70162c8400 100644 --- a/pkg/cmd/datastore/datastore.go +++ b/pkg/cmd/datastore/datastore.go @@ -10,6 +10,7 @@ import ( "time" "github.com/ccoveille/go-safecast/v2" + "github.com/prometheus/client_golang/prometheus" "github.com/spf13/pflag" "github.com/authzed/spicedb/internal/datastore/common" @@ -114,12 +115,13 @@ type Config struct { FilterMaximumIDCount uint16 `debugmap:"hidden" default:"100"` // Options - ReadConnPool ConnPoolConfig `debugmap:"visible"` - WriteConnPool ConnPoolConfig `debugmap:"visible"` - ReadOnly bool `debugmap:"visible"` - EnableDatastoreMetrics bool `debugmap:"visible"` - DisableStats bool `debugmap:"visible"` - IncludeQueryParametersInTraces bool `debugmap:"visible"` + ReadConnPool ConnPoolConfig `debugmap:"visible"` + WriteConnPool ConnPoolConfig `debugmap:"visible"` + ReadOnly bool `debugmap:"visible"` + PrometheusRegisterer prometheus.Registerer `debugmap:"hidden"` + EnableDatastoreMetrics bool `debugmap:"visible"` + DisableStats bool `debugmap:"visible"` + IncludeQueryParametersInTraces bool `debugmap:"visible"` // Read Replicas ReadReplicaConnPool ConnPoolConfig `debugmap:"visible"` @@ -603,6 +605,7 @@ func newCRDBDatastore(ctx context.Context, opts Config) (datastore.Datastore, er crdb.WithEnableConnectionBalancing(opts.EnableConnectionBalancing), crdb.ConnectRate(opts.ConnectRate), crdb.FilterMaximumIDCount(opts.FilterMaximumIDCount), + crdb.WithPrometheusRegisterer(opts.PrometheusRegisterer), crdb.WithIntegrity(opts.RelationshipIntegrityEnabled), crdb.AllowedMigrations(opts.AllowedMigrations), crdb.WithColumnOptimization(opts.ExperimentalColumnOptimization), diff --git a/pkg/cmd/datastore/zz_generated.options.go b/pkg/cmd/datastore/zz_generated.options.go index 9d814ebf72..0e772c72f0 100644 --- a/pkg/cmd/datastore/zz_generated.options.go +++ b/pkg/cmd/datastore/zz_generated.options.go @@ -5,6 +5,7 @@ import ( "fmt" types "github.com/authzed/spicedb/pkg/caveats/types" defaults "github.com/creasty/defaults" + prometheus "github.com/prometheus/client_golang/prometheus" "time" ) @@ -43,6 +44,7 @@ func (c *Config) ToOption() ConfigOption { to.ReadConnPool = c.ReadConnPool to.WriteConnPool = c.WriteConnPool to.ReadOnly = c.ReadOnly + to.PrometheusRegisterer = c.PrometheusRegisterer to.EnableDatastoreMetrics = c.EnableDatastoreMetrics to.DisableStats = c.DisableStats to.IncludeQueryParametersInTraces = c.IncludeQueryParametersInTraces @@ -431,6 +433,13 @@ func WithReadOnly(readOnly bool) ConfigOption { } } +// WithPrometheusRegisterer returns an option that can set PrometheusRegisterer on a Config +func WithPrometheusRegisterer(prometheusRegisterer prometheus.Registerer) ConfigOption { + return func(c *Config) { + c.PrometheusRegisterer = prometheusRegisterer + } +} + // WithEnableDatastoreMetrics returns an option that can set EnableDatastoreMetrics on a Config func WithEnableDatastoreMetrics(enableDatastoreMetrics bool) ConfigOption { return func(c *Config) { diff --git a/pkg/cmd/datastore_test.go b/pkg/cmd/datastore_test.go index 93831676e3..66b317ba4a 100644 --- a/pkg/cmd/datastore_test.go +++ b/pkg/cmd/datastore_test.go @@ -3,6 +3,7 @@ package cmd import ( "testing" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" datastoreTest "github.com/authzed/spicedb/internal/testserver/datastore" @@ -19,7 +20,8 @@ func TestExecuteGC(t *testing.T) { name: "cockroachdb does not support garbage collection", cfgBuilder: func(t *testing.T) *datastore.Config { cfg := datastore.DefaultDatastoreConfig() - cfg.EnableDatastoreMetrics = false // avoid "duplicate metrics collector registration attempted" + cfg.EnableDatastoreMetrics = false + cfg.PrometheusRegisterer = prometheus.NewRegistry() cfg.Engine = "cockroachdb" runningDatastore := datastoreTest.RunDatastoreEngine(t, cfg.Engine) db := runningDatastore.NewDatabase(t) @@ -46,23 +48,22 @@ func TestExecuteRepair(t *testing.T) { expectedError string }{ { - name: "cockroachdb does not support repair", + name: "memory datastore does not support repair", cfgBuilder: func(t *testing.T) *datastore.Config { cfg := datastore.DefaultDatastoreConfig() - cfg.EnableDatastoreMetrics = false // avoid "duplicate metrics collector registration attempted" - cfg.Engine = "cockroachdb" - runningDatastore := datastoreTest.RunDatastoreEngine(t, cfg.Engine) - db := runningDatastore.NewDatabase(t) - cfg.URI = db + cfg.EnableDatastoreMetrics = false + cfg.PrometheusRegisterer = prometheus.NewRegistry() + cfg.Engine = datastore.MemoryEngine return cfg }, - expectedError: "datastore of type 'cockroachdb' does not support the repair operation", + expectedError: "datastore of type 'memory' does not support the repair operation", }, { name: "postgres supports repair", cfgBuilder: func(t *testing.T) *datastore.Config { cfg := datastore.DefaultDatastoreConfig() - cfg.EnableDatastoreMetrics = false // avoid "duplicate metrics collector registration attempted" + cfg.EnableDatastoreMetrics = false + cfg.PrometheusRegisterer = prometheus.NewRegistry() cfg.Engine = "postgres" runningDatastore := datastoreTest.RunDatastoreEngine(t, cfg.Engine) db := runningDatastore.NewDatabase(t) diff --git a/pkg/cmd/serve_test.go b/pkg/cmd/serve_test.go index 117f6614a7..4b5fdf75a4 100644 --- a/pkg/cmd/serve_test.go +++ b/pkg/cmd/serve_test.go @@ -7,6 +7,7 @@ import ( "strings" "testing" + "github.com/prometheus/client_golang/prometheus" "github.com/spf13/cobra" "github.com/stretchr/testify/require" @@ -21,27 +22,28 @@ func RunServeTest(t *testing.T, args []string, assertConfig func(t *testing.T, m err := RegisterRootFlags(cmd) require.NoError(t, err) require.NoError(t, RegisterServeFlags(cmd, config)) - // Disable all metrics as they are singletons - config.DispatchClusterMetricsEnabled = false - config.DispatchClientMetricsEnabled = false - config.DatastoreConfig.EnableDatastoreMetrics = false - config.DispatchCacheConfig.Metrics = false - config.ClusterDispatchCacheConfig.Metrics = false - config.NamespaceCacheConfig.Metrics = false - config.StoredSchemaCacheConfig.Metrics = false + config.PrometheusRegisterer = prometheus.NewRegistry() cmd.SetArgs(args) cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx, cancel := context.WithCancel(cmd.Context()) - t.Cleanup(cancel) + defer cancel() - _, err := config.Complete(ctx) + srv, err := config.Complete(ctx) if err != nil { return err } + + runErrCh := make(chan error, 1) + go func() { + runErrCh <- srv.Run(ctx) + }() + assertConfig(t, config) - return nil + + cancel() + return <-runErrCh } require.NoError(t, cmd.Execute()) } diff --git a/pkg/cmd/server/defaults.go b/pkg/cmd/server/defaults.go index febda4da7b..2274e99755 100644 --- a/pkg/cmd/server/defaults.go +++ b/pkg/cmd/server/defaults.go @@ -94,8 +94,12 @@ func DefaultPreRunE(programName string) cobrautil.CobraRunFunc { // metrics and pprof endpoints. func MetricsHandler(telemetryRegistry *prometheus.Registry, c *Config) http.Handler { mux := http.NewServeMux() + gatherer := prometheus.DefaultGatherer + if c != nil && c.PrometheusGatherer != nil { + gatherer = c.PrometheusGatherer + } - mux.Handle("/metrics", promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{ + mux.Handle("/metrics", promhttp.HandlerFor(gatherer, promhttp.HandlerOpts{ // Opt into OpenMetrics e.g. to support exemplars. EnableOpenMetrics: true, })) @@ -182,7 +186,7 @@ const ( DefaultInternalMiddlewareServerSpecific = "servicespecific" ) -//go:generate go run github.com/ecordell/optgen -output zz_generated.middlewareoption.go . MiddlewareOption +//go:generate go run github.com/ecordell/optgen -output zz_generated.middlewareoption.go -prefix . MiddlewareOption type MiddlewareOption struct { Logger zerolog.Logger `debugmap:"hidden"` AuthFunc grpcauth.AuthFunc `debugmap:"hidden"` @@ -191,6 +195,7 @@ type MiddlewareOption struct { EnableRequestLog bool `debugmap:"visible"` EnableResponseLog bool `debugmap:"visible"` DisableGRPCHistogram bool `debugmap:"visible"` + PrometheusRegisterer prometheus.Registerer `debugmap:"hidden"` MiddlewareServiceLabel string `debugmap:"visible"` MismatchingZedTokenOption consistencymw.MismatchingTokenOption `debugmap:"visible"` @@ -253,9 +258,9 @@ var gRPCMetricsStreamingInterceptor grpc.StreamServerInterceptor var serverMetricsOnce sync.Once // GRPCMetrics returns the interceptors used for the default gRPC metrics from grpc-ecosystem/go-grpc-middleware -func GRPCMetrics(disableLatencyHistogram bool) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) { +func GRPCMetrics(registerer prometheus.Registerer, disableLatencyHistogram bool) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) { serverMetricsOnce.Do(func() { - gRPCMetricsUnaryInterceptor, gRPCMetricsStreamingInterceptor = createServerMetrics(disableLatencyHistogram) + gRPCMetricsUnaryInterceptor, gRPCMetricsStreamingInterceptor = createServerMetrics(registerer, disableLatencyHistogram) }) return gRPCMetricsUnaryInterceptor, gRPCMetricsStreamingInterceptor @@ -277,7 +282,7 @@ func doesNotMatchRoute(route string) func(_ context.Context, c interceptors.Call // DefaultUnaryMiddleware generates the default middleware chain used for the public SpiceDB Unary gRPC methods func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryServerInterceptor], error) { - grpcMetricsUnaryInterceptor, _ := GRPCMetrics(opts.DisableGRPCHistogram) + grpcMetricsUnaryInterceptor, _ := GRPCMetrics(opts.PrometheusRegisterer, opts.DisableGRPCHistogram) memoryProtectionUnaryInterceptor := memoryprotection.New(opts.MemoryUsageProvider, "unary-middleware") chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.UnaryServerInterceptor]{ NewUnaryMiddleware(). @@ -352,7 +357,7 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS // DefaultStreamingMiddleware generates the default middleware chain used for the public SpiceDB Streaming gRPC methods func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.StreamServerInterceptor], error) { - _, grpcMetricsStreamingInterceptor := GRPCMetrics(opts.DisableGRPCHistogram) + _, grpcMetricsStreamingInterceptor := GRPCMetrics(opts.PrometheusRegisterer, opts.DisableGRPCHistogram) memoryProtectionStreamInterceptor := memoryprotection.New(opts.MemoryUsageProvider, "stream-middleware") chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.StreamServerInterceptor]{ NewStreamMiddleware(). @@ -437,8 +442,8 @@ func determineEventsToLog(opts MiddlewareOption) grpclog.Option { } // DefaultDispatchMiddleware generates the default middleware chain used for the internal dispatch SpiceDB gRPC API -func DefaultDispatchMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc, ds datastore.Datastore, disableGRPCLatencyHistogram bool, memoryUsageProvider memoryprotection.MemoryUsageProvider, dlOpts ...datalayer.DataLayerOption) ([]grpc.UnaryServerInterceptor, []grpc.StreamServerInterceptor) { - grpcMetricsUnaryInterceptor, grpcMetricsStreamingInterceptor := GRPCMetrics(disableGRPCLatencyHistogram) +func DefaultDispatchMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc, ds datastore.Datastore, registerer prometheus.Registerer, disableGRPCLatencyHistogram bool, memoryUsageProvider memoryprotection.MemoryUsageProvider, dlOpts ...datalayer.DataLayerOption) ([]grpc.UnaryServerInterceptor, []grpc.StreamServerInterceptor) { + grpcMetricsUnaryInterceptor, grpcMetricsStreamingInterceptor := GRPCMetrics(registerer, disableGRPCLatencyHistogram) dispatchMemoryProtection := memoryprotection.New(memoryUsageProvider, "dispatch-middleware") dl := datalayer.NewDataLayer(ds, dlOpts...) @@ -484,7 +489,11 @@ func InterceptorLogger(l zerolog.Logger) grpclog.Logger { } // initializes prometheus grpc interceptors with exemplar support enabled -func createServerMetrics(disableHistogram bool) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) { +func createServerMetrics(registerer prometheus.Registerer, disableHistogram bool) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) { + if registerer == nil { + registerer = prometheus.DefaultRegisterer + } + var opts []grpcprom.ServerMetricsOption if !disableHistogram { opts = append(opts, grpcprom.WithServerHandlingTimeHistogram( @@ -500,7 +509,7 @@ func createServerMetrics(disableHistogram bool) (grpc.UnaryServerInterceptor, gr srvMetrics := grpcprom.NewServerMetrics(opts...) // deliberately ignore if these metrics were already registered, so that // custom builds of SpiceDB can register these metrics with custom labels - _ = prometheus.Register(srvMetrics) + _ = registerer.Register(srvMetrics) exemplarFromContext := func(ctx context.Context) prometheus.Labels { if span := trace.SpanContextFromContext(ctx); span.IsSampled() { diff --git a/pkg/cmd/server/defaults_test.go b/pkg/cmd/server/defaults_test.go index d4e9d0ce3d..1b09df6c51 100644 --- a/pkg/cmd/server/defaults_test.go +++ b/pkg/cmd/server/defaults_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/rs/zerolog" "github.com/stretchr/testify/require" @@ -32,6 +33,7 @@ func TestWithDatastore(t *testing.T) { true, true, false, + prometheus.DefaultRegisterer, "service", consistency.TreatMismatchingTokensAsError, memoryprotection.NewNoopMemoryUsageProvider(), @@ -75,6 +77,7 @@ func TestWithDatastoreMiddleware(t *testing.T) { true, true, false, + prometheus.DefaultRegisterer, "service", consistency.TreatMismatchingTokensAsError, memoryprotection.NewNoopMemoryUsageProvider(), diff --git a/pkg/cmd/server/middleware_test.go b/pkg/cmd/server/middleware_test.go index 4d2d5e5431..657dbaf1a3 100644 --- a/pkg/cmd/server/middleware_test.go +++ b/pkg/cmd/server/middleware_test.go @@ -3,7 +3,9 @@ package server import ( "context" "testing" + "time" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "go.uber.org/goleak" "google.golang.org/grpc" @@ -376,6 +378,7 @@ func TestMiddlewareOrdering(t *testing.T) { Enabled: true, }), ) + c.PrometheusRegisterer = prometheus.NewRegistry() rs, err := c.Complete(ctx) require.NoError(t, err) @@ -408,8 +411,13 @@ func TestMiddlewareOrdering(t *testing.T) { // NOTE: using WaitForReady ensures that the connection is active // and can accept RPCs, rather than returning an EOF error. - _, err = psc.CheckPermission(ctx, req, grpc.WaitForReady(true)) - require.NoError(t, err) + require.Eventually(t, func() bool { + callCtx, callCancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer callCancel() + + _, err = psc.CheckPermission(callCtx, req, grpc.WaitForReady(true)) + return err == nil + }, 5*time.Second, 50*time.Millisecond, "expected CheckPermission to succeed once server is ready") lrreq := &v1.LookupResourcesRequest{ ResourceObjectType: "resource", @@ -421,11 +429,18 @@ func TestMiddlewareOrdering(t *testing.T) { }, Permission: "read", } - lrc, err := psc.LookupResources(ctx, lrreq) - require.NoError(t, err) + require.Eventually(t, func() bool { + callCtx, callCancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer callCancel() - _, err = lrc.Recv() - require.NoError(t, err) + lrc, err := psc.LookupResources(callCtx, lrreq, grpc.WaitForReady(true)) + if err != nil { + return false + } + + _, err = lrc.Recv() + return err == nil + }, 5*time.Second, 50*time.Millisecond, "expected LookupResources to succeed once server is ready") cancel() require.NoError(t, <-errChan) @@ -487,6 +502,7 @@ func TestIncorrectOrderAssertionFails(t *testing.T) { }, }), ) + c.PrometheusRegisterer = prometheus.NewRegistry() rs, err := c.Complete(ctx) require.NoError(t, err) @@ -533,7 +549,7 @@ func TestIncorrectOrderAssertionFails(t *testing.T) { Permission: "read", } - lrc, err := psc.LookupResources(ctx, lrreq) + lrc, err := psc.LookupResources(ctx, lrreq, grpc.WaitForReady(true)) require.NoError(t, err) _, err = lrc.Recv() diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index ea102870da..c48f416f70 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -34,9 +34,12 @@ import ( combineddispatch "github.com/authzed/spicedb/internal/dispatch/combined" "github.com/authzed/spicedb/internal/dispatch/graph" "github.com/authzed/spicedb/internal/dispatch/keys" + "github.com/authzed/spicedb/internal/dispatch/singleflight" "github.com/authzed/spicedb/internal/gateway" log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/internal/middleware/memoryprotection" + "github.com/authzed/spicedb/internal/middleware/perfinsights" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" "github.com/authzed/spicedb/internal/services" dispatchSvc "github.com/authzed/spicedb/internal/services/dispatch" "github.com/authzed/spicedb/internal/services/health" @@ -47,7 +50,7 @@ import ( "github.com/authzed/spicedb/pkg/cmd/util" "github.com/authzed/spicedb/pkg/datalayer" "github.com/authzed/spicedb/pkg/datastore" - "github.com/authzed/spicedb/pkg/middleware/consistency" + consistency "github.com/authzed/spicedb/pkg/middleware/consistency" "github.com/authzed/spicedb/pkg/middleware/requestid" "github.com/authzed/spicedb/pkg/query" "github.com/authzed/spicedb/pkg/spiceerrors" @@ -158,10 +161,12 @@ type Config struct { DispatchStreamingMiddleware []grpc.StreamServerInterceptor `debugmap:"hidden"` // Telemetry - SilentlyDisableTelemetry bool `debugmap:"visible"` - TelemetryCAOverridePath string `debugmap:"visible"` - TelemetryEndpoint string `debugmap:"visible"` - TelemetryInterval time.Duration `debugmap:"visible"` + SilentlyDisableTelemetry bool `debugmap:"visible"` + TelemetryCAOverridePath string `debugmap:"visible"` + TelemetryEndpoint string `debugmap:"visible"` + TelemetryInterval time.Duration `debugmap:"visible"` + PrometheusGatherer prometheus.Gatherer `debugmap:"hidden"` + PrometheusRegisterer prometheus.Registerer `debugmap:"hidden"` // Logs EnableRequestLogs bool `debugmap:"visible"` @@ -280,6 +285,7 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { dispatcher, err = combineddispatch.NewDispatcher( combineddispatch.UpstreamAddr(c.DispatchUpstreamAddr), combineddispatch.UpstreamCAPath(c.DispatchUpstreamCAPath), + combineddispatch.PrometheusRegisterer(c.PrometheusRegisterer), combineddispatch.SecondaryUpstreamAddrs(c.DispatchSecondaryUpstreamAddrs), combineddispatch.SecondaryUpstreamExprs(c.DispatchSecondaryUpstreamExprs), combineddispatch.SecondaryMaximumPrimaryHedgingDelays(c.DispatchSecondaryMaximumPrimaryHedgingDelays), @@ -497,11 +503,37 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { closeables.AddCloser(gatewayCloser) closeables.AddWithoutError(gatewayServer.Close) + registerer := c.PrometheusRegisterer + if registerer == nil { + registerer = prometheus.DefaultRegisterer + } + + // Register all prometheus metrics with the configured registerer. + for _, regFn := range []func(prometheus.Registerer) error{ + schemacaching.RegisterMetrics, + singleflight.RegisterMetrics, + proxy.RegisterMetrics, + proxy.RegisterCheckingReplicatedMetrics, + proxy.RegisterStrictReplicatedMetrics, + usagemetrics.RegisterMetrics, + consistency.RegisterMetrics, + memoryprotection.RegisterMetrics, + perfinsights.RegisterMetrics, + telemetry.RegisterMetrics, + gateway.RegisterMetrics, + v1svc.RegisterMetrics, + cache.RegisterMetrics, + } { + if err := regFn(registerer); err != nil { + return nil, fmt.Errorf("failed to register prometheus metrics: %w", err) + } + } + infoCollector, err := telemetry.SpiceDBClusterInfoCollector(ctx, "environment", c.DatastoreConfig.Engine, ds) if err != nil { log.Warn().Err(err).Msg("unable to initialize info collector") } else { - if err := prometheus.Register(infoCollector); err != nil { + if err := registerer.Register(infoCollector); err != nil { log.Warn().Err(err).Msg("unable to initialize info collector") } } @@ -610,9 +642,9 @@ func (c *Config) BuildMemoryUsageProvider() memoryprotection.MemoryUsageProvider func (c *Config) buildDispatchServer(memoryUsageProvider memoryprotection.MemoryUsageProvider, ds datastore.Datastore, cachingClusterDispatch dispatch.Dispatcher, otelOpts []otelgrpc.Option, dlOpts []datalayer.DataLayerOption) (util.RunnableGRPCServer, error) { if len(c.DispatchUnaryMiddleware) == 0 && len(c.DispatchStreamingMiddleware) == 0 { if c.GRPCAuthFunc == nil { - c.DispatchUnaryMiddleware, c.DispatchStreamingMiddleware = DefaultDispatchMiddleware(log.Logger, auth.MustRequirePresharedKey(c.PresharedSecureKey), ds, c.DisableGRPCLatencyHistogram, memoryUsageProvider, dlOpts...) + c.DispatchUnaryMiddleware, c.DispatchStreamingMiddleware = DefaultDispatchMiddleware(log.Logger, auth.MustRequirePresharedKey(c.PresharedSecureKey), ds, c.PrometheusRegisterer, c.DisableGRPCLatencyHistogram, memoryUsageProvider, dlOpts...) } else { - c.DispatchUnaryMiddleware, c.DispatchStreamingMiddleware = DefaultDispatchMiddleware(log.Logger, c.GRPCAuthFunc, ds, c.DisableGRPCLatencyHistogram, memoryUsageProvider, dlOpts...) + c.DispatchUnaryMiddleware, c.DispatchStreamingMiddleware = DefaultDispatchMiddleware(log.Logger, c.GRPCAuthFunc, ds, c.PrometheusRegisterer, c.DisableGRPCLatencyHistogram, memoryUsageProvider, dlOpts...) } } diff --git a/pkg/cmd/server/server_test.go b/pkg/cmd/server/server_test.go index f604fb277b..04ff2bb8b1 100644 --- a/pkg/cmd/server/server_test.go +++ b/pkg/cmd/server/server_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" @@ -185,7 +186,9 @@ func TestOTelReporting(t *testing.T) { WithEnableMemoryProtectionMiddleware(false), } - srv, err := NewConfigWithOptionsAndDefaults(configOpts...).Complete(ctx) + srvCfg := NewConfigWithOptionsAndDefaults(configOpts...) + srvCfg.PrometheusRegisterer = prometheus.NewRegistry() + srv, err := srvCfg.Complete(ctx) require.NoError(t, err) conn, err := srv.GRPCDialContext(ctx) @@ -257,7 +260,9 @@ func TestDisableHealthCheckTracing(t *testing.T) { WithDatastore(ds), } - srv, err := NewConfigWithOptionsAndDefaults(configOpts...).Complete(ctx) + srvCfg2 := NewConfigWithOptionsAndDefaults(configOpts...) + srvCfg2.PrometheusRegisterer = prometheus.NewRegistry() + srv, err := srvCfg2.Complete(ctx) require.NoError(t, err) conn, err := srv.GRPCDialContext(ctx) @@ -394,7 +399,9 @@ func TestRetryPolicy(t *testing.T) { }), } - srv, err := NewConfigWithOptionsAndDefaults(configOpts...).Complete(ctx) + srvCfg3 := NewConfigWithOptionsAndDefaults(configOpts...) + srvCfg3.PrometheusRegisterer = prometheus.NewRegistry() + srv, err := srvCfg3.Complete(ctx) require.NoError(t, err) conn, err := srv.GRPCDialContext(ctx, @@ -469,6 +476,7 @@ func TestServerGracefulTerminationOnError(t *testing.T) { Network: util.BufferedNetwork, }, }, WithPresharedSecureKey("psk"), WithDatastore(ds), WithEnableMemoryProtectionMiddleware(false)) + c.PrometheusRegisterer = prometheus.NewRegistry() cancel() _, err = c.Complete(ctx) require.NoError(t, err) @@ -528,7 +536,7 @@ func TestModifyUnaryMiddleware(t *testing.T) { }, }} - opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, "testing", consistency.TreatMismatchingTokensAsFullConsistency, memoryprotection.NewNoopMemoryUsageProvider(), nil, nil} + opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, prometheus.DefaultRegisterer, "testing", consistency.TreatMismatchingTokensAsFullConsistency, memoryprotection.NewNoopMemoryUsageProvider(), nil, nil} opt = opt.WithDatastore(nil) defaultMw, err := DefaultUnaryMiddleware(opt) @@ -556,7 +564,7 @@ func TestModifyStreamingMiddleware(t *testing.T) { }, }} - opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, "testing", consistency.TreatMismatchingTokensAsFullConsistency, memoryprotection.NewNoopMemoryUsageProvider(), nil, nil} + opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, prometheus.DefaultRegisterer, "testing", consistency.TreatMismatchingTokensAsFullConsistency, memoryprotection.NewNoopMemoryUsageProvider(), nil, nil} opt = opt.WithDatastore(nil) defaultMw, err := DefaultStreamingMiddleware(opt) diff --git a/pkg/cmd/server/zz_generated.middlewareoption.go b/pkg/cmd/server/zz_generated.middlewareoption.go index c4ac6fd639..2fa06fd5f7 100644 --- a/pkg/cmd/server/zz_generated.middlewareoption.go +++ b/pkg/cmd/server/zz_generated.middlewareoption.go @@ -7,6 +7,7 @@ import ( consistency "github.com/authzed/spicedb/pkg/middleware/consistency" defaults "github.com/creasty/defaults" auth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" + prometheus "github.com/prometheus/client_golang/prometheus" zerolog "github.com/rs/zerolog" ) @@ -41,6 +42,7 @@ func (m *MiddlewareOption) ToOption() MiddlewareOptionOption { to.EnableRequestLog = m.EnableRequestLog to.EnableResponseLog = m.EnableResponseLog to.DisableGRPCHistogram = m.DisableGRPCHistogram + to.PrometheusRegisterer = m.PrometheusRegisterer to.MiddlewareServiceLabel = m.MiddlewareServiceLabel to.MismatchingZedTokenOption = m.MismatchingZedTokenOption to.MemoryUsageProvider = m.MemoryUsageProvider @@ -106,71 +108,78 @@ func (m *MiddlewareOption) WithOptions(opts ...MiddlewareOptionOption) *Middlewa return m } -// WithLogger returns an option that can set Logger on a MiddlewareOption -func WithLogger(logger zerolog.Logger) MiddlewareOptionOption { +// WithMiddlewareOptionLogger returns an option that can set Logger on a MiddlewareOption +func WithMiddlewareOptionLogger(logger zerolog.Logger) MiddlewareOptionOption { return func(m *MiddlewareOption) { m.Logger = logger } } -// WithAuthFunc returns an option that can set AuthFunc on a MiddlewareOption -func WithAuthFunc(authFunc auth.AuthFunc) MiddlewareOptionOption { +// WithMiddlewareOptionAuthFunc returns an option that can set AuthFunc on a MiddlewareOption +func WithMiddlewareOptionAuthFunc(authFunc auth.AuthFunc) MiddlewareOptionOption { return func(m *MiddlewareOption) { m.AuthFunc = authFunc } } -// WithEnableVersionResponse returns an option that can set EnableVersionResponse on a MiddlewareOption -func WithEnableVersionResponse(enableVersionResponse bool) MiddlewareOptionOption { +// WithMiddlewareOptionEnableVersionResponse returns an option that can set EnableVersionResponse on a MiddlewareOption +func WithMiddlewareOptionEnableVersionResponse(enableVersionResponse bool) MiddlewareOptionOption { return func(m *MiddlewareOption) { m.EnableVersionResponse = enableVersionResponse } } -// WithDispatcherForMiddleware returns an option that can set DispatcherForMiddleware on a MiddlewareOption -func WithDispatcherForMiddleware(dispatcherForMiddleware dispatch.Dispatcher) MiddlewareOptionOption { +// WithMiddlewareOptionDispatcherForMiddleware returns an option that can set DispatcherForMiddleware on a MiddlewareOption +func WithMiddlewareOptionDispatcherForMiddleware(dispatcherForMiddleware dispatch.Dispatcher) MiddlewareOptionOption { return func(m *MiddlewareOption) { m.DispatcherForMiddleware = dispatcherForMiddleware } } -// WithEnableRequestLog returns an option that can set EnableRequestLog on a MiddlewareOption -func WithEnableRequestLog(enableRequestLog bool) MiddlewareOptionOption { +// WithMiddlewareOptionEnableRequestLog returns an option that can set EnableRequestLog on a MiddlewareOption +func WithMiddlewareOptionEnableRequestLog(enableRequestLog bool) MiddlewareOptionOption { return func(m *MiddlewareOption) { m.EnableRequestLog = enableRequestLog } } -// WithEnableResponseLog returns an option that can set EnableResponseLog on a MiddlewareOption -func WithEnableResponseLog(enableResponseLog bool) MiddlewareOptionOption { +// WithMiddlewareOptionEnableResponseLog returns an option that can set EnableResponseLog on a MiddlewareOption +func WithMiddlewareOptionEnableResponseLog(enableResponseLog bool) MiddlewareOptionOption { return func(m *MiddlewareOption) { m.EnableResponseLog = enableResponseLog } } -// WithDisableGRPCHistogram returns an option that can set DisableGRPCHistogram on a MiddlewareOption -func WithDisableGRPCHistogram(disableGRPCHistogram bool) MiddlewareOptionOption { +// WithMiddlewareOptionDisableGRPCHistogram returns an option that can set DisableGRPCHistogram on a MiddlewareOption +func WithMiddlewareOptionDisableGRPCHistogram(disableGRPCHistogram bool) MiddlewareOptionOption { return func(m *MiddlewareOption) { m.DisableGRPCHistogram = disableGRPCHistogram } } -// WithMiddlewareServiceLabel returns an option that can set MiddlewareServiceLabel on a MiddlewareOption -func WithMiddlewareServiceLabel(middlewareServiceLabel string) MiddlewareOptionOption { +// WithMiddlewareOptionPrometheusRegisterer returns an option that can set PrometheusRegisterer on a MiddlewareOption +func WithMiddlewareOptionPrometheusRegisterer(prometheusRegisterer prometheus.Registerer) MiddlewareOptionOption { + return func(m *MiddlewareOption) { + m.PrometheusRegisterer = prometheusRegisterer + } +} + +// WithMiddlewareOptionMiddlewareServiceLabel returns an option that can set MiddlewareServiceLabel on a MiddlewareOption +func WithMiddlewareOptionMiddlewareServiceLabel(middlewareServiceLabel string) MiddlewareOptionOption { return func(m *MiddlewareOption) { m.MiddlewareServiceLabel = middlewareServiceLabel } } -// WithMismatchingZedTokenOption returns an option that can set MismatchingZedTokenOption on a MiddlewareOption -func WithMismatchingZedTokenOption(mismatchingZedTokenOption consistency.MismatchingTokenOption) MiddlewareOptionOption { +// WithMiddlewareOptionMismatchingZedTokenOption returns an option that can set MismatchingZedTokenOption on a MiddlewareOption +func WithMiddlewareOptionMismatchingZedTokenOption(mismatchingZedTokenOption consistency.MismatchingTokenOption) MiddlewareOptionOption { return func(m *MiddlewareOption) { m.MismatchingZedTokenOption = mismatchingZedTokenOption } } -// WithMemoryUsageProvider returns an option that can set MemoryUsageProvider on a MiddlewareOption -func WithMemoryUsageProvider(memoryUsageProvider memoryprotection.MemoryUsageProvider) MiddlewareOptionOption { +// WithMiddlewareOptionMemoryUsageProvider returns an option that can set MemoryUsageProvider on a MiddlewareOption +func WithMiddlewareOptionMemoryUsageProvider(memoryUsageProvider memoryprotection.MemoryUsageProvider) MiddlewareOptionOption { return func(m *MiddlewareOption) { m.MemoryUsageProvider = memoryUsageProvider } diff --git a/pkg/cmd/server/zz_generated.options.go b/pkg/cmd/server/zz_generated.options.go index 4b0cab41a6..9d77495d7d 100644 --- a/pkg/cmd/server/zz_generated.options.go +++ b/pkg/cmd/server/zz_generated.options.go @@ -11,6 +11,7 @@ import ( query "github.com/authzed/spicedb/pkg/query" defaults "github.com/creasty/defaults" auth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" + prometheus "github.com/prometheus/client_golang/prometheus" grpc "google.golang.org/grpc" "time" ) @@ -111,6 +112,8 @@ func (c *Config) ToOption() ConfigOption { to.TelemetryCAOverridePath = c.TelemetryCAOverridePath to.TelemetryEndpoint = c.TelemetryEndpoint to.TelemetryInterval = c.TelemetryInterval + to.PrometheusGatherer = c.PrometheusGatherer + to.PrometheusRegisterer = c.PrometheusRegisterer to.EnableRequestLogs = c.EnableRequestLogs to.EnableResponseLogs = c.EnableResponseLogs to.DisableGRPCLatencyHistogram = c.DisableGRPCLatencyHistogram @@ -999,6 +1002,20 @@ func WithTelemetryInterval(telemetryInterval time.Duration) ConfigOption { } } +// WithPrometheusGatherer returns an option that can set PrometheusGatherer on a Config +func WithPrometheusGatherer(prometheusGatherer prometheus.Gatherer) ConfigOption { + return func(c *Config) { + c.PrometheusGatherer = prometheusGatherer + } +} + +// WithPrometheusRegisterer returns an option that can set PrometheusRegisterer on a Config +func WithPrometheusRegisterer(prometheusRegisterer prometheus.Registerer) ConfigOption { + return func(c *Config) { + c.PrometheusRegisterer = prometheusRegisterer + } +} + // WithEnableRequestLogs returns an option that can set EnableRequestLogs on a Config func WithEnableRequestLogs(enableRequestLogs bool) ConfigOption { return func(c *Config) { diff --git a/pkg/cmd/testserver/testserver.go b/pkg/cmd/testserver/testserver.go index 611beb156b..7013d8dc5c 100644 --- a/pkg/cmd/testserver/testserver.go +++ b/pkg/cmd/testserver/testserver.go @@ -115,13 +115,13 @@ func (c *Config) Complete(ctx context.Context) (RunnableTestServer, error) { ) } - noAuth := server.WithAuthFunc(func(ctx context.Context) (context.Context, error) { + noAuth := server.WithMiddlewareOptionAuthFunc(func(ctx context.Context) (context.Context, error) { // Turn off the default auth system. return ctx, nil }) opts := *server.NewMiddlewareOptionWithOptions(noAuth, - server.WithLogger(log.Logger), - server.WithMemoryUsageProvider(&memoryprotection.HarcodedMemoryUsageProvider{AcceptAllRequests: true})) + server.WithMiddlewareOptionLogger(log.Logger), + server.WithMiddlewareOptionMemoryUsageProvider(&memoryprotection.HarcodedMemoryUsageProvider{AcceptAllRequests: true})) opts = opts.WithDatastoreMiddleware(datastoreMiddleware) unaryMiddleware, err := server.DefaultUnaryMiddleware(opts) diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index f7c605846d..4ef8eb7da4 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -1,6 +1,6 @@ package util -//go:generate go run github.com/ecordell/optgen -output zz_generated.options.go . GRPCServerConfig HTTPServerConfig +//go:generate go run github.com/ecordell/optgen -output zz_generated.options.go -prefix . GRPCServerConfig HTTPServerConfig import ( "cmp" @@ -15,6 +15,7 @@ import ( "github.com/jzelinskie/cobrautil/v2/cobraotel" _ "github.com/mostynb/go-grpc-compression/experimental/s2" // Register Snappy S2 compression + "github.com/prometheus/client_golang/prometheus" "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -34,15 +35,16 @@ import ( const BufferedNetwork string = "buffnet" type GRPCServerConfig struct { - Address string `debugmap:"visible"` - Network string `debugmap:"visible"` - TLSCertPath string `debugmap:"visible"` - TLSKeyPath string `debugmap:"visible"` - MaxConnAge time.Duration `debugmap:"visible"` - Enabled bool `debugmap:"visible"` - BufferSize int `debugmap:"visible"` - ClientCAPath string `debugmap:"visible"` - MaxWorkers uint32 `debugmap:"visible"` + Address string `debugmap:"visible"` + Network string `debugmap:"visible"` + TLSCertPath string `debugmap:"visible"` + TLSKeyPath string `debugmap:"visible"` + MaxConnAge time.Duration `debugmap:"visible"` + Enabled bool `debugmap:"visible"` + BufferSize int `debugmap:"visible"` + ClientCAPath string `debugmap:"visible"` + PrometheusRegisterer prometheus.Registerer `debugmap:"visible"` + MaxWorkers uint32 `debugmap:"visible"` flagPrefix string } @@ -156,7 +158,7 @@ func (c *GRPCServerConfig) tlsOpts() ([]grpc.ServerOption, *x509util.CertWatcher } // Else we've got TLS configuration and we'll construct the server options - watcher, err := x509util.NewTLSCertWatcher(c.TLSCertPath, c.TLSKeyPath) + watcher, err := x509util.NewTLSCertWatcher(c.PrometheusRegisterer, c.TLSCertPath, c.TLSKeyPath) if err != nil { return nil, nil, err } @@ -289,10 +291,11 @@ func (d *disabledGrpcServer) ForceStop() { } type HTTPServerConfig struct { - HTTPAddress string `debugmap:"visible"` - HTTPTLSCertPath string `debugmap:"visible"` - HTTPTLSKeyPath string `debugmap:"visible"` - HTTPEnabled bool `debugmap:"visible"` + HTTPAddress string `debugmap:"visible"` + HTTPTLSCertPath string `debugmap:"visible"` + HTTPTLSKeyPath string `debugmap:"visible"` + HTTPEnabled bool `debugmap:"visible"` + PrometheusRegisterer prometheus.Registerer `debugmap:"visible"` flagPrefix string } @@ -314,7 +317,7 @@ func (c *HTTPServerConfig) Complete(level zerolog.Level, handler http.Handler) ( } case c.HTTPTLSCertPath != "" && c.HTTPTLSKeyPath != "": - watcher, err := x509util.NewTLSCertWatcher(c.HTTPTLSCertPath, c.HTTPTLSKeyPath) + watcher, err := x509util.NewTLSCertWatcher(c.PrometheusRegisterer, c.HTTPTLSCertPath, c.HTTPTLSKeyPath) if err != nil { return nil, err } diff --git a/pkg/cmd/util/zz_generated.options.go b/pkg/cmd/util/zz_generated.options.go index 1c44dd3d6f..5c3a6535dd 100644 --- a/pkg/cmd/util/zz_generated.options.go +++ b/pkg/cmd/util/zz_generated.options.go @@ -3,6 +3,7 @@ package util import ( defaults "github.com/creasty/defaults" + prometheus "github.com/prometheus/client_golang/prometheus" "time" ) @@ -38,6 +39,7 @@ func (g *GRPCServerConfig) ToOption() GRPCServerConfigOption { to.Enabled = g.Enabled to.BufferSize = g.BufferSize to.ClientCAPath = g.ClientCAPath + to.PrometheusRegisterer = g.PrometheusRegisterer to.MaxWorkers = g.MaxWorkers } } @@ -79,6 +81,13 @@ func (g *GRPCServerConfig) DebugMap() map[string]any { } else { debugMap["ClientCAPath"] = g.ClientCAPath } + if dm, ok := any(&g.PrometheusRegisterer).(interface { + DebugMap() map[string]any + }); ok { + debugMap["PrometheusRegisterer"] = dm.DebugMap() + } else { + debugMap["PrometheusRegisterer"] = g.PrometheusRegisterer + } debugMap["MaxWorkers"] = g.MaxWorkers return debugMap } @@ -120,64 +129,71 @@ func (g *GRPCServerConfig) WithOptions(opts ...GRPCServerConfigOption) *GRPCServ return g } -// WithAddress returns an option that can set Address on a GRPCServerConfig -func WithAddress(address string) GRPCServerConfigOption { +// WithGRPCServerConfigAddress returns an option that can set Address on a GRPCServerConfig +func WithGRPCServerConfigAddress(address string) GRPCServerConfigOption { return func(g *GRPCServerConfig) { g.Address = address } } -// WithNetwork returns an option that can set Network on a GRPCServerConfig -func WithNetwork(network string) GRPCServerConfigOption { +// WithGRPCServerConfigNetwork returns an option that can set Network on a GRPCServerConfig +func WithGRPCServerConfigNetwork(network string) GRPCServerConfigOption { return func(g *GRPCServerConfig) { g.Network = network } } -// WithTLSCertPath returns an option that can set TLSCertPath on a GRPCServerConfig -func WithTLSCertPath(tLSCertPath string) GRPCServerConfigOption { +// WithGRPCServerConfigTLSCertPath returns an option that can set TLSCertPath on a GRPCServerConfig +func WithGRPCServerConfigTLSCertPath(tLSCertPath string) GRPCServerConfigOption { return func(g *GRPCServerConfig) { g.TLSCertPath = tLSCertPath } } -// WithTLSKeyPath returns an option that can set TLSKeyPath on a GRPCServerConfig -func WithTLSKeyPath(tLSKeyPath string) GRPCServerConfigOption { +// WithGRPCServerConfigTLSKeyPath returns an option that can set TLSKeyPath on a GRPCServerConfig +func WithGRPCServerConfigTLSKeyPath(tLSKeyPath string) GRPCServerConfigOption { return func(g *GRPCServerConfig) { g.TLSKeyPath = tLSKeyPath } } -// WithMaxConnAge returns an option that can set MaxConnAge on a GRPCServerConfig -func WithMaxConnAge(maxConnAge time.Duration) GRPCServerConfigOption { +// WithGRPCServerConfigMaxConnAge returns an option that can set MaxConnAge on a GRPCServerConfig +func WithGRPCServerConfigMaxConnAge(maxConnAge time.Duration) GRPCServerConfigOption { return func(g *GRPCServerConfig) { g.MaxConnAge = maxConnAge } } -// WithEnabled returns an option that can set Enabled on a GRPCServerConfig -func WithEnabled(enabled bool) GRPCServerConfigOption { +// WithGRPCServerConfigEnabled returns an option that can set Enabled on a GRPCServerConfig +func WithGRPCServerConfigEnabled(enabled bool) GRPCServerConfigOption { return func(g *GRPCServerConfig) { g.Enabled = enabled } } -// WithBufferSize returns an option that can set BufferSize on a GRPCServerConfig -func WithBufferSize(bufferSize int) GRPCServerConfigOption { +// WithGRPCServerConfigBufferSize returns an option that can set BufferSize on a GRPCServerConfig +func WithGRPCServerConfigBufferSize(bufferSize int) GRPCServerConfigOption { return func(g *GRPCServerConfig) { g.BufferSize = bufferSize } } -// WithClientCAPath returns an option that can set ClientCAPath on a GRPCServerConfig -func WithClientCAPath(clientCAPath string) GRPCServerConfigOption { +// WithGRPCServerConfigClientCAPath returns an option that can set ClientCAPath on a GRPCServerConfig +func WithGRPCServerConfigClientCAPath(clientCAPath string) GRPCServerConfigOption { return func(g *GRPCServerConfig) { g.ClientCAPath = clientCAPath } } -// WithMaxWorkers returns an option that can set MaxWorkers on a GRPCServerConfig -func WithMaxWorkers(maxWorkers uint32) GRPCServerConfigOption { +// WithGRPCServerConfigPrometheusRegisterer returns an option that can set PrometheusRegisterer on a GRPCServerConfig +func WithGRPCServerConfigPrometheusRegisterer(prometheusRegisterer prometheus.Registerer) GRPCServerConfigOption { + return func(g *GRPCServerConfig) { + g.PrometheusRegisterer = prometheusRegisterer + } +} + +// WithGRPCServerConfigMaxWorkers returns an option that can set MaxWorkers on a GRPCServerConfig +func WithGRPCServerConfigMaxWorkers(maxWorkers uint32) GRPCServerConfigOption { return func(g *GRPCServerConfig) { g.MaxWorkers = maxWorkers } @@ -211,6 +227,7 @@ func (h *HTTPServerConfig) ToOption() HTTPServerConfigOption { to.HTTPTLSCertPath = h.HTTPTLSCertPath to.HTTPTLSKeyPath = h.HTTPTLSKeyPath to.HTTPEnabled = h.HTTPEnabled + to.PrometheusRegisterer = h.PrometheusRegisterer } } @@ -233,6 +250,13 @@ func (h *HTTPServerConfig) DebugMap() map[string]any { debugMap["HTTPTLSKeyPath"] = h.HTTPTLSKeyPath } debugMap["HTTPEnabled"] = h.HTTPEnabled + if dm, ok := any(&h.PrometheusRegisterer).(interface { + DebugMap() map[string]any + }); ok { + debugMap["PrometheusRegisterer"] = dm.DebugMap() + } else { + debugMap["PrometheusRegisterer"] = h.PrometheusRegisterer + } return debugMap } @@ -273,30 +297,37 @@ func (h *HTTPServerConfig) WithOptions(opts ...HTTPServerConfigOption) *HTTPServ return h } -// WithHTTPAddress returns an option that can set HTTPAddress on a HTTPServerConfig -func WithHTTPAddress(hTTPAddress string) HTTPServerConfigOption { +// WithHTTPServerConfigHTTPAddress returns an option that can set HTTPAddress on a HTTPServerConfig +func WithHTTPServerConfigHTTPAddress(hTTPAddress string) HTTPServerConfigOption { return func(h *HTTPServerConfig) { h.HTTPAddress = hTTPAddress } } -// WithHTTPTLSCertPath returns an option that can set HTTPTLSCertPath on a HTTPServerConfig -func WithHTTPTLSCertPath(hTTPTLSCertPath string) HTTPServerConfigOption { +// WithHTTPServerConfigHTTPTLSCertPath returns an option that can set HTTPTLSCertPath on a HTTPServerConfig +func WithHTTPServerConfigHTTPTLSCertPath(hTTPTLSCertPath string) HTTPServerConfigOption { return func(h *HTTPServerConfig) { h.HTTPTLSCertPath = hTTPTLSCertPath } } -// WithHTTPTLSKeyPath returns an option that can set HTTPTLSKeyPath on a HTTPServerConfig -func WithHTTPTLSKeyPath(hTTPTLSKeyPath string) HTTPServerConfigOption { +// WithHTTPServerConfigHTTPTLSKeyPath returns an option that can set HTTPTLSKeyPath on a HTTPServerConfig +func WithHTTPServerConfigHTTPTLSKeyPath(hTTPTLSKeyPath string) HTTPServerConfigOption { return func(h *HTTPServerConfig) { h.HTTPTLSKeyPath = hTTPTLSKeyPath } } -// WithHTTPEnabled returns an option that can set HTTPEnabled on a HTTPServerConfig -func WithHTTPEnabled(hTTPEnabled bool) HTTPServerConfigOption { +// WithHTTPServerConfigHTTPEnabled returns an option that can set HTTPEnabled on a HTTPServerConfig +func WithHTTPServerConfigHTTPEnabled(hTTPEnabled bool) HTTPServerConfigOption { return func(h *HTTPServerConfig) { h.HTTPEnabled = hTTPEnabled } } + +// WithHTTPServerConfigPrometheusRegisterer returns an option that can set PrometheusRegisterer on a HTTPServerConfig +func WithHTTPServerConfigPrometheusRegisterer(prometheusRegisterer prometheus.Registerer) HTTPServerConfigOption { + return func(h *HTTPServerConfig) { + h.PrometheusRegisterer = prometheusRegisterer + } +} diff --git a/pkg/datastore/gc.go b/pkg/datastore/gc.go index f9a9433504..042ff67f99 100644 --- a/pkg/datastore/gc.go +++ b/pkg/datastore/gc.go @@ -63,21 +63,18 @@ var ( // RegisterGCMetrics registers garbage collection metrics to the default // registry and returns them (so that they be unregistered). -func RegisterGCMetrics() ([]prometheus.Collector, error) { +func RegisterGCMetrics(registerer prometheus.Registerer) (func(), error) { collectors := []prometheus.Collector{ gcDurationHistogram, gcRelationshipsCounter, + gcExpiredRelationshipsCounter, gcTransactionsCounter, gcNamespacesCounter, gcFailureCounter, } - for _, metric := range collectors { - if err := prometheus.Register(metric); err != nil { - return nil, fmt.Errorf("failed to register GC metric: %w", err) - } - } + unregister, _ := RegisterPrometheusCollectors(registerer, "failed to register GC metrics", collectors...) - return collectors, nil + return unregister, nil } // GarbageCollectableDatastore represents a datastore supporting external diff --git a/pkg/datastore/util.go b/pkg/datastore/util.go index e0c96359f9..54d9c50c43 100644 --- a/pkg/datastore/util.go +++ b/pkg/datastore/util.go @@ -2,6 +2,10 @@ package datastore import ( "context" + "errors" + + "github.com/prometheus/client_golang/prometheus" + "github.com/rs/zerolog/log" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" ) @@ -59,3 +63,22 @@ func DeleteAllData(ctx context.Context, ds Datastore) error { }) return err } + +func RegisterPrometheusCollectors(registerer prometheus.Registerer, errMessage string, collectors ...prometheus.Collector) (func(), error) { + if registerer == nil { + registerer = prometheus.DefaultRegisterer + } + for _, c := range collectors { + if err := registerer.Register(c); err != nil { + if err, ok := errors.AsType[prometheus.AlreadyRegisteredError](err); ok { + log.Warn().Err(err).Msg(errMessage) + return nil, err + } + } + } + return func() { + for _, c := range collectors { + registerer.Unregister(c) + } + }, nil +} diff --git a/pkg/middleware/consistency/consistency.go b/pkg/middleware/consistency/consistency.go index cc1d42098f..1f9d0762b8 100644 --- a/pkg/middleware/consistency/consistency.go +++ b/pkg/middleware/consistency/consistency.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" "github.com/rs/zerolog/log" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -22,13 +21,20 @@ import ( "github.com/authzed/spicedb/pkg/zedtoken" ) -var ConsistencyCounter = promauto.NewCounterVec(prometheus.CounterOpts{ +var ConsistencyCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: "spicedb", Subsystem: "middleware", Name: "consistency_assigned_total", Help: "Count of the consistencies used per request", }, []string{"method", "source", "service"}) +// RegisterMetrics registers the consistency middleware prometheus metrics with the provided registerer. +// If registerer is nil, prometheus.DefaultRegisterer is used. +func RegisterMetrics(registerer prometheus.Registerer) error { + _, err := datastore.RegisterPrometheusCollectors(registerer, "failed to register consistency metrics", ConsistencyCounter) + return err +} + // MismatchingTokenOption is the option specifying the behavior of the consistency middleware // when a ZedToken provided references a different datastore instance than the current // datastore instance. diff --git a/pkg/x509util/certwatcher.go b/pkg/x509util/certwatcher.go index df3efc3efe..945ae44c1e 100644 --- a/pkg/x509util/certwatcher.go +++ b/pkg/x509util/certwatcher.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "crypto/tls" + "errors" "fmt" "os" "sync" @@ -69,46 +70,80 @@ type CertWatcher struct { // metrics ReadCertificateTotal prometheus.Counter ReadCertificateErrors prometheus.Counter + prometheusUnregister func() } // NewTLSCertWatcher returns a new CertWatcher watching the given certificate and key. // It registers prometheus metrics for certificate read counts and errors. -// The metrics are unregistered when Start returns. -func NewTLSCertWatcher(certPath, keyPath string) (*CertWatcher, error) { +// If the metrics are already registered, it reuses the existing collectors. +func NewTLSCertWatcher(registerer prometheus.Registerer, certPath, keyPath string) (*CertWatcher, error) { var err error + unregister := func() {} + + if registerer == nil { + registerer = prometheus.DefaultRegisterer + } + + readTotal, err := registerOrGetCounter(registerer, ReadTotal) + if err != nil { + return nil, err + } + if readTotal != ReadTotal { + unregister = func() { + registerer.Unregister(readTotal) + } + } + + readErrors, err := registerOrGetCounter(registerer, ReadErrors) + if err != nil { + log.Warn().Err(err).Msg("failed to register certificate metrics") + return nil, err + } + if readErrors != ReadErrors { + unregister = func() { + unregister() + registerer.Unregister(readErrors) + } + } cw := &CertWatcher{ certPath: certPath, keyPath: keyPath, interval: defaultWatchInterval, started: make(chan error), - ReadCertificateTotal: ReadTotal, - ReadCertificateErrors: ReadErrors, + ReadCertificateTotal: readTotal, + ReadCertificateErrors: readErrors, + prometheusUnregister: unregister, } - // ignore "duplicate metric registration" errors - _ = prometheus.Register(cw.ReadCertificateTotal) - _ = prometheus.Register(cw.ReadCertificateErrors) - // Initial read of certificate and key. if err := cw.ReadCertificate(); err != nil { - cw.unregisterMetrics() return nil, err } cw.watcher, err = fsnotify.NewWatcher() if err != nil { - cw.unregisterMetrics() return nil, err } return cw, nil } -// unregisterMetrics removes the prometheus metrics registered by this CertWatcher. -func (cw *CertWatcher) unregisterMetrics() { - prometheus.Unregister(cw.ReadCertificateTotal) - prometheus.Unregister(cw.ReadCertificateErrors) +func registerOrGetCounter(registerer prometheus.Registerer, counter prometheus.Counter) (prometheus.Counter, error) { + if err := registerer.Register(counter); err != nil { + if alreadyRegisteredError, ok := errors.AsType[prometheus.AlreadyRegisteredError](err); ok { + existingCounter, ok := alreadyRegisteredError.ExistingCollector.(prometheus.Counter) + if !ok { + return nil, alreadyRegisteredError + } + + return existingCounter, nil + } + + return nil, err + } + + return counter, nil } // WithWatchInterval sets the watch interval and returns the CertWatcher pointer @@ -148,7 +183,7 @@ func (cw *CertWatcher) Started() <-chan error { // When Start returns, it unregisters the prometheus metrics that were registered // in NewTLSCertWatcher. func (cw *CertWatcher) Start(ctx context.Context) { - defer cw.unregisterMetrics() + defer cw.prometheusUnregister() for _, f := range []string{cw.certPath, cw.keyPath} { if err := cw.watcher.Add(f); err != nil { diff --git a/pkg/x509util/certwatcher_test.go b/pkg/x509util/certwatcher_test.go index 97790aa7f1..65e9b68c54 100644 --- a/pkg/x509util/certwatcher_test.go +++ b/pkg/x509util/certwatcher_test.go @@ -32,6 +32,7 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -39,7 +40,7 @@ import ( func TestCertWatcherNew(t *testing.T) { t.Run("should error without cert/key", func(t *testing.T) { - _, err := NewTLSCertWatcher("", "") + _, err := NewTLSCertWatcher(prometheus.DefaultRegisterer, "", "") require.Error(t, err) }) } @@ -53,7 +54,7 @@ func TestCertWatcherSequentialMetricRegistration(t *testing.T) { // First watcher: start and stop. ctx1, cancel1 := context.WithCancel(t.Context()) - watcher1, err := NewTLSCertWatcher(certPath, keyPath) + watcher1, err := NewTLSCertWatcher(prometheus.DefaultRegisterer, certPath, keyPath) require.NoError(t, err) go func() { @@ -63,7 +64,7 @@ func TestCertWatcherSequentialMetricRegistration(t *testing.T) { // Second watcher: should not fail due to duplicate metric registration. ctx2, cancel2 := context.WithCancel(t.Context()) - watcher2, err := NewTLSCertWatcher(certPath, keyPath) + watcher2, err := NewTLSCertWatcher(prometheus.DefaultRegisterer, certPath, keyPath) require.NoError(t, err) go func() { @@ -86,7 +87,7 @@ func setupWatcher(t *testing.T, ip string) (certPath, keyPath string, watcher *C err := writeCerts(certPath, keyPath, ip) require.NoError(t, err) - watcher, err = NewTLSCertWatcher(certPath, keyPath) + watcher, err = NewTLSCertWatcher(prometheus.DefaultRegisterer, certPath, keyPath) require.NoError(t, err) startWatcher = func(interval time.Duration) {