diff --git a/README.md b/README.md index 93ac8dca9..7e1db2c3c 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ Features | **[GPU Telemetry](docs/tutorials/gpu-telemetry.md)** | Real-time GPU metrics collection via DCGM (power, utilization, memory, temperature, etc) | Performance optimization, resource monitoring, multi-node telemetry | | **[Template Endpoint](docs/tutorials/template-endpoint.md)** | Benchmark custom APIs with flexible Jinja2 request templates | Custom API formats, rapid prototyping, non-standard endpoints | | **[SGLang Image Generation](docs/tutorials/sglang-image-generation.md)** | Benchmark image generation APIs using SGLang with FLUX.1-dev model | Image generation testing, text-to-image benchmarking, extracting generated images | +| **[Server Metrics](docs/tutorials/server-metrics.md)** | Collect Prometheus-compatible server metrics during benchmarking | Performance optimization, resource monitoring, multi-node telemetry | ### Working with Benchmark Data - **[Profile Exports](docs/tutorials/working-with-profile-exports.md)** - Parse and analyze `profile_export.jsonl` with Pydantic models, custom metrics, and async processing diff --git a/docs/cli_options.md b/docs/cli_options.md index b2728d6de..37da70da9 100644 --- a/docs/cli_options.md +++ b/docs/cli_options.md @@ -410,6 +410,12 @@ The delay in seconds before cancelling requests. This is used when --request-can Enable GPU telemetry console display and optionally specify: (1) 'dashboard' for realtime dashboard mode, (2) custom DCGM exporter URLs (e.g., http://node1:9401/metrics), (3) custom metrics CSV file (e.g., custom_gpu_metrics.csv). Default endpoints localhost:9400 and localhost:9401 are always attempted. Example: --gpu-telemetry dashboard node1:9400 custom.csv. +## Server Metrics Options + +#### `--server-metrics` `` + +Server metrics collection (ENABLED BY DEFAULT). Automatically collects from inference endpoint base_url + `/metrics`. Optionally specify additional custom Prometheus-compatible endpoint URLs (e.g., http://node1:8081/metrics, http://node2:9090/metrics). Use AIPERF_SERVER_METRICS_ENABLED=false to disable. Example: `--server-metrics node1:8081 node2:9090/metrics` for additional endpoints. + ## ZMQ Communication Options #### `--zmq-host` `` diff --git a/docs/environment_variables.md b/docs/environment_variables.md index c56d66dfb..4213a9864 100644 --- a/docs/environment_variables.md +++ b/docs/environment_variables.md @@ -39,6 +39,7 @@ GPU telemetry collection configuration. Controls GPU metrics collection frequenc |----------------------|---------|-------------|-------------| | `AIPERF_GPU_COLLECTION_INTERVAL` | `0.33` | ≥ 0.01, ≤ 300.0 | GPU telemetry metrics collection interval in seconds (default: 330ms, ~3Hz) | | `AIPERF_GPU_DEFAULT_DCGM_ENDPOINTS` | `['http://localhost:9400/metrics', 'http://localhost:9401/metrics']` | — | Default DCGM endpoint URLs to check for GPU telemetry (comma-separated string or JSON array) | +| `AIPERF_GPU_EXPORT_BATCH_SIZE` | `100` | ≥ 1, ≤ 1000000 | Batch size for telemetry record export results processor | | `AIPERF_GPU_REACHABILITY_TIMEOUT` | `5` | ≥ 1, ≤ 300 | Timeout in seconds for checking GPU telemetry endpoint reachability during init | | `AIPERF_GPU_SHUTDOWN_DELAY` | `5.0` | ≥ 1.0, ≤ 300.0 | Delay in seconds before shutting down GPU telemetry service to allow command response transmission | | `AIPERF_GPU_THREAD_JOIN_TIMEOUT` | `5.0` | ≥ 1.0, ≤ 300.0 | Timeout in seconds for joining GPU telemetry collection threads during shutdown | @@ -89,6 +90,20 @@ Record processing and export configuration. Controls batch sizes, processor scal | `AIPERF_RECORD_PROCESSOR_SCALE_FACTOR` | `4` | ≥ 1, ≤ 100 | Scale factor for number of record processors to spawn based on worker count. Formula: 1 record processor for every X workers | | `AIPERF_RECORD_PROGRESS_REPORT_INTERVAL` | `2.0` | ≥ 0.1, ≤ 600.0 | Interval in seconds between records progress report messages | +## SERVERMETRICS + +Server metrics collection configuration. Controls server metrics collection frequency, endpoint detection, and shutdown behavior. Metrics are collected from Prometheus-compatible endpoints at the specified interval. + +| Environment Variable | Default | Constraints | Description | +|----------------------|---------|-------------|-------------| +| `AIPERF_SERVER_METRICS_ENABLED` | `True` | — | Enable server metrics collection (set to false to disable entirely) | +| `AIPERF_SERVER_METRICS_COLLECTION_FLUSH_PERIOD` | `2.0` | ≥ 0.0, ≤ 30.0 | Time in seconds to continue collecting metrics after profiling completes, allowing server-side metrics to flush/finalize before shutting down (default: 2.0s) | +| `AIPERF_SERVER_METRICS_COLLECTION_INTERVAL` | `0.33` | ≥ 0.01, ≤ 300.0 | Server metrics collection interval in seconds (default: 330ms, ~3Hz) | +| `AIPERF_SERVER_METRICS_DEFAULT_BACKEND_PORTS` | `[]` | — | Default backend ports to check on inference endpoint hostname (comma-separated string or JSON array) | +| `AIPERF_SERVER_METRICS_EXPORT_BATCH_SIZE` | `100` | ≥ 1, ≤ 1000000 | Batch size for server metrics jsonl writer export results processor | +| `AIPERF_SERVER_METRICS_REACHABILITY_TIMEOUT` | `5` | ≥ 1, ≤ 300 | Timeout in seconds for checking server metrics endpoint reachability during init | +| `AIPERF_SERVER_METRICS_SHUTDOWN_DELAY` | `5.0` | ≥ 1.0, ≤ 300.0 | Delay in seconds before shutting down server metrics service to allow command response transmission | + ## SERVICE Service lifecycle and inter-service communication configuration. Controls timeouts for service registration, startup, shutdown, command handling, connection probing, heartbeats, and profile operations. diff --git a/docs/tutorials/server-metrics.md b/docs/tutorials/server-metrics.md new file mode 100644 index 000000000..12757e65d --- /dev/null +++ b/docs/tutorials/server-metrics.md @@ -0,0 +1,632 @@ + + +# Server Metrics Collection with AIPerf + +This guide shows you how to use AIPerf's automatic server metrics collection feature. Server metrics provide insights into LLM inference server performance, including request counts, latencies, cache utilization, and custom application metrics. + +## Overview + +AIPerf **automatically collects metrics by default** from Prometheus-compatible endpoints exposed by LLM inference servers like vLLM, SGLang, TRT-LLM, and others. These metrics complement AIPerf's client-side measurements with server-side observability data. + +**What You'll Learn:** +- How automatic server metrics collection works (enabled by default) +- Configure additional custom Prometheus endpoints +- Understand the output files and data format +- Use server metrics for performance analysis + +## Quick Start + +### Basic Usage + +Server metrics are **automatically collected** for the inference endpoint port - just run AIPerf normally: + +```bash +aiperf profile \ + --model Qwen/Qwen3-0.6B \ + --endpoint-type chat \ + --endpoint /v1/chat/completions \ + --url localhost:8000 \ + --concurrency 4 \ + --request-count 100 +``` + +**What happens automatically:** +1. AIPerf queries the Prometheus `/metrics` endpoint on your inference server (checks `--url` port) +2. Collects metrics every `AIPERF_SERVER_METRICS_COLLECTION_INTERVAL` (configurable) +3. Exports time-series data to `server_metrics_export.jsonl` file +4. Saves metadata about collected metrics to `server_metrics_metadata.json` file + +> [!TIP] +> No flag needed! Server metrics are collected by default. Use `--server-metrics ` to add additional endpoints, or set `AIPERF_SERVER_METRICS_ENABLED=false` to disable. + +### Automatic Endpoint + +By default, AIPerf automatically discovers and queries the Prometheus `/metrics` endpoint on your inference server (checks `--url` port) and any additional ports specified via `AIPERF_SERVER_METRICS_DEFAULT_BACKEND_PORTS` (comma-separated string or JSON array). + +> [!TIP] +> **Default Port Handling:** When your inference URL has no explicit port (e.g., `https://api.example.com/v1/chat`), AIPerf uses the default port for the scheme (443 for HTTPS, 80 for HTTP) before checking additional ports from `AIPERF_SERVER_METRICS_DEFAULT_BACKEND_PORTS`. + +### Custom Endpoint URLs + +Specify additional custom Prometheus endpoints explicitly: + +```bash +# Single custom endpoint +aiperf profile --model MODEL ... --server-metrics http://localhost:8081 + +# Multiple endpoints (multi-node or multiple services) +aiperf profile --model MODEL ... --server-metrics \ + http://node1:8081 \ + http://node2:8081 \ + http://monitoring:9090 +``` + +> [!NOTE] +> URLs can be specified with or without the `http://` prefix and `/metrics` suffix. AIPerf normalizes them automatically: +> - `localhost:8000` → `http://localhost:8000/metrics` +> - `http://server:9090` → `http://server:9090/metrics` +> - `localhost:8081/metrics` → `http://localhost:8081/metrics` + +### Disabling Server Metrics + +To disable automatic server metrics collection: + +```bash +export AIPERF_SERVER_METRICS_ENABLED=false +``` + +This completely disables server metrics collection for the run. + +## Understanding Server Metrics + +### What Metrics Are Collected? + +AIPerf collects **all metrics** exposed by Prometheus-compatible endpoints, with automatic filtering: + +- **Collected:** All counter, gauge, histogram, and summary metrics +- **Automatically Filtered:** Metrics ending with `_created` (internal Prometheus timestamps) + +Common metrics from LLM inference servers include: + +#### vLLM Metrics Examples +- **Queue Metrics:** `vllm:num_requests_running`, `vllm:num_requests_waiting`, `vllm:num_requests_swapped` +- **Cache Utilization:** `vllm:gpu_cache_usage_perc`, `vllm:cpu_cache_usage_perc` +- **Prefix Cache:** `vllm:gpu_prefix_cache_hit_rate`, `vllm:cpu_prefix_cache_hit_rate` +- **Throughput:** `vllm:avg_prompt_throughput_toks_per_s`, `vllm:avg_generation_throughput_toks_per_s` +- **Latency Histograms:** `vllm:time_to_first_token_seconds`, `vllm:time_per_output_token_seconds`, `vllm:e2e_request_latency_seconds` +- **Token Distributions:** `vllm:request_prompt_tokens`, `vllm:request_generation_tokens` +- **Counters:** `vllm:request_success_total`, `vllm:prompt_tokens_total`, `vllm:generation_tokens_total`, `vllm:num_preemptions_total` + +#### Dynamo Frontend Metrics Examples +- **Request Metrics:** `dynamo_frontend_requests` +- **Latency Distributions:** `dynamo_frontend_time_to_first_token_seconds`, `dynamo_frontend_request_duration_seconds`, `dynamo_frontend_inter_token_latency_seconds` +- **Queue Metrics:** `dynamo_frontend_queued_requests`, `dynamo_frontend_inflight_requests` +- **Token Metrics:** `dynamo_frontend_input_sequence_tokens`, `dynamo_frontend_output_sequence_tokens` +- **Model Config:** `dynamo_frontend_model_context_length`, `dynamo_frontend_model_total_kv_blocks` + +#### Dynamo Component Metrics Examples +- **Request Metrics:** `dynamo_component_requests_total`, `dynamo_component_errors_total` +- **Data Transfer:** `dynamo_component_request_bytes_total`, `dynamo_component_response_bytes_total` +- **Task Metrics:** `dynamo_component_tasks_issued_total`, `dynamo_component_tasks_success_total`, `dynamo_component_tasks_failed_total` +- **Performance:** `dynamo_component_request_duration_seconds`, `dynamo_component_inflight_requests` +- **System:** `dynamo_component_uptime_seconds` +- **NATS Metrics:** `dynamo_component_nats_client_in_messages`, `dynamo_component_nats_service_requests_total` + +#### Prometheus Metric Types Supported +- **Counter:** Cumulative values (e.g., total requests, total tokens) +- **Gauge:** Point-in-time values (e.g., cache utilization %) +- **Histogram:** Distribution with buckets (e.g., latency percentiles) +- **Summary:** Pre-computed quantiles (e.g., p50, p90, p99) + +### Output Files + +AIPerf generates two files per benchmark run: + +#### 1. Time-Series Data: `server_metrics_export.jsonl` + +Line-delimited JSON with metrics snapshots collected over time (from real Dynamo run): + +```json +{"endpoint_url":"http://localhost:8081/metrics","timestamp_ns":1763591215213919629,"endpoint_latency_ns":712690779,"metrics":{"dynamo_component_requests":[{"labels":{"dynamo_component":"prefill","dynamo_endpoint":"generate","model":"openai/gpt-oss-20b"},"value":360.0}],"dynamo_component_nats_client_in_messages":[{"value":59284.0}],"dynamo_component_request_duration_seconds":[{"labels":{"dynamo_component":"prefill","dynamo_endpoint":"generate","model":"openai/gpt-oss-20b"},"histogram":{"0.005":0.0,"0.01":0.0,"0.025":123.0,"0.05":327.0,"0.1":348.0,"0.25":360.0,"+Inf":360.0},"sum":12.232215459,"count":360.0}]}} +{"endpoint_url":"http://localhost:8000/metrics","timestamp_ns":1763591215220757503,"endpoint_latency_ns":719764167,"metrics":{"dynamo_frontend_requests":[{"labels":{"endpoint":"chat_completions","model":"openai/gpt-oss-20b","status":"success"},"value":1000.0}],"dynamo_frontend_queued_requests":[{"labels":{"model":"openai/gpt-oss-20b"},"value":0.0}],"dynamo_frontend_time_to_first_token_seconds":[{"labels":{"model":"openai/gpt-oss-20b"},"histogram":{"0":0.0,"0.0047":0.0,"0.1":835.0,"0.47":892.0,"1":899.0,"10":1000.0,"+Inf":1000.0},"sum":765.15823571,"count":1000.0}]}} +``` + +**Each line contains:** +- `endpoint_url`: Source Prometheus endpoint +- `timestamp_ns`: Collection timestamp (nanoseconds since epoch) +- `endpoint_latency_ns`: Time to fetch metrics from endpoint (nanoseconds) +- `metrics`: Dictionary of metric families with samples + - **Counter/Gauge:** `{"value": 42.0}` or `{"labels": {...}, "value": 42.0}` + - **Histogram:** `{"histogram": {"le": count, ...}, "sum": X, "count": N}` (le = bucket upper bounds) + - **Summary:** `{"summary": {"quantile": value, ...}, "sum": X, "count": N}` (quantile = percentile labels) + +**Space Optimization with Deduplication:** +The file is automatically **deduplicated** per endpoint to reduce file size while preserving accurate timeline information: + +1. **First occurrence** of metrics → always written (marks start of period) +2. **Consecutive identical metrics** → skipped and counted +3. **Change detected** → last duplicate written (marks end of period), then new record written (marks start of new period) + +**Example:** Input `A,A,A,B,B,C,D,D,D,D` → Output `A,A,B,B,C,D,D` + +This ensures you have actual timestamp observations for when metrics changed, enabling accurate duration calculations and time-series analysis. Deduplication uses equality comparison on the metrics dictionary for each endpoint separately. + +#### 2. Metadata: `server_metrics_metadata.json` + +Pretty-printed JSON with metric schemas, info metrics, and documentation: + +```json +{ + "endpoints": { + "http://localhost:8000/metrics": { + "endpoint_url": "http://localhost:8000/metrics", + "info_metrics": { + "python_info": { + "description": "Python platform information", + "labels": [ + { + "implementation": "CPython", + "major": "3", + "minor": "12", + "patchlevel": "10", + "version": "3.12.10" + } + ] + } + }, + "metric_schemas": { + "dynamo_frontend_inflight_requests": { + "type": "gauge", + "description": "Number of inflight requests" + }, + "dynamo_frontend_queued_requests": { + "type": "gauge", + "description": "Number of queued requests" + }, + "dynamo_frontend_time_to_first_token_seconds": { + "type": "histogram", + "description": "Time to first token in seconds" + }, + "dynamo_frontend_request_duration_seconds": { + "type": "histogram", + "description": "Request duration in seconds" + }, + "dynamo_frontend_requests": { + "type": "counter", + "description": "Total number of requests processed" + } + } + } + } +} +``` + +**Contains:** +- Metric names and types (counter, gauge, histogram, summary) +- Description text explaining what each metric measures +- Endpoint URLs and display names + +**Info Metrics:** +- Info metrics (ending in _info) contain static system information that doesn't change over time. +- We store only the labels (not values) since the labels contain the actual information and values are typically just 1.0. +- The description text explains what the info metric measures. + +> [!TIP] +> Use the metadata file to understand what metrics are available and how to interpret the JSONL data. + +> [!NOTE] +> **Output Directory Structure:** Files are created in `artifacts/{run-name}/` where `{run-name}` is automatically generated from your model, endpoint type, schedule, and concurrency (e.g., `Qwen_Qwen3-0.6B-openai-chat-concurrency4`). +> +> **Custom Filenames:** When using `--profile-export-prefix custom_name`, files become: +> - `artifacts/{run-name}/custom_name_server_metrics.jsonl` +> - `artifacts/{run-name}/custom_name_server_metrics_metadata.json` + +## Configuration Options + +### Environment Variables + +Customize collection behavior with environment variables: + +**Available Settings:** + +| Environment Variable | Default | Range | Description | +|---------------------|---------|-------|-------------| +| `AIPERF_SERVER_METRICS_COLLECTION_INTERVAL` | 0.1s | 0.01s - 300s | Metrics collection frequency | +| `AIPERF_SERVER_METRICS_COLLECTION_FLUSH_PERIOD` | 2.0s | 0.0s - 30s | Wait time for final metrics after benchmark | +| `AIPERF_SERVER_METRICS_DEFAULT_BACKEND_PORTS` | empty list | comma-separated | Additional ports to check during auto-discovery (beyond inference endpoint port) | +| `AIPERF_SERVER_METRICS_REACHABILITY_TIMEOUT` | 5s | 1s - 300s | Timeout for endpoint reachability tests | +| `AIPERF_SERVER_METRICS_SHUTDOWN_DELAY` | 5.0s | 1.0s - 300s | Delay before shutting down to allow final command response transmission | + +## Multi-Node Server Metrics + +For distributed LLM deployments (tensor parallelism, pipeline parallelism), collect metrics from all nodes: + +```bash +# Example: 4-node distributed Dynamo deployment +aiperf profile \ + --model meta-llama/Llama-3.1-70B \ + --endpoint-type chat \ + --endpoint /v1/chat/completions \ + --url http://node-0:8000 \ + --concurrency 16 \ + --request-count 500 \ + --server-metrics \ + node-0:8000 \ + node-1:8081 \ + node-2:8081 \ + node-3:8081 +``` + +**Output Structure:** +Each endpoint's metrics are stored separately in the JSONL file with its `endpoint_url` field, allowing you to: +- Analyze per-node performance +- Detect load imbalances +- Monitor distributed system health + +## Understanding the Data Format + +### JSONL Record Structure + +Each line in `server_metrics_export.jsonl` is a JSON object containing ALL metrics from one endpoint at one point in time. + +**Example from Dynamo frontend:** + +```json +{ + "endpoint_url": "http://localhost:8000/metrics", + "timestamp_ns": 1763591215220757503, + "endpoint_latency_ns": 719764167, + "metrics": { + "dynamo_frontend_requests": [ + { + "labels": { + "endpoint": "chat_completions", + "model": "openai/gpt-oss-20b", + "request_type": "unary", + "status": "success" + }, + "value": 1000.0 + } + ], + "dynamo_frontend_queued_requests": [ + { + "labels": {"model": "openai/gpt-oss-20b"}, + "value": 0.0 + } + ], + "dynamo_frontend_time_to_first_token_seconds": [ + { + "labels": {"model": "openai/gpt-oss-20b"}, + "histogram": { + "0": 0.0, + "0.0022": 0.0, + "0.0047": 0.0, + "0.01": 0.0, + "0.022": 0.0, + "0.047": 0.0, + "0.1": 835.0, + "0.22": 888.0, + "0.47": 892.0, + "1": 899.0, + "2.2": 900.0, + "4.7": 900.0, + "10": 1000.0, + "22": 1000.0, + "48": 1000.0, + "100": 1000.0, + "220": 1000.0, + "480": 1000.0, + "+Inf": 1000.0 + }, + "sum": 765.1582357100003, + "count": 1000.0 + } + ], + "dynamo_frontend_request_duration_seconds": [ + { + "labels": {"model": "openai/gpt-oss-20b"}, + "histogram": { + "0": 0.0, + "1.9": 0.0, + "3.4": 10.0, + "6.3": 212.0, + "12": 554.0, + "22": 969.0, + "40": 1000.0, + "75": 1000.0, + "140": 1000.0, + "260": 1000.0, + "+Inf": 1000.0 + }, + "sum": 11336.94603903602, + "count": 1000.0 + } + ] + } +} +``` + +**Top-Level Fields:** +- `endpoint_url`: Source Prometheus endpoint URL +- `timestamp_ns`: Unix timestamp in nanoseconds when metrics were collected +- `endpoint_latency_ns`: Time taken to fetch metrics from endpoint (nanoseconds) +- `metrics`: Dictionary containing ALL metrics from this endpoint at this timestamp + +**Sample Structure by Type:** +- **Counter/Gauge:** `{"labels": {...}, "value": N}` +- **Histogram:** `{"labels": {...}, "histogram": {"le": count, ...}, "sum": N, "count": N}` +- **Summary:** `{"labels": {...}, "summary": {"quantile": value, ...}, "sum": N, "count": N}` + +### Multi-Endpoint Data (Interleaved) + +When collecting from multiple endpoints, records are **interleaved by write time** (when deduplication completes), not strictly by collection time. + +**Example from Dynamo with component (8081) and frontend (8000) endpoints:** + +```jsonl +{"endpoint_url":"http://localhost:8081/metrics","timestamp_ns":1763591215213919629,"endpoint_latency_ns":712690779,"metrics":{...}} +{"endpoint_url":"http://localhost:8000/metrics","timestamp_ns":1763591215220757503,"endpoint_latency_ns":719764167,"metrics":{...}} +{"endpoint_url":"http://localhost:8081/metrics","timestamp_ns":1763591215313945146,"endpoint_latency_ns":100025517,"metrics":{...}} +{"endpoint_url":"http://localhost:8000/metrics","timestamp_ns":1763591215320776712,"endpoint_latency_ns":100831209,"metrics":{...}} +{"endpoint_url":"http://localhost:8081/metrics","timestamp_ns":1763591215414013463,"endpoint_latency_ns":100068317,"metrics":{...}} +{"endpoint_url":"http://localhost:8000/metrics","timestamp_ns":1763591215421601721,"endpoint_latency_ns":100825009,"metrics":{...}} +``` + +**Key Points:** +- Records are **NOT** strictly alternating between endpoints +- **Deduplication** causes multiple consecutive records from same endpoint (first occurrence + last duplicate before change) +- Use `endpoint_url` field to filter/group by endpoint during analysis +- Each endpoint is collected and deduplicated independently +- Timestamps reflect actual collection times, not write order + +### Example: vLLM Metrics Data + +**Example from vLLM inference server:** + +```json +{ + "endpoint_url": "http://localhost:8000/metrics", + "timestamp_ns": 1763591240123456789, + "endpoint_latency_ns": 42134567, + "metrics": { + "vllm:num_requests_running": [ + {"value": 12.0} + ], + "vllm:num_requests_waiting": [ + {"value": 3.0} + ], + "vllm:gpu_cache_usage_perc": [ + {"value": 0.72} + ], + "vllm:gpu_prefix_cache_hit_rate": [ + {"value": 0.85} + ], + "vllm:avg_prompt_throughput_toks_per_s": [ + {"value": 1523.4} + ], + "vllm:avg_generation_throughput_toks_per_s": [ + {"value": 892.1} + ], + "vllm:request_success_total": [ + {"value": 1500.0} + ], + "vllm:prompt_tokens_total": [ + {"value": 75000.0} + ], + "vllm:generation_tokens_total": [ + {"value": 45000.0} + ], + "vllm:time_to_first_token_seconds": [ + { + "histogram": { + "0.001": 0.0, + "0.005": 12.0, + "0.01": 145.0, + "0.02": 789.0, + "0.04": 1234.0, + "0.06": 1456.0, + "0.08": 1489.0, + "0.1": 1498.0, + "0.25": 1500.0, + "0.5": 1500.0, + "1.0": 1500.0, + "+Inf": 1500.0 + }, + "sum": 32.456, + "count": 1500.0 + } + ], + "vllm:e2e_request_latency_seconds": [ + { + "histogram": { + "0.01": 0.0, + "0.025": 5.0, + "0.05": 23.0, + "0.075": 78.0, + "0.1": 234.0, + "0.15": 567.0, + "0.2": 890.0, + "0.3": 1234.0, + "0.4": 1456.0, + "0.5": 1489.0, + "1.0": 1500.0, + "2.5": 1500.0, + "+Inf": 1500.0 + }, + "sum": 245.678, + "count": 1500.0 + } + ] + } +} +``` + +**Key vLLM Metrics Shown:** +- **Gauges:** Current queue sizes, cache utilization percentages, throughput rates +- **Counters:** Cumulative success counts and token totals +- **Histograms:** Latency distributions with bucket counts showing percentile breakdowns + +### Example: Dynamo Component Metrics Data + +**Example from Dynamo component (work handler) endpoint:** + +```json +{ + "endpoint_url": "http://localhost:8081/metrics", + "timestamp_ns": 1763591240234567890, + "endpoint_latency_ns": 38765432, + "metrics": { + "dynamo_component_requests_total": [ + {"value": 2340.0} + ], + "dynamo_component_errors_total": [ + {"value": 12.0} + ], + "dynamo_component_request_bytes_total": [ + {"value": 1170000.0} + ], + "dynamo_component_response_bytes_total": [ + {"value": 4656000.0} + ], + "dynamo_component_tasks_issued_total": [ + {"value": 2340.0} + ], + "dynamo_component_tasks_success_total": [ + {"value": 2328.0} + ], + "dynamo_component_tasks_failed_total": [ + {"value": 12.0} + ], + "dynamo_component_inflight_requests": [ + {"value": 8.0} + ], + "dynamo_component_uptime_seconds": [ + {"value": 3456.0} + ], + "dynamo_component_request_duration_seconds": [ + { + "histogram": { + "0.005": 0.0, + "0.01": 23.0, + "0.025": 456.0, + "0.05": 1234.0, + "0.1": 2012.0, + "0.25": 2298.0, + "0.5": 2328.0, + "1.0": 2340.0, + "+Inf": 2340.0 + }, + "sum": 287.654, + "count": 2340.0 + } + ] + } +} +``` + +**Key Dynamo Component Metrics Shown:** +- **Counters:** Request counts, error counts, data transfer bytes, task completion status +- **Gauges:** Concurrent request count, uptime tracking +- **Histograms:** Request duration distributions + +### Metadata File Structure + +The `server_metrics_metadata.json` file describes all collected metrics: + +```json +{ + "endpoints": { + "http://localhost:8000/metrics": { + "endpoint_url": "http://localhost:8000/metrics", + "info_metrics": { + "python_info": { + "description": "Python platform information", + "labels": [ + { + "implementation": "CPython", + "major": "3", + "minor": "12", + "patchlevel": "10", + "version": "3.12.10" + } + ] + } + }, + "metric_schemas": { + "vllm:num_requests_running": { + "type": "gauge", + "description": "Number of requests currently running on GPU." + }, + "vllm:num_requests_waiting": { + "type": "gauge", + "description": "Number of requests waiting to be processed." + }, + "vllm:gpu_cache_usage_perc": { + "type": "gauge", + "description": "GPU KV-cache usage. 1 means 100 percent usage." + }, + "vllm:gpu_prefix_cache_hit_rate": { + "type": "gauge", + "description": "GPU prefix cache block hit rate." + }, + "vllm:request_success_total": { + "type": "counter", + "description": "Count of successfully processed requests." + }, + "vllm:time_to_first_token_seconds": { + "type": "histogram", + "description": "Histogram of time to first token in seconds." + }, + "vllm:e2e_request_latency_seconds": { + "type": "histogram", + "description": "Histogram of end to end request latency in seconds." + } + } + }, + "http://localhost:8081/metrics": { + "endpoint_url": "http://localhost:8081/metrics", + "metric_schemas": { + "dynamo_component_requests_total": { + "type": "counter", + "description": "Total component requests processed." + }, + "dynamo_component_errors_total": { + "type": "counter", + "description": "Total processing errors." + }, + "dynamo_component_inflight_requests": { + "type": "gauge", + "description": "Concurrent requests being processed." + }, + "dynamo_component_request_duration_seconds": { + "type": "histogram", + "description": "Component request duration in seconds." + } + } + } + } +} +``` + +**Use the metadata file to:** +- Discover what metrics are available from each endpoint +- Understand metric types (counter, gauge, histogram, summary) +- Read metric descriptions and understand what they measure +- View info metrics with system/platform information + +## Next Steps + +- **GPU Telemetry:** Combine server metrics with [GPU telemetry](gpu-telemetry.md) for comprehensive observability
+ +## Summary + +Server metrics collection in AIPerf provides: + +✅ **Enabled by default** - automatic discovery of Prometheus endpoints (checks inference endpoint port)
+✅ **Comprehensive collection** of all exposed metrics (counters, gauges, histograms, summaries)
+✅ **Efficient storage** with automatic deduplication (per endpoint)
+✅ **Multi-node support** for distributed deployments
+✅ **Easy analysis** with JSONL format and metadata schemas
diff --git a/src/aiperf/common/config/config_defaults.py b/src/aiperf/common/config/config_defaults.py index e252fa248..04e09bf79 100644 --- a/src/aiperf/common/config/config_defaults.py +++ b/src/aiperf/common/config/config_defaults.py @@ -149,6 +149,8 @@ class OutputDefaults: PROFILE_EXPORT_JSONL_FILE = Path("profile_export.jsonl") PROFILE_EXPORT_RAW_JSONL_FILE = Path("profile_export_raw.jsonl") PROFILE_EXPORT_GPU_TELEMETRY_JSONL_FILE = Path("gpu_telemetry_export.jsonl") + SERVER_METRICS_EXPORT_JSONL_FILE = Path("server_metrics_export.jsonl") + SERVER_METRICS_METADATA_JSON_FILE = Path("server_metrics_metadata.json") EXPORT_LEVEL = ExportLevel.RECORDS SLICE_DURATION = None diff --git a/src/aiperf/common/config/groups.py b/src/aiperf/common/config/groups.py index c65e68b7f..9206b9e4e 100644 --- a/src/aiperf/common/config/groups.py +++ b/src/aiperf/common/config/groups.py @@ -23,6 +23,7 @@ class Groups: IMAGE_INPUT = Group.create_ordered("Image Input") VIDEO_INPUT = Group.create_ordered("Video Input") SERVICE = Group.create_ordered("Service") + SERVER_METRICS = Group.create_ordered("Server Metrics") TELEMETRY = Group.create_ordered("Telemetry") UI = Group.create_ordered("UI") WORKERS = Group.create_ordered("Workers") diff --git a/src/aiperf/common/config/output_config.py b/src/aiperf/common/config/output_config.py index e90d8f334..b032be1a2 100644 --- a/src/aiperf/common/config/output_config.py +++ b/src/aiperf/common/config/output_config.py @@ -75,6 +75,12 @@ class OutputConfig(BaseConfig): _profile_export_gpu_telemetry_jsonl_file: Path = ( OutputDefaults.PROFILE_EXPORT_GPU_TELEMETRY_JSONL_FILE ) + _server_metrics_export_jsonl_file: Path = ( + OutputDefaults.SERVER_METRICS_EXPORT_JSONL_FILE + ) + _server_metrics_metadata_json_file: Path = ( + OutputDefaults.SERVER_METRICS_METADATA_JSON_FILE + ) @model_validator(mode="after") def set_export_filenames(self) -> Self: @@ -85,10 +91,14 @@ def set_export_filenames(self) -> Self: base_path = self.profile_export_prefix base_str = str(base_path) + # Check complex suffixes first (longest to shortest) to avoid double-suffixing + # e.g., if user passes "foo_raw.jsonl", we want "foo" not "foo_raw" suffixes_to_strip = [ + "_server_metrics_metadata.json", + "_server_metrics.jsonl", + "_gpu_telemetry.jsonl", "_timeslices.csv", "_timeslices.json", - "_gpu_telemetry.jsonl", "_raw.jsonl", ".csv", ".json", @@ -108,7 +118,12 @@ def set_export_filenames(self) -> Self: self._profile_export_gpu_telemetry_jsonl_file = Path( f"{base_str}_gpu_telemetry.jsonl" ) - + self._server_metrics_export_jsonl_file = Path( + f"{base_str}_server_metrics.jsonl" + ) + self._server_metrics_metadata_json_file = Path( + f"{base_str}_server_metrics_metadata.json" + ) return self slice_duration: Annotated[ @@ -149,3 +164,11 @@ def profile_export_raw_jsonl_file(self) -> Path: @property def profile_export_gpu_telemetry_jsonl_file(self) -> Path: return self.artifact_directory / self._profile_export_gpu_telemetry_jsonl_file + + @property + def server_metrics_export_jsonl_file(self) -> Path: + return self.artifact_directory / self._server_metrics_export_jsonl_file + + @property + def server_metrics_metadata_json_file(self) -> Path: + return self.artifact_directory / self._server_metrics_metadata_json_file diff --git a/src/aiperf/common/config/user_config.py b/src/aiperf/common/config/user_config.py index f4d615861..8f532386d 100644 --- a/src/aiperf/common/config/user_config.py +++ b/src/aiperf/common/config/user_config.py @@ -213,7 +213,6 @@ def _count_dataset_entries(self) -> int: gpu_telemetry: Annotated[ list[str] | None, Field( - default=None, description=( "Enable GPU telemetry console display and optionally specify: " "(1) 'dashboard' for realtime dashboard mode, " @@ -229,7 +228,7 @@ def _count_dataset_entries(self) -> int: consume_multiple=True, group=Groups.TELEMETRY, ), - ] + ] = None _gpu_telemetry_mode: GPUTelemetryMode = GPUTelemetryMode.SUMMARY _gpu_telemetry_urls: list[str] = [] @@ -286,6 +285,55 @@ def gpu_telemetry_metrics_file(self) -> Path | None: """Get the path to custom GPU metrics CSV file.""" return self._gpu_telemetry_metrics_file + server_metrics: Annotated[ + list[str] | None, + Field( + description=( + "Server metrics collection (ENABLED BY DEFAULT). " + "Automatically collects from inference endpoint base_url + `/metrics`. " + "Optionally specify additional custom Prometheus-compatible endpoint URLs " + "(e.g., http://node1:8081/metrics, http://node2:9090/metrics). " + "Use AIPERF_SERVER_METRICS_ENABLED=false to disable. " + "Example: `--server-metrics node1:8081 node2:9090/metrics` for additional endpoints" + ), + ), + BeforeValidator(parse_str_or_list), + CLIParameter( + name=("--server-metrics",), + consume_multiple=True, + group=Groups.SERVER_METRICS, + ), + ] = None + + _server_metrics_urls: list[str] = [] + + @model_validator(mode="after") + def _parse_server_metrics_config(self) -> Self: + """Parse server_metrics list into URLs. + + Check Environment.SERVER_METRICS.ENABLED to see if collection is enabled. + Empty list [] means enabled with automatic discovery only. + Non-empty list means enabled with custom URLs. + """ + from aiperf.common.metric_utils import normalize_metrics_endpoint_url + + urls: list[str] = [] + + for item in self.server_metrics or []: + # Check for URLs (anything with : or starting with http) + if item.startswith("http") or ":" in item: + normalized_url = item if item.startswith("http") else f"http://{item}" + normalized_url = normalize_metrics_endpoint_url(normalized_url) + urls.append(normalized_url) + + self._server_metrics_urls = urls + return self + + @property + def server_metrics_urls(self) -> list[str]: + """Get the parsed server metrics Prometheus endpoint URLs.""" + return self._server_metrics_urls + @model_validator(mode="after") def _compute_config(self) -> Self: """Compute additional configuration. diff --git a/src/aiperf/common/duplicate_tracker.py b/src/aiperf/common/duplicate_tracker.py new file mode 100644 index 000000000..1ed11d003 --- /dev/null +++ b/src/aiperf/common/duplicate_tracker.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Duplicate tracker for deduplicating records.""" + +import asyncio +from collections import defaultdict +from collections.abc import Callable +from typing import Any, Generic, TypeVar + +from aiperf.common.mixins import AIPerfLoggerMixin + +TRecord = TypeVar("TRecord", bound=Any) + + +class AsyncKeyedDuplicateTracker(AIPerfLoggerMixin, Generic[TRecord]): + """Tracker for deduplicating records by key and value. + + Args: + key_function: A function that takes a record and returns a key for tracking duplicates. This is used to group records by key. + value_function: A function that takes a record and returns a value for comparison. This is used to compare the current value to the previous value. + + Notes: + The key_function and value_function are used to group records by key and compare values. This is useful for cases where + the record itself contains timestamps or other metadata that is not relevant to the value being compared for deduplication. + + Tracks the previous record for each key and detects duplicates. + + Deduplication logic: + Consecutive identical values are suppressed to save + storage while preserving complete timeline information. The strategy: + + 1. First occurrence → always written (marks start of period) + 2. Duplicates → skipped and counted + 3. Change detected → last duplicate written, then new record + (provides end timestamp of previous period + start of new period) + + Example: Input A,A,A,B,B,C,D,D,D,D → Output A,A,B,B,C,D,D + + Why write the last occurrence? Time-series data needs actual observations: + Without: A@t1, B@t4 ← You could guess A ended at ~t3, but no proof + With: A@t1, A@t3, B@t4 ← A was observed until t3 + + Without the last occurrence, you'd rely on interpolation/assumptions rather + than actual measured data. This enables accurate duration calculations, + timeline visualization (Grafana), and time-weighted averages. Essential + for metrics requiring precise change detection. + + Deduplication uses equality (==) on the metrics dictionary for each separate endpoint. + """ + + def __init__( + self, + key_function: Callable[[TRecord], str], + value_function: Callable[[TRecord], Any] = lambda x: x, + **kwargs, + ) -> None: + super().__init__(**kwargs) + # Lock for safe access to creating dynamic locks for deduplication. + self._lock_creation_lock = asyncio.Lock() + self._dupe_locks: dict[str, asyncio.Lock] = {} + self._dupe_counts: dict[str, int] = defaultdict(int) + # Keep track of the previous record for each endpoint to detect duplicates. + self._previous_records: dict[str, TRecord] = {} + self._key_function = key_function + self._value_function = value_function + + async def deduplicate_record(self, record: TRecord) -> list[TRecord]: + """Deduplicate a record and return the records to write. + + Args: + record: The record to deduplicate. + + Returns: + A list of records to write containing either an empty list, the current record, or the current and previous records. + """ + records_to_write: list[TRecord] = [record] + + key = self._key_function(record) + value = self._value_function(record) + + if key not in self._dupe_locks: + # Create a lock for this key if it doesn't exist + async with self._lock_creation_lock: + # Double check inside the lock to avoid race conditions + if key not in self._dupe_locks: + self.trace(lambda: f"Creating lock for key: {key}") + self._dupe_locks[key] = asyncio.Lock() + + # Check for duplicates and update the records to write + async with self._dupe_locks[key]: + if key in self._previous_records: + if self._value_function(self._previous_records[key]) == value: + self._dupe_counts[key] += 1 + self.trace( + lambda: f"Duplicate found for key: {key}, incrementing dupe count to {self._dupe_counts[key]}" + ) + # Clear the list instead of return so the previous record is still updated down below + records_to_write.clear() + + # If we have duplicates, we need to write the previous record before the current record, + # in order to know when the change actually occurs. + elif self._dupe_counts[key] > 0: + self._dupe_counts[key] = 0 + self.trace( + lambda: f"New change detected for key: {key}, writing previous record and resetting dupe count" + ) + records_to_write.insert(0, self._previous_records[key]) + + self._previous_records[key] = record + + return records_to_write + + async def flush_remaining_duplicates(self) -> list[TRecord]: + """Flush remaining duplicates for all keys on shutdown. + + When the system is stopping, there may be pending duplicates that haven't + been written yet (because we're still in a duplicate sequence). This method + returns the last occurrence for each key that has pending duplicates. + + Returns: + A list of records that need to be flushed. + """ + records_to_flush: list[TRecord] = [] + + # Iterate through all keys that have pending duplicates + for key in list(self._dupe_counts.keys()): + if self._dupe_counts[key] > 0 and key in self._dupe_locks: + async with self._dupe_locks[key]: + if self._dupe_counts[key] > 0 and key in self._previous_records: + self.trace( + lambda key=key: f"Flushing {self._dupe_counts[key]} remaining duplicates for key: {key}" + ) + records_to_flush.append(self._previous_records[key]) + self._dupe_counts[key] = 0 + + return records_to_flush diff --git a/src/aiperf/common/enums/__init__.py b/src/aiperf/common/enums/__init__.py index 3f61b0b62..f2ffb1e71 100644 --- a/src/aiperf/common/enums/__init__.py +++ b/src/aiperf/common/enums/__init__.py @@ -84,6 +84,9 @@ RecordProcessorType, ResultsProcessorType, ) +from aiperf.common.enums.prometheus_enums import ( + PrometheusMetricType, +) from aiperf.common.enums.service_enums import ( LifecycleState, ServiceRegistrationStatus, @@ -155,6 +158,7 @@ "ModelSelectionStrategy", "PowerMetricUnit", "PowerMetricUnitInfo", + "PrometheusMetricType", "PromptSource", "PublicDatasetType", "RecordProcessorType", diff --git a/src/aiperf/common/enums/message_enums.py b/src/aiperf/common/enums/message_enums.py index 3ba913cef..734f69777 100644 --- a/src/aiperf/common/enums/message_enums.py +++ b/src/aiperf/common/enums/message_enums.py @@ -47,5 +47,7 @@ class MessageType(CaseInsensitiveStrEnum): STATUS = "status" TELEMETRY_RECORDS = "telemetry_records" TELEMETRY_STATUS = "telemetry_status" + SERVER_METRICS_RECORDS = "server_metrics_records" + SERVER_METRICS_STATUS = "server_metrics_status" WORKER_HEALTH = "worker_health" WORKER_STATUS_SUMMARY = "worker_status_summary" diff --git a/src/aiperf/common/enums/post_processor_enums.py b/src/aiperf/common/enums/post_processor_enums.py index 6d5617451..5e753cf48 100644 --- a/src/aiperf/common/enums/post_processor_enums.py +++ b/src/aiperf/common/enums/post_processor_enums.py @@ -39,6 +39,9 @@ class ResultsProcessorType(CaseInsensitiveStrEnum): """Processor that exports per-record metrics to JSONL files with display unit conversion and filtering. Only enabled when export_level is set to RECORDS.""" + SERVER_METRICS_JSONL_WRITER = "server_metrics_jsonl_writer" + """Processor that exports server metrics data to JSONL files.""" + TELEMETRY_EXPORT = "telemetry_export" """Processor that exports per-record GPU telemetry data to JSONL files. Writes each TelemetryRecord as it arrives from the TelemetryManager.""" diff --git a/src/aiperf/common/enums/prometheus_enums.py b/src/aiperf/common/enums/prometheus_enums.py new file mode 100644 index 000000000..9bc1efde0 --- /dev/null +++ b/src/aiperf/common/enums/prometheus_enums.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +from typing_extensions import Self + +from aiperf.common.enums.base_enums import CaseInsensitiveStrEnum + + +class PrometheusMetricType(CaseInsensitiveStrEnum): + """Prometheus metric types as defined in the Prometheus exposition format. + + See: https://prometheus.io/docs/concepts/metric_types/ + """ + + COUNTER = "counter" + """Counter: A cumulative metric that represents a single monotonically increasing counter.""" + + GAUGE = "gauge" + """Gauge: A metric that represents a single numerical value that can arbitrarily go up and down.""" + + HISTOGRAM = "histogram" + """Histogram: Samples observations and counts them in configurable buckets.""" + + SUMMARY = "summary" + """Summary: Similar to histogram, samples observations and provides quantiles.""" + + UNKNOWN = "unknown" + """Unknown: Untyped metric (prometheus_client uses 'unknown' instead of 'untyped').""" + + @classmethod + def _missing_(cls, value: Any) -> Self: + """ + Handles cases where a value is not directly found in the enumeration. + + This method is called when an attempt is made to access an enumeration + member using a value that does not directly match any of the defined + members. It provides custom logic to handle such cases. + + Returns: + The matching enumeration member if a case-insensitive match is found + for string values; otherwise, returns PrometheusMetricType.UNKNOWN. + """ + try: + return super()._missing_(value) + except ValueError: + return cls.UNKNOWN diff --git a/src/aiperf/common/enums/service_enums.py b/src/aiperf/common/enums/service_enums.py index 64e0fd6c5..a2c3ac08b 100644 --- a/src/aiperf/common/enums/service_enums.py +++ b/src/aiperf/common/enums/service_enums.py @@ -45,9 +45,7 @@ class ServiceType(CaseInsensitiveStrEnum): WORKER_MANAGER = "worker_manager" WORKER = "worker" TELEMETRY_MANAGER = "telemetry_manager" - - # For testing purposes only - TEST = "test_service" + SERVER_METRICS_MANAGER = "server_metrics_manager" class ServiceRegistrationStatus(CaseInsensitiveStrEnum): diff --git a/src/aiperf/common/environment.py b/src/aiperf/common/environment.py index 72cfde83f..ab30fed78 100644 --- a/src/aiperf/common/environment.py +++ b/src/aiperf/common/environment.py @@ -7,17 +7,18 @@ All settings can be configured via environment variables with the AIPERF_ prefix. Structure: - Environment.DATASET.* - Dataset management - Environment.DEV.* - Development and debugging settings - Environment.GPU.* - GPU telemetry collection - Environment.HTTP.* - HTTP client socket and connection settings - Environment.LOGGING.* - Logging configuration - Environment.METRICS.* - Metrics collection and storage - Environment.RECORD.* - Record processing - Environment.SERVICE.* - Service lifecycle and communication - Environment.UI.* - User interface settings - Environment.WORKER.* - Worker management and scaling - Environment.ZMQ.* - ZMQ communication settings + Environment.DATASET.* - Dataset management + Environment.DEV.* - Development and debugging settings + Environment.GPU.* - GPU telemetry collection + Environment.HTTP.* - HTTP client socket and connection settings + Environment.LOGGING.* - Logging configuration + Environment.METRICS.* - Metrics collection and storage + Environment.RECORD.* - Record processing + Environment.SERVER_METRICS.* - Server metrics collection + Environment.SERVICE.* - Service lifecycle and communication + Environment.UI.* - User interface settings + Environment.WORKER.* - Worker management and scaling + Environment.ZMQ.* - ZMQ communication settings Examples: # Via environment variables: @@ -141,6 +142,12 @@ class _GPUSettings(BaseSettings): default=["http://localhost:9400/metrics", "http://localhost:9401/metrics"], description="Default DCGM endpoint URLs to check for GPU telemetry (comma-separated string or JSON array)", ) + EXPORT_BATCH_SIZE: int = Field( + ge=1, + le=1000000, + default=100, + description="Batch size for telemetry record export results processor", + ) REACHABILITY_TIMEOUT: int = Field( ge=1, le=300, @@ -318,6 +325,62 @@ class _RecordSettings(BaseSettings): ) +class _ServerMetricsSettings(BaseSettings): + """Server metrics collection configuration. + + Controls server metrics collection frequency, endpoint detection, and shutdown behavior. + Metrics are collected from Prometheus-compatible endpoints at the specified interval. + """ + + model_config = SettingsConfigDict( + env_prefix="AIPERF_SERVER_METRICS_", + env_parse_enums=True, + ) + + ENABLED: bool = Field( + default=True, + description="Enable server metrics collection (set to false to disable entirely)", + ) + COLLECTION_FLUSH_PERIOD: float = Field( + ge=0.0, + le=30.0, + default=2.0, + description="Time in seconds to continue collecting metrics after profiling completes, " + "allowing server-side metrics to flush/finalize before shutting down (default: 2.0s)", + ) + COLLECTION_INTERVAL: float = Field( + ge=0.01, + le=300.0, + default=0.33, + description="Server metrics collection interval in seconds (default: 330ms, ~3Hz)", + ) + DEFAULT_BACKEND_PORTS: Annotated[ + str | list[int], + BeforeValidator(parse_str_or_csv_list), + ] = Field( + default=[], + description="Default backend ports to check on inference endpoint hostname (comma-separated string or JSON array)", + ) + EXPORT_BATCH_SIZE: int = Field( + ge=1, + le=1000000, + default=100, + description="Batch size for server metrics jsonl writer export results processor", + ) + REACHABILITY_TIMEOUT: int = Field( + ge=1, + le=300, + default=5, + description="Timeout in seconds for checking server metrics endpoint reachability during init", + ) + SHUTDOWN_DELAY: float = Field( + ge=1.0, + le=300.0, + default=5.0, + description="Delay in seconds before shutting down server metrics service to allow command response transmission", + ) + + class _ServiceSettings(BaseSettings): """Service lifecycle and inter-service communication configuration. @@ -649,6 +712,10 @@ class _Environment(BaseSettings): default_factory=_RecordSettings, description="Record processing and export settings", ) + SERVER_METRICS: _ServerMetricsSettings = Field( + default_factory=_ServerMetricsSettings, + description="Server metrics collection settings", + ) SERVICE: _ServiceSettings = Field( default_factory=_ServiceSettings, description="Service lifecycle and communication settings", diff --git a/src/aiperf/common/messages/__init__.py b/src/aiperf/common/messages/__init__.py index 5551e2c65..ddc350c68 100644 --- a/src/aiperf/common/messages/__init__.py +++ b/src/aiperf/common/messages/__init__.py @@ -66,6 +66,10 @@ ProfileResultsMessage, RecordsProcessingStatsMessage, ) +from aiperf.common.messages.server_metrics_messages import ( + ServerMetricsRecordsMessage, + ServerMetricsStatusMessage, +) from aiperf.common.messages.service_messages import ( BaseServiceErrorMessage, BaseServiceMessage, @@ -134,6 +138,8 @@ "RegisterServiceCommand", "RegistrationMessage", "RequiresRequestNSMixin", + "ServerMetricsRecordsMessage", + "ServerMetricsStatusMessage", "ShutdownCommand", "ShutdownWorkersCommand", "SpawnWorkersCommand", diff --git a/src/aiperf/common/messages/server_metrics_messages.py b/src/aiperf/common/messages/server_metrics_messages.py new file mode 100644 index 000000000..331c789f1 --- /dev/null +++ b/src/aiperf/common/messages/server_metrics_messages.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from pydantic import Field + +from aiperf.common.enums import MessageType +from aiperf.common.messages.service_messages import BaseServiceMessage +from aiperf.common.models import ErrorDetails, ServerMetricsRecord +from aiperf.common.types import MessageTypeT + + +class ServerMetricsRecordsMessage(BaseServiceMessage): + """Message from the server metrics data collector to the records manager to notify it + of the server metrics records for a batch of server samples. + + Contains full server metrics records with all metadata. + """ + + message_type: MessageTypeT = MessageType.SERVER_METRICS_RECORDS + + collector_id: str = Field( + description="The ID of the server metrics data collector that collected the records" + ) + records: list[ServerMetricsRecord] = Field( + default_factory=list, + description="The server metrics records", + ) + error: ErrorDetails | None = Field( + default=None, + description="The error details if the server metrics record collection failed.", + ) + + @property + def valid(self) -> bool: + """Whether server metrics collection succeeded (empty response is valid).""" + return self.error is None + + @property + def has_data(self) -> bool: + """Whether any metrics were collected.""" + return len(self.records) > 0 + + +class ServerMetricsStatusMessage(BaseServiceMessage): + """Message from ServerMetricsManager to SystemController indicating server metrics availability.""" + + message_type: MessageTypeT = MessageType.SERVER_METRICS_STATUS + + enabled: bool = Field( + description="Whether server metrics collection is enabled and will produce results" + ) + reason: str | None = Field( + default=None, + description="Reason why server metrics is disabled (if enabled=False)", + ) + endpoints_configured: list[str] = Field( + default_factory=list, + description="List of Prometheus endpoint URLs configured", + ) + endpoints_reachable: list[str] = Field( + default_factory=list, + description="List of Prometheus endpoint URLs that were reachable and will provide data", + ) diff --git a/src/aiperf/common/metric_utils.py b/src/aiperf/common/metric_utils.py new file mode 100644 index 000000000..da626f752 --- /dev/null +++ b/src/aiperf/common/metric_utils.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from urllib.parse import urlparse + + +def normalize_metrics_endpoint_url(url: str) -> str: + """Ensure metrics endpoint URL ends with /metrics suffix. + + Works with Prometheus, DCGM, and other compatible endpoints. + This utility is used by both TelemetryManager and ServerMetricsManager + to ensure consistent URL formatting. + + Args: + url: Base URL or full metrics URL (e.g., "http://localhost:9400" or + "http://localhost:9400/metrics") + + Returns: + URL ending with /metrics with trailing slashes removed + (e.g., "http://localhost:9400/metrics") + + Raises: + ValueError: If URL is empty or whitespace-only + + Examples: + >>> normalize_metrics_endpoint_url("http://localhost:9400") + "http://localhost:9400/metrics" + >>> normalize_metrics_endpoint_url("http://localhost:9400/") + "http://localhost:9400/metrics" + >>> normalize_metrics_endpoint_url("http://localhost:9400/metrics") + "http://localhost:9400/metrics" + """ + if not url or not url.strip(): + raise ValueError("URL cannot be empty or whitespace-only") + + url = url.rstrip("/") + if not url.endswith("/metrics"): + url = f"{url}/metrics" + return url + + +def build_hostname_aware_prometheus_endpoints( + inference_endpoint_url: str, + default_ports: list[int], + include_inference_port: bool = True, +) -> list[str]: + """Build hostname-aware Prometheus/DCGM endpoint URLs based on inference endpoint. + + Extracts hostname and scheme from the inference endpoint URL and generates + Prometheus-compatible URLs for the specified ports on the same hostname. + This enables zero-config telemetry for distributed deployments. + + Args: + inference_endpoint_url: The inference endpoint URL (e.g., http://myserver:8000/v1/chat) + default_ports: List of ports to check on the same hostname (e.g., [9400, 9401]) + include_inference_port: Whether to include the inference endpoint port in the list of ports to check + + Returns: + List of Prometheus endpoint URLs with /metrics suffix + + Examples: + >>> build_hostname_aware_prometheus_endpoints("http://localhost:8000", [9400, 9401]) + ['http://localhost:9400/metrics', 'http://localhost:9401/metrics'] + >>> build_hostname_aware_prometheus_endpoints("http://gpu-server:8000", [8081, 6880]) + ['http://gpu-server:8081/metrics', 'http://gpu-server:6880/metrics'] + """ + if not inference_endpoint_url.startswith("http"): + inference_endpoint_url = f"http://{inference_endpoint_url}" + parsed = urlparse(inference_endpoint_url) + hostname = parsed.hostname or "localhost" + scheme = parsed.scheme or "http" + + ports_to_check = list(default_ports) + if include_inference_port: + ports_to_check.insert(0, parsed.port or (443 if scheme == "https" else 80)) + + # Build endpoints and deduplicate while preserving order + endpoints = [f"{scheme}://{hostname}:{port}/metrics" for port in ports_to_check] + return list(dict.fromkeys(endpoints)) diff --git a/src/aiperf/common/mixins/__init__.py b/src/aiperf/common/mixins/__init__.py index 17a3d71a0..68e6d6977 100644 --- a/src/aiperf/common/mixins/__init__.py +++ b/src/aiperf/common/mixins/__init__.py @@ -14,11 +14,18 @@ from aiperf.common.mixins.aiperf_logger_mixin import ( AIPerfLoggerMixin, ) +from aiperf.common.mixins.base_metrics_collector_mixin import ( + BaseMetricsCollectorMixin, + TErrorCallback, + TRecord, + TRecordCallback, +) from aiperf.common.mixins.base_mixin import ( BaseMixin, ) from aiperf.common.mixins.buffered_jsonl_writer_mixin import ( BufferedJSONLWriterMixin, + BufferedJSONLWriterMixinWithDeduplication, ) from aiperf.common.mixins.command_handler_mixin import ( CommandHandlerMixin, @@ -60,8 +67,10 @@ __all__ = [ "AIPerfLifecycleMixin", "AIPerfLoggerMixin", + "BaseMetricsCollectorMixin", "BaseMixin", "BufferedJSONLWriterMixin", + "BufferedJSONLWriterMixinWithDeduplication", "CommandHandlerMixin", "CommunicationMixin", "HooksMixin", @@ -72,6 +81,9 @@ "RealtimeMetricsMixin", "RealtimeTelemetryMetricsMixin", "ReplyClientMixin", + "TErrorCallback", + "TRecord", + "TRecordCallback", "TaskManagerMixin", "WorkerTrackerMixin", ] diff --git a/src/aiperf/common/mixins/base_metrics_collector_mixin.py b/src/aiperf/common/mixins/base_metrics_collector_mixin.py new file mode 100644 index 000000000..f3a9f368d --- /dev/null +++ b/src/aiperf/common/mixins/base_metrics_collector_mixin.py @@ -0,0 +1,238 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Base mixin for async HTTP metrics data collectors. + +This mixin provides common functionality for collecting metrics from HTTP endpoints, +used by both GPU telemetry and server metrics systems. +""" + +import asyncio +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +from typing import Generic, TypeVar + +import aiohttp + +from aiperf.common.hooks import background_task, on_init, on_stop +from aiperf.common.mixins import AIPerfLifecycleMixin +from aiperf.common.models import ErrorDetails + +# Type variables for records returned by collectors +TRecord = TypeVar("TRecord") +TRecordCallback = TypeVar( + "TRecordCallback", bound=Callable[[list[TRecord], str], Awaitable[None]] +) +TErrorCallback = TypeVar( + "TErrorCallback", bound=Callable[[ErrorDetails, str], Awaitable[None]] +) + + +class BaseMetricsCollectorMixin(AIPerfLifecycleMixin, ABC, Generic[TRecord]): + """Mixin providing async HTTP collection for metrics endpoints. + + This mixin encapsulates the pattern of periodically fetching metrics from + HTTP endpoints, parsing them, and delivering them via callbacks. + + Common patterns: + - aiohttp session management + - Reachability testing + - Background collection task with error handling + - Callback-based delivery + + Used by: + - TelemetryDataCollector (DCGM metrics) + - ServerMetricsDataCollector (Prometheus metrics) + """ + + def __init__( + self, + endpoint_url: str, + collection_interval: float, + reachability_timeout: float, + record_callback: TRecordCallback | None = None, + error_callback: TErrorCallback | None = None, + **kwargs, + ) -> None: + """Initialize the metrics collector. + + Args: + endpoint_url: URL of the metrics endpoint + collection_interval: Interval in seconds between collections + reachability_timeout: Timeout in seconds for reachability checks + record_callback: Optional callback to receive collected records + error_callback: Optional callback to receive collection errors + **kwargs: Additional arguments passed to super().__init__() + """ + self._endpoint_url = endpoint_url + self._collection_interval = collection_interval + self._reachability_timeout = reachability_timeout + self._record_callback = record_callback + self._error_callback = error_callback + self._session: aiohttp.ClientSession | None = None + super().__init__(**kwargs) + + @property + def endpoint_url(self) -> str: + """Get the metrics endpoint URL.""" + return self._endpoint_url + + @property + def collection_interval(self) -> float: + """Get the collection interval in seconds.""" + return self._collection_interval + + @on_init + async def _initialize_http_client(self) -> None: + """Initialize the aiohttp client session. + + Called automatically during initialization phase. + Creates an aiohttp ClientSession with appropriate timeout settings. + Uses connect timeout only (no total timeout) to allow long-running scrapes. + """ + timeout = aiohttp.ClientTimeout( + total=None, # No total timeout for ongoing scrapes + connect=self._reachability_timeout, # Fast connection timeout only + ) + self._session = aiohttp.ClientSession(timeout=timeout) + + @on_stop + async def _cleanup_http_client(self) -> None: + """Clean up the aiohttp client session. + + Called automatically during shutdown phase. + """ + if self._session: + await self._session.close() + self._session = None + + async def is_url_reachable(self) -> bool: + """Check if metrics endpoint is accessible. + + Attempts HEAD request first for efficiency, falls back to GET if HEAD is not supported. + Uses existing session if available, otherwise creates a temporary session. + + Returns: + True if endpoint responds with HTTP 200, False otherwise + """ + if not self._endpoint_url: + return False + + # Use existing session if available, otherwise create a temporary one + if self._session: + return await self._check_reachability_with_session(self._session) + else: + # Create a temporary session for reachability check + timeout = aiohttp.ClientTimeout(total=self._reachability_timeout) + async with aiohttp.ClientSession(timeout=timeout) as temp_session: + return await self._check_reachability_with_session(temp_session) + + async def _check_reachability_with_session( + self, session: aiohttp.ClientSession + ) -> bool: + """Check reachability using a specific session. + + Args: + session: aiohttp session to use for the check + + Returns: + True if endpoint is reachable with HTTP 200 + """ + try: + # Try HEAD first for efficiency + async with session.head( + self._endpoint_url, allow_redirects=False + ) as response: + if response.status == 200: + return True + # Fall back to GET if HEAD is not supported + async with session.get(self._endpoint_url) as response: + return response.status == 200 + except (aiohttp.ClientError, asyncio.TimeoutError): + return False + + @background_task(immediate=True, interval=lambda self: self._collection_interval) + async def _collect_metrics_task(self) -> None: + """Background task for collecting metrics at regular intervals. + + This uses the @background_task decorator which automatically handles + lifecycle management and stopping when the collector is stopped. + + Errors during collection are caught and sent via error_callback if configured. + CancelledError is propagated to allow graceful shutdown. + """ + try: + await self._collect_and_process_metrics() + except asyncio.CancelledError: + raise + except Exception as e: + if self._error_callback: + try: + await self._error_callback( + ErrorDetails.from_exception(e), + self.id, + ) + except Exception as callback_error: + self.error(f"Failed to send error via callback: {callback_error}") + else: + self.error(f"Metrics collection error: {e}") + + @abstractmethod + async def _collect_and_process_metrics(self) -> None: + """Collect metrics from endpoint and process them into records. + + Subclasses must implement this to: + 1. Fetch raw metrics data from the endpoint + 2. Parse data into record objects + 3. Send records via callback (if configured) + """ + pass + + async def _fetch_metrics_text(self) -> str: + """Fetch raw metrics text from the HTTP endpoint. + + Performs safety checks before making HTTP request: + - Verifies stop_requested flag to allow graceful shutdown + - Checks session is initialized and not closed + - Handles concurrent session closure gracefully + + Returns: + Raw metrics text from the endpoint + + Raises: + RuntimeError: If HTTP session is not initialized + aiohttp.ClientError: If HTTP request fails + asyncio.CancelledError: If collector is being stopped or session is closed + """ + if self.stop_requested: + raise asyncio.CancelledError + + # Snapshot session to avoid race with _cleanup_http_client setting it to None + session = self._session + if not session: + raise RuntimeError("HTTP session not initialized") + + try: + if session.closed: + raise asyncio.CancelledError + + async with session.get(self._endpoint_url) as response: + response.raise_for_status() + return await response.text() + except (aiohttp.ClientConnectionError, RuntimeError) as e: + # Convert connection errors during shutdown to CancelledError + if self.stop_requested or session.closed: + raise asyncio.CancelledError from e + raise + + async def _send_records_via_callback(self, records: list[TRecord]) -> None: + """Send records to the callback if configured. + + Args: + records: List of records to send + """ + if records and self._record_callback: + try: + await self._record_callback(records, self.id) + except Exception as e: + self.error(f"Failed to send records via callback: {e!r}", exc_info=True) diff --git a/src/aiperf/common/mixins/buffered_jsonl_writer_mixin.py b/src/aiperf/common/mixins/buffered_jsonl_writer_mixin.py index dde09f27e..e21457f52 100644 --- a/src/aiperf/common/mixins/buffered_jsonl_writer_mixin.py +++ b/src/aiperf/common/mixins/buffered_jsonl_writer_mixin.py @@ -3,12 +3,14 @@ """Mixin for buffered JSONL writing with automatic flushing.""" import asyncio +from collections.abc import Callable from pathlib import Path -from typing import Generic +from typing import Any, Generic import aiofiles import orjson +from aiperf.common.duplicate_tracker import AsyncKeyedDuplicateTracker from aiperf.common.environment import Environment from aiperf.common.hooks import on_init, on_stop from aiperf.common.mixins.aiperf_lifecycle_mixin import AIPerfLifecycleMixin @@ -53,6 +55,16 @@ def __init__( self._batch_size = batch_size self._buffer_lock = asyncio.Lock() + try: + # Create the output file directory if it doesn't exist and clear the file + self.output_file.parent.mkdir(parents=True, exist_ok=True) + self.output_file.unlink(missing_ok=True) + except Exception as e: + self.exception( + f"Failed to create output file directory or clear file: {self.output_file}: {e!r}" + ) + raise + @on_init async def _open_file(self) -> None: """Open the file handle for writing in binary mode (called automatically on initialization).""" @@ -97,6 +109,16 @@ async def buffered_write(self, record: BaseModelT) -> None: except Exception as e: self.error(f"Failed to write record: {e!r}") + async def flush_buffer(self) -> None: + """Flush the buffer to disk.""" + async with self._buffer_lock: + buffer_to_flush = self._buffer + self._buffer = [] + if buffer_to_flush: + await self._flush_buffer(buffer_to_flush) + else: + self.debug(lambda: f"No buffer to flush: {self.output_file}") + async def _flush_buffer(self, buffer_to_flush: list[bytes]) -> None: """Write buffered records to disk using bulk write. @@ -166,3 +188,41 @@ async def _close_file(self) -> None: self.debug( f"{self.__class__.__name__}: {self.lines_written} JSONL lines written to {self.output_file}" ) + + +class BufferedJSONLWriterMixinWithDeduplication(BufferedJSONLWriterMixin[BaseModelT]): + """Mixin for buffered JSONL writing with automatic flushing and deduplication. + + + Args: + dedupe_key_function: A function that takes a record and returns a key for deduplication. + dedupe_value_function: A function that takes a record and returns a value for deduplication. + + See the AsyncKeyedDuplicateTracker class for more details on the deduplication logic. + """ + + def __init__( + self, + dedupe_key_function: Callable[[BaseModelT], str], + dedupe_value_function: Callable[[BaseModelT], Any] = lambda x: x, + **kwargs, + ): + super().__init__(**kwargs) + self._duplicate_tracker = AsyncKeyedDuplicateTracker[BaseModelT]( + key_function=dedupe_key_function, + value_function=dedupe_value_function, + ) + + async def buffered_write(self, record: BaseModelT) -> None: + """Write a Pydantic model to the buffer with automatic flushing and deduplication.""" + records_to_write = await self._duplicate_tracker.deduplicate_record(record) + for record in records_to_write: + await super().buffered_write(record) + + @on_stop + async def _flush_remaining_duplicates(self) -> None: + """Flush remaining duplicates on shutdown.""" + records_to_flush = await self._duplicate_tracker.flush_remaining_duplicates() + for record in records_to_flush: + await super().buffered_write(record) + await super().flush_buffer() diff --git a/src/aiperf/common/models/__init__.py b/src/aiperf/common/models/__init__.py index 5363bed61..e6a3799d3 100644 --- a/src/aiperf/common/models/__init__.py +++ b/src/aiperf/common/models/__init__.py @@ -3,7 +3,7 @@ ######################################################################## ## 🚩 mkinit flags 🚩 ## ######################################################################## -__ignore__ = [] +__ignore__ = ["logger"] ######################################################################## ## ⚠️ This file is auto-generated by mkinit ⚠️ ## ## ⚠️ Do not edit below this line ⚠️ ## @@ -100,7 +100,19 @@ SequenceLengthPair, create_balanced_distribution, create_uniform_distribution, - logger, +) +from aiperf.common.models.server_metrics_models import ( + HistogramData, + InfoMetricData, + MetricFamily, + MetricSample, + MetricSchema, + ServerMetricsMetadata, + ServerMetricsMetadataFile, + ServerMetricsRecord, + ServerMetricsSlimRecord, + SlimMetricSample, + SummaryData, ) from aiperf.common.models.service_models import ( ServiceRunInfo, @@ -149,17 +161,22 @@ "GpuSummary", "GpuTelemetryData", "GpuTelemetrySnapshot", + "HistogramData", "IOCounters", "Image", "ImageDataItem", "ImageResponseData", + "InfoMetricData", "InputsFile", "JsonExportData", "JsonMetricResult", "Media", + "MetricFamily", "MetricRecordInfo", "MetricRecordMetadata", "MetricResult", + "MetricSample", + "MetricSchema", "MetricValue", "ModelEndpointInfo", "ModelInfo", @@ -183,9 +200,15 @@ "SSEMessage", "SequenceLengthDistribution", "SequenceLengthPair", + "ServerMetricsMetadata", + "ServerMetricsMetadataFile", + "ServerMetricsRecord", + "ServerMetricsSlimRecord", "ServiceRunInfo", "SessionPayloads", + "SlimMetricSample", "StatsProtocol", + "SummaryData", "TelemetryExportData", "TelemetryHierarchy", "TelemetryMetrics", @@ -205,5 +228,4 @@ "WorkerTaskStats", "create_balanced_distribution", "create_uniform_distribution", - "logger", ] diff --git a/src/aiperf/common/models/server_metrics_models.py b/src/aiperf/common/models/server_metrics_models.py new file mode 100644 index 000000000..14559dea0 --- /dev/null +++ b/src/aiperf/common/models/server_metrics_models.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pydantic import Field + +from aiperf.common.enums import PrometheusMetricType +from aiperf.common.models.base_models import AIPerfBaseModel + + +class HistogramData(AIPerfBaseModel): + """Structured histogram data with buckets, sum, and count.""" + + buckets: dict[str, float] = Field( + description="Bucket upper bounds to counts {le: value}" + ) + sum: float | None = Field(default=None, description="Sum of all observed values") + count: float | None = Field( + default=None, description="Total number of observations" + ) + + +class SummaryData(AIPerfBaseModel): + """Structured summary data with quantiles, sum, and count.""" + + quantiles: dict[str, float] = Field( + description="Quantile to value {quantile: value}" + ) + sum: float | None = Field(default=None, description="Sum of all observed values") + count: float | None = Field( + default=None, description="Total number of observations" + ) + + +class SlimMetricSample(AIPerfBaseModel): + """Slim metric sample with minimal data using dictionary-based format. + + Optimized for JSONL export. Uses dictionary format for + histogram/summary data for clarity: + - Type and help text are in schema + - Histogram bucket labels (le values) map to their counts + - Summary quantile labels map to their values + - sum/count are optional fields at sample level (used for histogram/summary) + + Format examples: + - Counter/Gauge: {"value": 42.0} or {"labels": {...}, "value": 42.0} + - Histogram: {"histogram": {"0.01": 10, "0.1": 25, "1.0": 50, ...}, "sum": 100.0, "count": 50} + - Summary: {"summary": {"0.5": 0.1, "0.9": 0.5, "0.99": 1.0, ...}, "sum": 100.0, "count": 50} + """ + + labels: dict[str, str] | None = Field( + default=None, + description="Metric labels (excluding histogram/summary special labels). None if no labels.", + ) + value: float | None = Field( + default=None, description="Simple metric value (counter/gauge)" + ) + histogram: dict[str, float] | None = Field( + default=None, + description="Histogram bucket upper bounds (le) to counts mapping", + ) + summary: dict[str, float] | None = Field( + default=None, + description="Summary quantile to value mapping", + ) + sum: float | None = Field( + default=None, + description="Sum of all observed values (for histogram/summary)", + ) + count: float | None = Field( + default=None, + description="Total number of observations (for histogram/summary)", + ) + + +class MetricSample(AIPerfBaseModel): + """Single metric sample with labels and value.""" + + labels: dict[str, str] | None = Field( + default=None, + description="Metric labels (excluding histogram/summary special labels). None if no labels.", + ) + value: float | None = Field( + default=None, description="Simple metric value (counter/gauge)" + ) + histogram: HistogramData | None = Field( + default=None, description="Histogram data if metric is histogram type" + ) + summary: SummaryData | None = Field( + default=None, description="Summary data if metric is summary type" + ) + + def to_slim(self) -> SlimMetricSample: + """Convert to slim metric sample format. + + For histograms and summaries, converts to dictionary format where + bucket/quantile labels map to their counts/values. + + Returns: + SlimMetricSample with dictionary-based histogram/summary data + """ + if self.histogram: + return SlimMetricSample( + labels=self.labels, + value=self.value, + histogram=self.histogram.buckets, + sum=self.histogram.sum, + count=self.histogram.count, + ) + + if self.summary: + return SlimMetricSample( + labels=self.labels, + value=self.value, + summary=self.summary.quantiles, + sum=self.summary.sum, + count=self.summary.count, + ) + + return SlimMetricSample( + labels=self.labels, + value=self.value, + ) + + +class MetricFamily(AIPerfBaseModel): + """Group of related metrics with same name and type.""" + + type: PrometheusMetricType = Field(description="Metric type as enum") + description: str = Field(description="Metric description from HELP text") + samples: list[MetricSample] = Field( + description="Metric samples grouped by base labels" + ) + + +class MetricSchema(AIPerfBaseModel): + """Schema information for a metric (type and help text). + + Provides documentation for each metric collected from Prometheus endpoints. + Sent once per metric in ServerMetricsMetadata to avoid repeating in every record. + """ + + type: PrometheusMetricType = Field(description="Metric type as enum") + description: str = Field(description="Metric description from HELP text") + + +class InfoMetricData(AIPerfBaseModel): + """Complete data for an info metric including label data. + + Info metrics (ending in _info) contain static system information that doesn't + change over time. We store only the labels (not values) since the labels contain + the actual information and values are typically just 1.0. + """ + + description: str = Field(description="Metric description from HELP text") + labels: list[dict[str, str]] = Field( + description="List of label keys and values as reported by the Prometheus endpoint" + ) + + +class ServerMetricsSlimRecord(AIPerfBaseModel): + """Slim server metrics record containing only time-varying data. + + This record excludes static metadata (endpoint_url, metric types, help text) + to reduce JSONL file size. The metadata and schemas are stored separately in the + ServerMetricsMetadataFile. + """ + + endpoint_url: str = Field( + description="Source Prometheus metrics endpoint URL (e.g., 'http://localhost:8081/metrics')" + ) + timestamp_ns: int = Field( + description="Nanosecond wall-clock timestamp when metrics were collected (time_ns)" + ) + endpoint_latency_ns: int = Field( + description="Nanoseconds it took to collect the metrics from the endpoint" + ) + metrics: dict[str, list[SlimMetricSample]] = Field( + description="Metrics grouped by family name, mapping directly to slim sample list" + ) + + +class ServerMetricsMetadata(AIPerfBaseModel): + """Metadata for a server metrics endpoint that doesn't change over time. + + Includes metric schemas (type and help text) to avoid sending them in every record. + Info metrics (ending in _info) are stored with their complete label keys and values, + since they represent static information that doesn't change over time. + """ + + endpoint_url: str = Field(description="Prometheus metrics endpoint URL") + info_metrics: dict[str, InfoMetricData] = Field( + default_factory=dict, + description="Info metrics (ending in _info) with complete label keys and values as reported by the Prometheus endpoint", + ) + metric_schemas: dict[str, MetricSchema] = Field( + default_factory=dict, + description="Metric schemas (name, type, and help) as reported by the Prometheus endpoint", + ) + + +class ServerMetricsMetadataFile(AIPerfBaseModel): + """Container for all server metrics endpoint metadata. + + This model represents the complete server_metrics_metadata.json file structure, + mapping endpoint URLs to their metadata. + """ + + endpoints: dict[str, ServerMetricsMetadata] = Field( + default_factory=dict, + description="Dict mapping endpoint_url to ServerMetricsMetadata", + ) + + +class ServerMetricsRecord(AIPerfBaseModel): + """Single server metrics data point from Prometheus endpoint. + + This record contains all metrics scraped from one Prometheus endpoint at one point in time. + Used for hierarchical storage: endpoint_url -> time series data. + """ + + endpoint_url: str = Field( + description="Source Prometheus metrics endpoint URL (e.g., 'http://localhost:8081/metrics')" + ) + timestamp_ns: int = Field( + description="Nanosecond wall-clock timestamp when metrics were collected (time_ns)" + ) + endpoint_latency_ns: int = Field( + description="Nanoseconds it took to collect the metrics from the endpoint" + ) + metrics: dict[str, MetricFamily] = Field( + description="Metrics grouped by family name" + ) + + def to_slim(self) -> ServerMetricsSlimRecord: + """Convert to slim record using array-based format for histograms/summaries. + + Creates flat structure where metrics map directly to slim sample lists. + For histograms and summaries, uses array format with bucket counts/quantile values + and sum/count at the sample level. + + Excludes metrics ending in _info as they are stored separately in metadata. + + Returns: + ServerMetricsSlimRecord with only timestamp and slim samples (flat structure) + """ + slim_metrics = { + name: [sample.to_slim() for sample in family.samples] + for name, family in self.metrics.items() + if not name.endswith("_info") + } + + return ServerMetricsSlimRecord( + timestamp_ns=self.timestamp_ns, + endpoint_latency_ns=self.endpoint_latency_ns, + endpoint_url=self.endpoint_url, + metrics=slim_metrics, + ) + + def extract_metadata(self) -> ServerMetricsMetadata: + """Extract metadata from this record. + + Extracts metric schemas (type and description) and separates _info metrics + with their complete label data. + + Returns: + ServerMetricsMetadata with schemas and info metrics + """ + metric_schemas: dict[str, MetricSchema] = {} + info_metrics: dict[str, InfoMetricData] = {} + + for metric_name, metric_family in self.metrics.items(): + if metric_name.endswith("_info"): + labels_list = [ + sample.labels if sample.labels else {} + for sample in metric_family.samples + ] + info_metrics[metric_name] = InfoMetricData( + description=metric_family.description, + labels=labels_list, + ) + else: + metric_schemas[metric_name] = MetricSchema( + type=metric_family.type, + description=metric_family.description, + ) + + return ServerMetricsMetadata( + endpoint_url=self.endpoint_url, + metric_schemas=metric_schemas, + info_metrics=info_metrics, + ) diff --git a/src/aiperf/common/protocols.py b/src/aiperf/common/protocols.py index ff6584ac3..46fe36438 100644 --- a/src/aiperf/common/protocols.py +++ b/src/aiperf/common/protocols.py @@ -42,6 +42,7 @@ from aiperf.common.models.metadata import EndpointMetadata, TransportMetadata from aiperf.common.models.model_endpoint_info import ModelEndpointInfo from aiperf.common.models.record_models import MetricResult + from aiperf.common.models.server_metrics_models import ServerMetricsRecord from aiperf.dataset.loader.models import CustomDatasetT from aiperf.exporters.exporter_config import ExporterConfig, FileExportInfo from aiperf.metrics.metric_dicts import MetricRecordDict @@ -617,6 +618,37 @@ async def process_result(self, record_data: "MetricRecordsData") -> None: ... async def summarize(self) -> list["MetricResult"]: ... +@runtime_checkable +class ServerMetricsResultsProcessorProtocol(Protocol): + """Protocol for server metrics results processors that handle ServerMetricsRecord objects. + + This protocol is separate from ResultsProcessorProtocol because server metrics data + has fundamentally different structure (hierarchical Prometheus snapshots) compared + to inference metrics (flat key-value pairs). + + Processors implementing this protocol can handle both server metrics records + (time-varying data) and metadata messages (static endpoint information). + """ + + async def process_server_metrics_record( + self, record: "ServerMetricsRecord" + ) -> None: + """Process individual server metrics record with complete Prometheus snapshot. + + Args: + record: ServerMetricsRecord containing Prometheus metrics snapshot and metadata + """ + ... + + async def summarize(self) -> list["MetricResult"]: + """Generate list of MetricResult for server metrics data. + + Returns: + list of MetricResult containing the server metrics data and hierarchy. + """ + ... + + @runtime_checkable class TelemetryResultsProcessorProtocol(Protocol): """Protocol for telemetry results processors that handle TelemetryRecord objects. diff --git a/src/aiperf/controller/system_controller.py b/src/aiperf/controller/system_controller.py index 6b2285964..5ace82320 100644 --- a/src/aiperf/controller/system_controller.py +++ b/src/aiperf/controller/system_controller.py @@ -39,6 +39,7 @@ ProfileStartCommand, RealtimeMetricsCommand, RegisterServiceCommand, + ServerMetricsStatusMessage, ShutdownCommand, ShutdownWorkersCommand, SpawnWorkersCommand, @@ -132,6 +133,8 @@ def __init__( self._shutdown_lock = asyncio.Lock() self._endpoints_configured: list[str] = [] self._endpoints_reachable: list[str] = [] + self._server_metrics_endpoints_configured: list[str] = [] + self._server_metrics_endpoints_reachable: list[str] = [] self.debug("System Controller created") async def request_realtime_metrics(self) -> None: @@ -183,6 +186,9 @@ async def _start_services(self) -> None: self.debug("Starting optional TelemetryManager service") await self.service_manager.run_service(ServiceType.TELEMETRY_MANAGER, 1) + self.debug("Starting optional ServerMetricsManager service") + await self.service_manager.run_service(ServiceType.SERVER_METRICS_MANAGER, 1) + async with self.try_operation_or_stop("Register Services"): await self.service_manager.wait_for_all_services_registration( stop_event=self._stop_requested_event, @@ -378,6 +384,26 @@ async def _on_telemetry_status_message( f"GPU telemetry enabled - {len(message.endpoints_reachable)}/{len(message.endpoints_configured)} endpoint(s) reachable" ) + @on_message(MessageType.SERVER_METRICS_STATUS) + async def _on_server_metrics_status_message( + self, message: ServerMetricsStatusMessage + ) -> None: + """Handle server metrics status from ServerMetricsManager. + + ServerMetricsStatusMessage informs SystemController if server metrics results will be available. + """ + + self._server_metrics_endpoints_configured = message.endpoints_configured + self._server_metrics_endpoints_reachable = message.endpoints_reachable + + if not message.enabled: + reason_msg = f" - {message.reason}" if message.reason else "" + self.info(f"Server metrics disabled{reason_msg}") + else: + self.info( + f"Server metrics enabled - {len(message.endpoints_reachable)}/{len(message.endpoints_configured)} endpoint(s) reachable" + ) + @on_message(MessageType.COMMAND_RESPONSE) async def _process_command_response_message(self, message: CommandResponse) -> None: """Process a command response message.""" diff --git a/src/aiperf/gpu_telemetry/telemetry_data_collector.py b/src/aiperf/gpu_telemetry/telemetry_data_collector.py index 726f62197..b741a3be5 100644 --- a/src/aiperf/gpu_telemetry/telemetry_data_collector.py +++ b/src/aiperf/gpu_telemetry/telemetry_data_collector.py @@ -1,17 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import asyncio import time -from collections.abc import Awaitable, Callable -import aiohttp from prometheus_client.parser import text_string_to_metric_families from aiperf.common.environment import Environment -from aiperf.common.hooks import background_task, on_init, on_stop -from aiperf.common.mixins.aiperf_lifecycle_mixin import AIPerfLifecycleMixin -from aiperf.common.models import ErrorDetails, TelemetryMetrics, TelemetryRecord +from aiperf.common.mixins import ( + BaseMetricsCollectorMixin, + TErrorCallback, + TRecordCallback, +) +from aiperf.common.models import TelemetryMetrics, TelemetryRecord from aiperf.gpu_telemetry.constants import ( DCGM_TO_FIELD_MAPPING, SCALING_FACTORS, @@ -20,21 +20,19 @@ __all__ = ["TelemetryDataCollector"] -class TelemetryDataCollector(AIPerfLifecycleMixin): +class TelemetryDataCollector(BaseMetricsCollectorMixin[TelemetryRecord]): """Collects telemetry metrics from DCGM metrics endpoint using async architecture. Modern async collector that fetches GPU metrics from DCGM exporter and converts them to - TelemetryRecord objects. Uses AIPerf lifecycle management and background tasks. - - Extends AIPerfLifecycleMixin for proper lifecycle management - - Uses aiohttp for async HTTP requests + TelemetryRecord objects. Uses BaseMetricsCollectorMixin for HTTP collection patterns. + - Extends BaseMetricsCollectorMixin for async HTTP collection - Uses prometheus_client for robust metric parsing - - Uses @background_task for periodic collection - Sends TelemetryRecord list via callback function - - No local storage (follows centralized architecture) Args: dcgm_url: URL of the DCGM metrics endpoint (e.g., "http://localhost:9400/metrics") - collection_interval: Interval in seconds between metric collections (default: 1.0) + collection_interval: Interval in seconds between metric collections (default: from Environment) + reachability_timeout: Timeout in seconds for reachability checks (default: from Environment) record_callback: Optional async callback to receive collected records. Signature: async (records: list[TelemetryRecord], collector_id: str) -> None error_callback: Optional async callback to receive collection errors. @@ -45,175 +43,38 @@ class TelemetryDataCollector(AIPerfLifecycleMixin): def __init__( self, dcgm_url: str, - collection_interval: float | None = None, - record_callback: Callable[[list[TelemetryRecord], str], Awaitable[None]] - | None = None, - error_callback: Callable[[ErrorDetails, str], Awaitable[None]] | None = None, + collection_interval: float = Environment.GPU.COLLECTION_INTERVAL, + reachability_timeout: float = Environment.GPU.REACHABILITY_TIMEOUT, + record_callback: TRecordCallback | None = None, + error_callback: TErrorCallback | None = None, collector_id: str = "telemetry_collector", ) -> None: - self._dcgm_url = dcgm_url - self._collection_interval = ( - collection_interval - if collection_interval is not None - else Environment.GPU.COLLECTION_INTERVAL - ) - self._record_callback = record_callback - self._error_callback = error_callback self._scaling_factors = SCALING_FACTORS - self._session: aiohttp.ClientSession | None = None - - super().__init__(id=collector_id) - - @on_init - async def _initialize_http_client(self) -> None: - """Initialize the aiohttp client session. - - Called automatically by AIPerfLifecycleMixin during initialization phase. - Creates an aiohttp ClientSession with appropriate timeout settings. - """ - timeout = aiohttp.ClientTimeout(total=Environment.GPU.REACHABILITY_TIMEOUT) - self._session = aiohttp.ClientSession(timeout=timeout) - - @on_stop - async def _cleanup_http_client(self) -> None: - """Clean up the aiohttp client session. - - Called automatically by AIPerfLifecycleMixin during shutdown phase. - Race conditions with background tasks are handled by checking - self.stop_requested in the background task itself. - - Raises: - Exception: Any exception from session.close() is allowed to propagate - """ - if self._session: - await self._session.close() - self._session = None - - async def is_url_reachable(self) -> bool: - """Check if DCGM metrics endpoint is accessible. - - Attempts HEAD request first for efficiency, falls back to GET if HEAD is not supported. - Uses existing session if available, otherwise creates a temporary session. - - Returns: - bool: True if endpoint responds with HTTP 200, False for any error or other status - """ - if not self._dcgm_url: - return False - - # Use existing session if available, otherwise create a temporary one - if self._session: - try: - # Try HEAD first for efficiency - async with self._session.head( - self._dcgm_url, allow_redirects=False - ) as response: - if response.status == 200: - return True - # Fall back to GET if HEAD is not supported - async with self._session.get(self._dcgm_url) as response: - return response.status == 200 - except (aiohttp.ClientError, asyncio.TimeoutError): - return False - else: - # Create a temporary session for reachability check - timeout = aiohttp.ClientTimeout(total=Environment.GPU.REACHABILITY_TIMEOUT) - async with aiohttp.ClientSession(timeout=timeout) as temp_session: - try: - # Try HEAD first for efficiency - async with temp_session.head( - self._dcgm_url, allow_redirects=False - ) as response: - if response.status == 200: - return True - # Fall back to GET if HEAD is not supported - async with temp_session.get(self._dcgm_url) as response: - return response.status == 200 - except (aiohttp.ClientError, asyncio.TimeoutError): - return False - - @background_task(immediate=True, interval=lambda self: self._collection_interval) - async def _collect_telemetry_task(self) -> None: - """Background task for collecting telemetry data at regular intervals. - - This uses the @background_task decorator which automatically handles - lifecycle management and stopping when the collector is stopped. - The interval is set to the collection_interval so this runs periodically. - - Errors during collection are caught and sent via error_callback if configured. - CancelledError is propagated to allow graceful shutdown. - - Raises: - asyncio.CancelledError: Propagated to signal task cancellation during shutdown - """ - try: - await self._collect_and_process_metrics() - except asyncio.CancelledError: - raise - except Exception as e: - if self._error_callback: - try: - await self._error_callback(ErrorDetails.from_exception(e), self.id) - except Exception as callback_error: - self.error(f"Failed to send error via callback: {callback_error}") - else: - self.error(f"Telemetry collection error: {e}") + super().__init__( + endpoint_url=dcgm_url, + collection_interval=collection_interval, + reachability_timeout=reachability_timeout, + record_callback=record_callback, + error_callback=error_callback, + id=collector_id, + ) async def _collect_and_process_metrics(self) -> None: """Collect metrics from DCGM endpoint and process them into TelemetryRecord objects. + Implements the abstract method from BaseMetricsCollectorMixin. + Orchestrates the full collection flow: - 1. Fetches raw metrics data from DCGM endpoint + 1. Fetches raw metrics data from DCGM endpoint (via mixin's _fetch_metrics_text) 2. Parses Prometheus-format data into TelemetryRecord objects - 3. Sends records via callback (if configured and records are not empty) - - Callback failures are caught and logged as warnings without stopping collection. + 3. Sends records via callback (via mixin's _send_records_via_callback) Raises: Exception: Any exception from fetch or parse is logged and re-raised """ - try: - metrics_data = await self._fetch_metrics() - records = self._parse_metrics_to_records(metrics_data) - - if records and self._record_callback: - try: - await self._record_callback(records, self.id) - except Exception as e: - self.warning(f"Failed to send telemetry records via callback: {e}") - - except Exception as e: - self.error(f"Error collecting and processing metrics: {e}") - raise - - async def _fetch_metrics(self) -> str: - """Fetch raw metrics data from DCGM endpoint using aiohttp. - - Performs safety checks before making HTTP request: - - Verifies stop_requested flag to allow graceful shutdown - - Checks session is initialized and not closed - - Returns: - str: Raw metrics text in Prometheus exposition format - - Raises: - RuntimeError: If HTTP session is not initialized - aiohttp.ClientError: If HTTP request fails (4xx, 5xx, network errors) - asyncio.CancelledError: If collector is being stopped or session is closed - """ - if self.stop_requested: - raise asyncio.CancelledError - - if not self._session: - raise RuntimeError("HTTP session not initialized. Call initialize() first.") - - if self._session.closed: - raise asyncio.CancelledError - - async with self._session.get(self._dcgm_url) as response: - response.raise_for_status() - text = await response.text() - return text + metrics_data = await self._fetch_metrics_text() + records = self._parse_metrics_to_records(metrics_data) + await self._send_records_via_callback(records) def _parse_metrics_to_records(self, metrics_data: str) -> list[TelemetryRecord]: """Parse DCGM metrics text into TelemetryRecord objects using prometheus_client. @@ -286,7 +147,7 @@ def _parse_metrics_to_records(self, metrics_data: str) -> list[TelemetryRecord]: record = TelemetryRecord( timestamp_ns=current_timestamp, - dcgm_url=self._dcgm_url, + dcgm_url=self.endpoint_url, gpu_index=gpu_index, gpu_uuid=metadata.get("uuid", f"unknown-gpu-{gpu_index}"), gpu_model_name=metadata.get("model_name", f"GPU {gpu_index}"), diff --git a/src/aiperf/post_processors/__init__.py b/src/aiperf/post_processors/__init__.py index 680dcd11a..937673874 100644 --- a/src/aiperf/post_processors/__init__.py +++ b/src/aiperf/post_processors/__init__.py @@ -24,6 +24,9 @@ from aiperf.post_processors.record_export_results_processor import ( RecordExportResultsProcessor, ) +from aiperf.post_processors.server_metrics_export_results_processor import ( + ServerMetricsExportResultsProcessor, +) from aiperf.post_processors.telemetry_export_results_processor import ( TelemetryExportResultsProcessor, ) @@ -41,6 +44,7 @@ "RawRecordAggregator", "RawRecordWriterProcessor", "RecordExportResultsProcessor", + "ServerMetricsExportResultsProcessor", "TelemetryExportResultsProcessor", "TelemetryResultsProcessor", "TimesliceMetricResultsProcessor", diff --git a/src/aiperf/post_processors/server_metrics_export_results_processor.py b/src/aiperf/post_processors/server_metrics_export_results_processor.py new file mode 100644 index 000000000..7b995ed14 --- /dev/null +++ b/src/aiperf/post_processors/server_metrics_export_results_processor.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import asyncio + +import orjson + +from aiperf.common.config import UserConfig +from aiperf.common.decorators import implements_protocol +from aiperf.common.enums import ResultsProcessorType +from aiperf.common.environment import Environment +from aiperf.common.factories import ResultsProcessorFactory +from aiperf.common.mixins import BufferedJSONLWriterMixinWithDeduplication +from aiperf.common.models.record_models import MetricResult +from aiperf.common.models.server_metrics_models import ( + ServerMetricsMetadata, + ServerMetricsMetadataFile, + ServerMetricsRecord, + ServerMetricsSlimRecord, +) +from aiperf.common.protocols import ServerMetricsResultsProcessorProtocol +from aiperf.post_processors.base_metrics_processor import BaseMetricsProcessor + + +@implements_protocol(ServerMetricsResultsProcessorProtocol) +@ResultsProcessorFactory.register(ResultsProcessorType.SERVER_METRICS_JSONL_WRITER) +class ServerMetricsExportResultsProcessor( + BaseMetricsProcessor, + BufferedJSONLWriterMixinWithDeduplication[ServerMetricsSlimRecord], +): + """Exports per-record server metrics data to JSONL files in slim format. + + This processor converts full ServerMetricsRecord objects to slim format before writing, + excluding static metadata (metric types, description text) to minimize file size. + Writes one JSON line per collection cycle. + + Each line contains: + - timestamp_ns: Collection timestamp in nanoseconds + - endpoint_latency_ns: Time taken to collect the metrics from the endpoint + - endpoint_url: Source Prometheus metrics endpoint URL (e.g., 'http://localhost:8081/metrics') + - metrics: Dict mapping metric names to sample lists (flat structure) + """ + + def __init__( + self, + user_config: UserConfig, + **kwargs, + ) -> None: + output_file = user_config.output.server_metrics_export_jsonl_file + + super().__init__( + user_config=user_config, + output_file=output_file, + batch_size=Environment.SERVER_METRICS.EXPORT_BATCH_SIZE, + dedupe_key_function=lambda x: x.endpoint_url, + dedupe_value_function=lambda x: x.metrics, + **kwargs, + ) + + self._metadata_file = user_config.output.server_metrics_metadata_json_file + self._metadata_file.unlink(missing_ok=True) + + # Keep track of metadata for all endpoints over time. Note that the metadata fields + # can occasionally change, so we need to keep track of it over time. + self._metadata_file_model = ServerMetricsMetadataFile() + self._metadata_file_lock = asyncio.Lock() + + self.info(f"Server metrics jsonl writer export enabled: {self.output_file}") + self.info(f"Server metrics metadata file: {self._metadata_file}") + + async def process_server_metrics_record(self, record: ServerMetricsRecord) -> None: + """Process individual server metrics record by converting to slim and writing to JSONL. + + Converts full record to slim format to reduce file size by excluding static metadata. + On first record from each endpoint, extracts metadata and writes metadata file. + + Args: + record: ServerMetricsRecord containing Prometheus metrics snapshot and metadata + """ + url = record.endpoint_url + metadata = record.extract_metadata() + + # Check without lock to avoid unnecessary locking + should_write = False + if ( + url not in self._metadata_file_model.endpoints + or self._should_update_metadata( + record, self._metadata_file_model.endpoints[url] + ) + ): + should_write = True + + if should_write: + async with self._metadata_file_lock: + if url not in self._metadata_file_model.endpoints: + # First time seeing this endpoint, set the metadata + self._metadata_file_model.endpoints[url] = metadata + else: + # Merge new metadata with existing metadata + existing = self._metadata_file_model.endpoints[url] + existing.metric_schemas.update(metadata.metric_schemas) + existing.info_metrics.update(metadata.info_metrics) + await self._write_metadata_file() + + # Convert to slim format before writing to reduce file size (will be deduplicated) + slim_record = record.to_slim() + await self.buffered_write(slim_record) + + def _should_update_metadata( + self, record: ServerMetricsRecord, existing_metadata: ServerMetricsMetadata + ) -> bool: + """Check if metadata should be updated based on record changes. + + Detects new metric names not in existing metadata. + + Args: + record: ServerMetricsRecord to check + existing_metadata: Existing metadata for the endpoint + + Returns: + True if metadata needs updating, False otherwise + """ + existing_schemas = existing_metadata.metric_schemas + + # Check for new metric names + new_metrics = set(record.metrics.keys()) - set(existing_schemas.keys()) + if new_metrics: + self.debug( + lambda: f"Detected new metrics for {record.endpoint_url}: {sorted(new_metrics)}" + ) + return True + + return False + + async def _write_metadata_file(self) -> None: + """Write the complete metadata file for all seen endpoints atomically. + + Re-writes the entire metadata file with all endpoints seen so far. + Uses Pydantic model serialization with orjson for efficient JSON writing. + Uses atomic temp-file + rename pattern to prevent corruption on crash. + """ + metadata_json = orjson.dumps( + self._metadata_file_model.model_dump(exclude_none=True, mode="json"), + option=orjson.OPT_INDENT_2, + ) + + # Write to temp file and atomically rename to prevent corruption + temp_file = self._metadata_file.with_suffix(".tmp") + try: + temp_file.write_bytes(metadata_json) + temp_file.replace(self._metadata_file) # Atomic on POSIX + self.debug( + lambda: f"Wrote metadata file with {len(self._metadata_file_model.endpoints)} endpoints" + ) + except Exception as e: + self.error(f"Failed to write metadata file: {e}") + temp_file.unlink(missing_ok=True) + raise + + async def summarize(self) -> list[MetricResult]: + """Summarize the results. + + Returns: + Empty list (export processors don't generate metric results). + """ + return [] diff --git a/src/aiperf/post_processors/telemetry_export_results_processor.py b/src/aiperf/post_processors/telemetry_export_results_processor.py index 01dc93a4a..d298976be 100644 --- a/src/aiperf/post_processors/telemetry_export_results_processor.py +++ b/src/aiperf/post_processors/telemetry_export_results_processor.py @@ -8,7 +8,7 @@ from aiperf.common.enums import ResultsProcessorType from aiperf.common.environment import Environment from aiperf.common.factories import ResultsProcessorFactory -from aiperf.common.mixins import BufferedJSONLWriterMixin +from aiperf.common.mixins import BufferedJSONLWriterMixinWithDeduplication from aiperf.common.models import MetricResult from aiperf.common.models.telemetry_models import TelemetryRecord from aiperf.common.protocols import TelemetryResultsProcessorProtocol @@ -18,7 +18,7 @@ @implements_protocol(TelemetryResultsProcessorProtocol) @ResultsProcessorFactory.register(ResultsProcessorType.TELEMETRY_EXPORT) class TelemetryExportResultsProcessor( - BaseMetricsProcessor, BufferedJSONLWriterMixin[TelemetryRecord] + BaseMetricsProcessor, BufferedJSONLWriterMixinWithDeduplication[TelemetryRecord] ): """Exports per-record GPU telemetry data to JSONL files. @@ -42,13 +42,13 @@ def __init__( **kwargs, ): output_file: Path = user_config.output.profile_export_gpu_telemetry_jsonl_file - output_file.parent.mkdir(parents=True, exist_ok=True) - output_file.unlink(missing_ok=True) super().__init__( - output_file=output_file, - batch_size=Environment.RECORD.EXPORT_BATCH_SIZE, user_config=user_config, + output_file=output_file, + batch_size=Environment.GPU.EXPORT_BATCH_SIZE, + dedupe_key_function=lambda x: (x.dcgm_url, x.gpu_uuid), + dedupe_value_function=lambda x: x.telemetry_data, **kwargs, ) diff --git a/src/aiperf/records/__init__.py b/src/aiperf/records/__init__.py index ced5b2392..c9e0087d3 100644 --- a/src/aiperf/records/__init__.py +++ b/src/aiperf/records/__init__.py @@ -23,7 +23,10 @@ RecordProcessor, ) from aiperf.records.records_manager import ( + ErrorTrackingState, + MetricTrackingState, RecordsManager, + ServerMetricsTrackingState, TelemetryTrackingState, ) @@ -31,11 +34,14 @@ "AllRequestsProcessedCondition", "CompletionReason", "DurationTimeoutCondition", + "ErrorTrackingState", "InferenceResultParser", + "MetricTrackingState", "PhaseCompletionChecker", "PhaseCompletionCondition", "PhaseCompletionContext", "RecordProcessor", "RecordsManager", + "ServerMetricsTrackingState", "TelemetryTrackingState", ] diff --git a/src/aiperf/records/records_manager.py b/src/aiperf/records/records_manager.py index 516c13d1c..adea8f46d 100644 --- a/src/aiperf/records/records_manager.py +++ b/src/aiperf/records/records_manager.py @@ -27,38 +27,39 @@ from aiperf.common.messages import ( AllRecordsReceivedMessage, CreditPhaseCompleteMessage, + CreditPhaseSendingCompleteMessage, CreditPhaseStartMessage, + MetricRecordsData, MetricRecordsMessage, ProcessRecordsCommand, ProcessRecordsResultMessage, ProcessTelemetryResultMessage, ProfileCancelCommand, + RealtimeMetricsCommand, RealtimeMetricsMessage, RealtimeTelemetryMetricsMessage, RecordsProcessingStatsMessage, + ServerMetricsRecordsMessage, StartRealtimeTelemetryCommand, TelemetryRecordsMessage, ) -from aiperf.common.messages.command_messages import RealtimeMetricsCommand -from aiperf.common.messages.credit_messages import CreditPhaseSendingCompleteMessage -from aiperf.common.messages.inference_messages import MetricRecordsData from aiperf.common.mixins import PullClientMixin from aiperf.common.models import ( ErrorDetails, ErrorDetailsCount, + MetricResult, ProcessingStats, ProcessRecordsResult, - ProfileResults, -) -from aiperf.common.models.record_models import MetricResult -from aiperf.common.models.telemetry_models import ( ProcessTelemetryResult, + ProfileResults, + ServerMetricsRecord, TelemetryHierarchy, TelemetryRecord, TelemetryResults, ) from aiperf.common.protocols import ( ResultsProcessorProtocol, + ServerMetricsResultsProcessorProtocol, ServiceProtocol, TelemetryResultsProcessorProtocol, ) @@ -66,12 +67,11 @@ @dataclass -class TelemetryTrackingState: - """ - Tracks telemetry-related state and performance metrics. +class ErrorTrackingState: + """Base class for tracking errors with counts and thread-safe access. - Consolidates error tracking, warnings, endpoint status, and performance - statistics for GPU telemetry collection and processing. + Provides common error tracking functionality for all metrics subsystems + (telemetry, server metrics, regular metrics). """ error_counts: dict[ErrorDetails, int] = field( @@ -79,6 +79,16 @@ class TelemetryTrackingState: ) error_counts_lock: asyncio.Lock = field(default_factory=asyncio.Lock) + +@dataclass +class TelemetryTrackingState(ErrorTrackingState): + """ + Tracks telemetry-related state and performance metrics. + + Consolidates error tracking, warnings, endpoint status, and performance + statistics for GPU telemetry collection and processing. + """ + task_runs: int = 0 total_gen_time_ms: float = 0.0 total_pub_time_ms: float = 0.0 @@ -87,6 +97,11 @@ class TelemetryTrackingState: last_metric_values: dict[str, float | None] | None = None +# Type aliases for clarity - all use the base ErrorTrackingState +ServerMetricsTrackingState = ErrorTrackingState +MetricTrackingState = ErrorTrackingState + + @implements_protocol(ServiceProtocol) @ServiceFactory.register(ServiceType.RECORDS_MANAGER) class RecordsManager(PullClientMixin, BaseComponentService): @@ -134,10 +149,15 @@ def __init__( self._previous_realtime_records: int | None = None self._telemetry_state = TelemetryTrackingState() + self._server_metrics_state = ServerMetricsTrackingState() + self._metric_state = MetricTrackingState() self._telemetry_enable_event = asyncio.Event() self._metric_results_processors: list[ResultsProcessorProtocol] = [] self._telemetry_results_processors: list[TelemetryResultsProcessorProtocol] = [] + self._server_metrics_results_processors: list[ + ServerMetricsResultsProcessorProtocol + ] = [] self._telemetry_accumulator: TelemetryResultsProcessorProtocol | None = None for results_processor_type in ResultsProcessorFactory.get_all_class_types(): @@ -156,6 +176,10 @@ def __init__( # Store the accumulating processor separately for hierarchy access if results_processor_type == ResultsProcessorType.TELEMETRY_RESULTS: self._telemetry_accumulator = results_processor + elif isinstance( + results_processor, ServerMetricsResultsProcessorProtocol + ): + self._server_metrics_results_processors.append(results_processor) else: self._metric_results_processors.append(results_processor) @@ -244,6 +268,26 @@ async def _on_telemetry_records(self, message: TelemetryRecordsMessage) -> None: async with self._telemetry_state.error_counts_lock: self._telemetry_state.error_counts[message.error] += 1 + @on_pull_message(MessageType.SERVER_METRICS_RECORDS) + async def _on_server_metrics_records( + self, message: ServerMetricsRecordsMessage + ) -> None: + """Handle server metrics records message from Server Metrics Manager. + + Forwards full records to results processors. + + Args: + message: Batch of server metrics records from a Prometheus collector + """ + if message.valid: + # Forward full records to results processors + for record in message.records: + await self._send_server_metrics_to_results_processors(record) + else: + if message.error: + async with self._server_metrics_state.error_counts_lock: + self._server_metrics_state.error_counts[message.error] += 1 + def _should_include_request_by_duration( self, record_data: MetricRecordsData ) -> bool: @@ -351,6 +395,29 @@ async def _send_telemetry_to_results_processors( ] ) + async def _send_server_metrics_to_results_processors( + self, record: ServerMetricsRecord + ) -> None: + """Send individual server metrics records to server metrics results processors only. + + Args: + record: ServerMetricsRecord from single collection cycle + """ + errors = await asyncio.gather( + *[ + processor.process_server_metrics_record(record) + for processor in self._server_metrics_results_processors + ], + return_exceptions=True, + ) + for error in errors: + if isinstance(error, BaseException): + self.exception(f"Failed to process server metrics record: {error!r}") + async with self._server_metrics_state.error_counts_lock: + self._server_metrics_state.error_counts[ + ErrorDetails.from_exception(error) + ] += 1 + @on_message(MessageType.CREDIT_PHASE_START) async def _on_credit_phase_start( self, phase_start_msg: CreditPhaseStartMessage diff --git a/src/aiperf/server_metrics/__init__.py b/src/aiperf/server_metrics/__init__.py new file mode 100644 index 000000000..09ee99390 --- /dev/null +++ b/src/aiperf/server_metrics/__init__.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from aiperf.server_metrics.server_metrics_data_collector import ( + ServerMetricsDataCollector, +) +from aiperf.server_metrics.server_metrics_manager import ( + ServerMetricsManager, +) + +__all__ = ["ServerMetricsDataCollector", "ServerMetricsManager"] diff --git a/src/aiperf/server_metrics/server_metrics_data_collector.py b/src/aiperf/server_metrics/server_metrics_data_collector.py new file mode 100644 index 000000000..dcb9f96fd --- /dev/null +++ b/src/aiperf/server_metrics/server_metrics_data_collector.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import time +from collections import defaultdict +from collections.abc import Awaitable, Callable + +from prometheus_client.metrics_core import Metric +from prometheus_client.parser import text_string_to_metric_families + +from aiperf.common.enums import PrometheusMetricType +from aiperf.common.environment import Environment +from aiperf.common.mixins import BaseMetricsCollectorMixin +from aiperf.common.models import ErrorDetails +from aiperf.common.models.server_metrics_models import ( + HistogramData, + MetricFamily, + MetricSample, + ServerMetricsRecord, + SummaryData, +) + +__all__ = ["ServerMetricsDataCollector"] + + +class ServerMetricsDataCollector(BaseMetricsCollectorMixin[ServerMetricsRecord]): + """Collects server metrics from Prometheus endpoint using async architecture. + + Modern async collector that fetches metrics from Prometheus-compatible endpoints + and converts them to ServerMetricsRecord objects. Uses BaseMetricsCollectorMixin + for HTTP collection patterns. + - Extends BaseMetricsCollectorMixin for async HTTP collection + - Uses prometheus_client for robust metric parsing + - Sends ServerMetricsRecord list via callback function + + Args: + endpoint_url: URL of the Prometheus metrics endpoint (e.g., "http://localhost:8081/metrics") + collection_interval: Interval in seconds between metric collections (default: 1.0) + record_callback: Optional async callback to receive collected records. + Signature: async (records: list[ServerMetricsRecord], collector_id: str) -> None + error_callback: Optional async callback to receive collection errors. + Signature: async (error: ErrorDetails, collector_id: str) -> None + collector_id: Unique identifier for this collector instance + """ + + def __init__( + self, + endpoint_url: str, + collection_interval: float | None = None, + reachability_timeout: float | None = None, + record_callback: Callable[[list[ServerMetricsRecord], str], Awaitable[None]] + | None = None, + error_callback: Callable[[ErrorDetails, str], Awaitable[None]] | None = None, + collector_id: str = "server_metrics_collector", + ) -> None: + super().__init__( + endpoint_url=endpoint_url, + collection_interval=collection_interval + or Environment.SERVER_METRICS.COLLECTION_INTERVAL, + reachability_timeout=reachability_timeout + or Environment.SERVER_METRICS.REACHABILITY_TIMEOUT, + record_callback=record_callback, + error_callback=error_callback, + id=collector_id, + ) + + async def _collect_and_process_metrics(self) -> None: + """Collect metrics from Prometheus endpoint and process them into ServerMetricsRecord objects. + + Implements the abstract method from BaseMetricsCollectorMixin. + + Orchestrates the full collection flow: + 1. Fetches raw metrics data from Prometheus endpoint (via mixin's _fetch_metrics_text) + 2. Parses Prometheus-format data into ServerMetricsRecord objects + 3. Sends records via callback (via mixin's _send_records_via_callback) + + Raises: + Exception: Any exception from fetch or parse is logged and re-raised + """ + start_perf_ns = time.perf_counter_ns() + metrics_data = await self._fetch_metrics_text() + latency_ns = time.perf_counter_ns() - start_perf_ns + records = self._parse_metrics_to_records(metrics_data, latency_ns) + await self._send_records_via_callback(records) + + def _parse_metrics_to_records( + self, metrics_data: str, latency_ns: int + ) -> list[ServerMetricsRecord]: + """Parse Prometheus metrics text into ServerMetricsRecord objects. + + Processes Prometheus exposition format metrics: + 1. Parses metric families using prometheus_client parser + 2. Groups metrics by type (counter, gauge, histogram, summary) + 3. De-duplicates by label combination (last value wins) + 4. Structures histogram and summary data + + Args: + metrics_data: Raw metrics text from Prometheus endpoint in Prometheus format + latency_ns: Nanoseconds it took to collect the metrics from the endpoint + + Returns: + list[ServerMetricsRecord]: List with single ServerMetricsRecord containing complete snapshot. + Returns empty list if metrics_data is empty or parsing fails. + """ + if not metrics_data.strip(): + return [] + + current_timestamp_ns = time.time_ns() + metrics_dict: dict[str, MetricFamily] = {} + + try: + for family in text_string_to_metric_families(metrics_data): + # Skip _created metrics - these are timestamps indicating when the parent + # histogram/summary/counter was created, not actual metric data + if family.name.endswith("_created"): + continue + + metric_type = PrometheusMetricType(family.type) + match metric_type: + case PrometheusMetricType.HISTOGRAM: + samples = self._process_histogram_family(family) + case PrometheusMetricType.SUMMARY: + samples = self._process_summary_family(family) + case ( + PrometheusMetricType.COUNTER + | PrometheusMetricType.GAUGE + | PrometheusMetricType.UNKNOWN + ): + samples = self._process_simple_family(family) + case _: + self.warning(f"Unsupported metric type: {metric_type}") + continue + + # Only add metric family if it has samples (skip empty after validation) + if samples: + metrics_dict[family.name] = MetricFamily( + type=metric_type, + description=family.documentation or "", + samples=samples, + ) + except ValueError as e: + self.warning(f"Failed to parse Prometheus metrics - invalid format: {e}") + raise + + # Suppress empty snapshots to reduce I/O noise + if not metrics_dict: + return [] + + record = ServerMetricsRecord( + timestamp_ns=current_timestamp_ns, + endpoint_latency_ns=latency_ns, + endpoint_url=self._endpoint_url, + metrics=metrics_dict, + ) + + return [record] + + def _process_simple_family(self, family: Metric) -> list[MetricSample]: + """Process counter, gauge, or untyped metrics with de-duplication. + + Args: + family: Prometheus metric family + + Returns: + List of MetricSample objects with de-duplicated values (last wins) + """ + samples_by_labels: dict[tuple, float] = {} + + for sample in family.samples: + label_key = tuple(sorted(sample.labels.items())) + samples_by_labels[label_key] = sample.value + + return [ + MetricSample(labels=dict(label_tuple) if label_tuple else None, value=value) + for label_tuple, value in samples_by_labels.items() + ] + + def _process_histogram_family(self, family: Metric) -> list[MetricSample]: + """Process histogram metrics into structured format. + + Args: + family: Prometheus histogram metric family + + Returns: + List of MetricSample objects with HistogramData + """ + histograms: dict[tuple, HistogramData] = defaultdict( + lambda: HistogramData(buckets={}, sum=None, count=None) + ) + + for sample in family.samples: + base_labels = {k: v for k, v in sample.labels.items() if k != "le"} + label_key = tuple(sorted(base_labels.items())) + + if sample.name.endswith("_bucket"): + le_value = sample.labels.get("le", "+Inf") + histograms[label_key].buckets[le_value] = sample.value + elif sample.name.endswith("_sum"): + histograms[label_key].sum = sample.value + elif sample.name.endswith("_count"): + histograms[label_key].count = sample.value + + samples = [] + for label_tuple, hist_data in histograms.items(): + # Skip histograms missing required fields or with no buckets + if ( + hist_data.sum is None + or hist_data.count is None + or not hist_data.buckets + ): + self.debug( + lambda hist=hist_data: f"Skipping incomplete histogram (missing sum, count, or buckets): {hist}" + ) + continue + + samples.append( + MetricSample( + labels=dict(label_tuple) if label_tuple else None, + histogram=hist_data, + ) + ) + + return samples + + def _process_summary_family(self, family: Metric) -> list[MetricSample]: + """Process summary metrics into structured format. + + Args: + family: Prometheus summary metric family + + Returns: + List of MetricSample objects with SummaryData + """ + summaries: dict[tuple, SummaryData] = defaultdict( + lambda: SummaryData(quantiles={}, sum=None, count=None) + ) + + for sample in family.samples: + base_labels = {k: v for k, v in sample.labels.items() if k != "quantile"} + label_key = tuple(sorted(base_labels.items())) + + if sample.name == family.name: + quantile = sample.labels.get("quantile", "0") + summaries[label_key].quantiles[quantile] = sample.value + elif sample.name.endswith("_sum"): + summaries[label_key].sum = sample.value + elif sample.name.endswith("_count"): + summaries[label_key].count = sample.value + + samples = [] + for label_tuple, summary_data in summaries.items(): + # Skip summaries missing required fields or with no quantiles + if ( + summary_data.sum is None + or summary_data.count is None + or not summary_data.quantiles + ): + self.debug( + lambda summary=summary_data: f"Skipping incomplete summary (missing sum, count, or quantiles): {summary}" + ) + continue + + samples.append( + MetricSample( + labels=dict(label_tuple) if label_tuple else None, + summary=summary_data, + ) + ) + + return samples diff --git a/src/aiperf/server_metrics/server_metrics_manager.py b/src/aiperf/server_metrics/server_metrics_manager.py new file mode 100644 index 000000000..993221cda --- /dev/null +++ b/src/aiperf/server_metrics/server_metrics_manager.py @@ -0,0 +1,352 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import asyncio + +from aiperf.common.base_component_service import BaseComponentService +from aiperf.common.config import ServiceConfig, UserConfig +from aiperf.common.decorators import implements_protocol +from aiperf.common.enums import ( + CommAddress, + CommandType, + ServiceType, +) +from aiperf.common.environment import Environment +from aiperf.common.factories import ServiceFactory +from aiperf.common.hooks import on_command, on_stop +from aiperf.common.messages import ( + ProfileCancelCommand, + ProfileConfigureCommand, + ProfileStartCommand, +) +from aiperf.common.messages.server_metrics_messages import ( + ServerMetricsRecordsMessage, + ServerMetricsStatusMessage, +) +from aiperf.common.metric_utils import ( + build_hostname_aware_prometheus_endpoints, + normalize_metrics_endpoint_url, +) +from aiperf.common.models import ErrorDetails +from aiperf.common.models.server_metrics_models import ServerMetricsRecord +from aiperf.common.protocols import ( + PushClientProtocol, + ServiceProtocol, +) +from aiperf.server_metrics.server_metrics_data_collector import ( + ServerMetricsDataCollector, +) + +__all__ = ["ServerMetricsManager"] + + +@implements_protocol(ServiceProtocol) +@ServiceFactory.register(ServiceType.SERVER_METRICS_MANAGER) +class ServerMetricsManager(BaseComponentService): + """Coordinates multiple ServerMetricsDataCollector instances for server metrics collection. + + The ServerMetricsManager coordinates multiple ServerMetricsDataCollector instances + to collect server metrics from multiple Prometheus endpoints and send unified + ServerMetricsRecordsMessage to RecordsManager. + + This service: + - Manages lifecycle of ServerMetricsDataCollector instances + - Collects metrics from multiple Prometheus endpoints + - Sends ServerMetricsRecordsMessage to RecordsManager via message system + - Handles errors gracefully with ErrorDetails + - Follows centralized architecture patterns + + Args: + service_config: Service-level configuration (logging, communication, etc.) + user_config: User-provided configuration including server_metrics endpoints + service_id: Optional unique identifier for this service instance + """ + + def __init__( + self, + service_config: ServiceConfig, + user_config: UserConfig, + service_id: str | None = None, + ) -> None: + super().__init__( + service_config=service_config, + user_config=user_config, + service_id=service_id, + ) + + self.records_push_client: PushClientProtocol = self.comms.create_push_client( + CommAddress.RECORDS, + ) + + self._collectors: dict[str, ServerMetricsDataCollector] = {} + self._server_metrics_endpoints = build_hostname_aware_prometheus_endpoints( + inference_endpoint_url=user_config.endpoint.url, + default_ports=Environment.SERVER_METRICS.DEFAULT_BACKEND_PORTS, + include_inference_port=True, + ) + self.info( + f"Server Metrics: Discovered {len(self._server_metrics_endpoints)} endpoints: {self._server_metrics_endpoints}" + ) + + # Add user-specified URLs if provided + if user_config.server_metrics_urls: + # Add user URLs, avoiding duplicates + for url in user_config.server_metrics_urls: + normalized_url = normalize_metrics_endpoint_url(url) + if normalized_url not in self._server_metrics_endpoints: + self._server_metrics_endpoints.append(normalized_url) + + # Use server metrics collection interval + self._collection_interval = Environment.SERVER_METRICS.COLLECTION_INTERVAL + + @on_command(CommandType.PROFILE_CONFIGURE) + async def _profile_configure_command( + self, message: ProfileConfigureCommand + ) -> None: + """Configure the server metrics collectors but don't start them yet. + + Creates ServerMetricsDataCollector instances for each configured endpoint, + tests reachability, and sends status message to RecordsManager. + If no endpoints are reachable, disables metrics collection and stops the service. + + Args: + message: Profile configuration command from SystemController + """ + # Check if server metrics are disabled via environment variable + if not Environment.SERVER_METRICS.ENABLED: + await self._send_server_metrics_status( + enabled=False, + reason="disabled via AIPERF_SERVER_METRICS_ENABLED=false", + endpoints_configured=[], + endpoints_reachable=[], + ) + return + + self._collectors.clear() + + for endpoint_url in self._server_metrics_endpoints: + self.debug(f"Server Metrics: Testing reachability of {endpoint_url}") + collector = ServerMetricsDataCollector( + endpoint_url=endpoint_url, + collection_interval=self._collection_interval, + record_callback=self._on_server_metrics_records, + error_callback=self._on_server_metrics_error, + collector_id=endpoint_url, + ) + + try: + is_reachable = await collector.is_url_reachable() + if is_reachable: + self._collectors[endpoint_url] = collector + self.debug( + f"Server Metrics: Prometheus endpoint {endpoint_url} is reachable" + ) + else: + self.debug( + f"Server Metrics: Prometheus endpoint {endpoint_url} is not reachable" + ) + except Exception as e: + self.error(f"Server Metrics: Exception testing {endpoint_url}: {e}") + + reachable_endpoints = list(self._collectors.keys()) + + if not self._collectors: + # Server metrics manager shutdown occurs in _on_start_profiling to prevent hang + await self._send_server_metrics_status( + enabled=False, + reason="no Prometheus endpoints reachable", + endpoints_configured=self._server_metrics_endpoints, + endpoints_reachable=[], + ) + return + + await self._send_server_metrics_status( + enabled=True, + reason=None, + endpoints_configured=self._server_metrics_endpoints, + endpoints_reachable=reachable_endpoints, + ) + + @on_command(CommandType.PROFILE_START) + async def _on_start_profiling(self, message: ProfileStartCommand) -> None: + """Start all server metrics collectors. + + Initializes and starts each configured collector. + If no collectors start successfully, sends disabled status to SystemController. + + Args: + message: Profile start command from SystemController + """ + if not self._collectors: + # Server metrics disabled status already sent in _profile_configure_command, only shutdown here + await self.stop() + return + + started_count = 0 + for endpoint_url, collector in self._collectors.items(): + try: + await collector.initialize() + await collector.start() + started_count += 1 + except Exception as e: + self.error(f"Failed to start collector for {endpoint_url}: {e}") + + if started_count == 0: + self.warning("No server metrics collectors successfully started") + await self._send_server_metrics_status( + enabled=False, + reason="all collectors failed to start", + endpoints_configured=self._server_metrics_endpoints, + endpoints_reachable=[], + ) + await self.stop() + return + + @on_command(CommandType.PROFILE_CANCEL) + async def _handle_profile_cancel_command( + self, message: ProfileCancelCommand + ) -> None: + """Stop all server metrics collectors when profiling is cancelled. + + Called when user cancels profiling or an error occurs during profiling. + Waits for flush period to allow metrics to finalize, then stops collectors. + + Args: + message: Profile cancel command from SystemController + """ + flush_period = Environment.SERVER_METRICS.COLLECTION_FLUSH_PERIOD + if flush_period > 0: + self.info( + f"Server Metrics: Waiting {flush_period}s flush period for final server metrics to finalize" + ) + await asyncio.sleep(flush_period) + + await self._stop_all_collectors() + + @on_stop + async def _server_metrics_manager_stop(self) -> None: + """Stop all server metrics collectors during service shutdown. + + Called automatically by BaseComponentService lifecycle management via @on_stop hook. + Ensures all collectors are properly stopped and cleaned up even if shutdown + command was not received. + """ + await self._stop_all_collectors() + + async def _stop_all_collectors(self) -> None: + """Stop all server metrics collectors. + + Attempts to stop each collector gracefully, logging errors but continuing with + remaining collectors to ensure all resources are released. Does nothing if no + collectors are configured. + + Errors during individual collector shutdown do not prevent other collectors + from being stopped. + """ + if not self._collectors: + return + + for endpoint_url, collector in self._collectors.items(): + try: + await collector.stop() + except Exception as e: + self.error(f"Failed to stop collector for {endpoint_url}: {e}") + + async def _on_server_metrics_records( + self, records: list[ServerMetricsRecord], collector_id: str + ) -> None: + """Async callback for receiving server metrics records from collectors. + + Sends full records with all metadata to RecordsManager. + Empty record lists are ignored. + + Args: + records: List of ServerMetricsRecord objects from a collector + collector_id: Unique identifier of the collector that sent the records + """ + if not records: + return + + try: + message = ServerMetricsRecordsMessage( + service_id=self.service_id, + collector_id=collector_id, + records=records, + error=None, + ) + + await self.records_push_client.push(message) + + except Exception as e: + self.error(f"Failed to send server metrics records: {e}") + # Send error message to RecordsManager to track the failure + try: + error_message = ServerMetricsRecordsMessage( + service_id=self.service_id, + collector_id=collector_id, + records=[], + error=ErrorDetails.from_exception(e), + ) + await self.records_push_client.push(error_message) + except Exception as nested_error: + self.error( + f"Failed to send error message after record send failure: {nested_error}" + ) + + async def _on_server_metrics_error( + self, error: ErrorDetails, collector_id: str + ) -> None: + """Async callback for receiving server metrics errors from collectors. + + Sends error ServerMetricsRecordsMessage to RecordsManager via message system. + The message contains an empty records list and the error details. + + Args: + error: ErrorDetails describing the collection error + collector_id: Unique identifier of the collector that encountered the error + """ + try: + error_message = ServerMetricsRecordsMessage( + service_id=self.service_id, + collector_id=collector_id, + records=[], + error=error, + ) + + await self.records_push_client.push(error_message) + + except Exception as e: + self.error(f"Failed to send server metrics error message: {e}") + + async def _send_server_metrics_status( + self, + enabled: bool, + reason: str | None = None, + endpoints_configured: list[str] | None = None, + endpoints_reachable: list[str] | None = None, + ) -> None: + """Send server metrics status message to SystemController. + + Publishes ServerMetricsStatusMessage to inform SystemController about metrics + availability and endpoint reachability. Used during configuration phase and + when metrics are disabled due to errors. + + Args: + enabled: Whether server metrics collection is enabled/available + reason: Optional human-readable reason for status (e.g., "no Prometheus endpoints reachable") + endpoints_configured: List of Prometheus endpoint URLs configured + endpoints_reachable: List of Prometheus endpoint URLs that are accessible + """ + try: + status_message = ServerMetricsStatusMessage( + service_id=self.service_id, + enabled=enabled, + reason=reason, + endpoints_configured=endpoints_configured or [], + endpoints_reachable=endpoints_reachable or [], + ) + + await self.publish(status_message) + + except Exception as e: + self.error(f"Failed to send server metrics status message: {e}") diff --git a/src/aiperf/zmq/dealer_request_client.py b/src/aiperf/zmq/dealer_request_client.py index a19428c71..cdf280279 100644 --- a/src/aiperf/zmq/dealer_request_client.py +++ b/src/aiperf/zmq/dealer_request_client.py @@ -146,7 +146,12 @@ async def request( future = asyncio.Future[Message]() async def callback(response_message: Message) -> None: - future.set_result(response_message) + if not future.done(): + future.set_result(response_message) + else: + self.warning( + f"Received response for request {message.request_id} after it was already completed. Ignoring." + ) await self.request_async(message, callback) return await asyncio.wait_for(future, timeout=timeout) diff --git a/tests/unit/common/test_duplicate_tracker.py b/tests/unit/common/test_duplicate_tracker.py new file mode 100644 index 000000000..c9f4b777c --- /dev/null +++ b/tests/unit/common/test_duplicate_tracker.py @@ -0,0 +1,497 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from dataclasses import dataclass + +import pytest + +from aiperf.common.duplicate_tracker import AsyncKeyedDuplicateTracker + + +@dataclass +class SampleRecord: + """Simple test record for deduplication testing.""" + + key: str + value: int + metadata: str = "test" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SampleRecord): + return False + return self.value == other.value and self.metadata == other.metadata + + +@pytest.fixture +def tracker() -> AsyncKeyedDuplicateTracker[SampleRecord]: + """Create a fresh tracker for each test.""" + return AsyncKeyedDuplicateTracker[SampleRecord]( + key_function=lambda record: record.key, + value_function=lambda record: record.value, + ) + + +# Helper functions +async def write_sequence( + tracker: AsyncKeyedDuplicateTracker[SampleRecord], key: str, values: list[int] +) -> list[list[SampleRecord]]: + """Write a sequence of values and return all results.""" + results = [] + for value in values: + record = SampleRecord(key=key, value=value) + result = await tracker.deduplicate_record(record) + results.append(result) + return results + + +def flatten_results(results: list[list[SampleRecord]]) -> list[SampleRecord]: + """Flatten a list of results into a single list.""" + return [record for result in results for record in result] + + +def get_values(records: list[SampleRecord]) -> list[int]: + """Extract values from a list of records.""" + return [record.value for record in records] + + +class TestAsyncKeyedDuplicateTrackerBasicDeduplication: + """Test basic deduplication functionality.""" + + @pytest.mark.asyncio + async def test_first_record_always_written( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord] + ): + """Test that the first record is always written.""" + record = SampleRecord(key="key1", value=1) + result = await tracker.deduplicate_record(record) + + assert len(result) == 1 + assert result[0] == record + + @pytest.mark.asyncio + @pytest.mark.parametrize("num_duplicates", [2, 3, 5, 10]) + async def test_consecutive_duplicates_suppressed( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord], num_duplicates: int + ): + """Test that consecutive identical records are suppressed.""" + results = await write_sequence(tracker, "key1", [1] * num_duplicates) + + # First should be written, rest suppressed + assert len(results[0]) == 1 + for i in range(1, num_duplicates): + assert len(results[i]) == 0 + + # Dupe count should be num_duplicates - 1 + assert tracker._dupe_counts["key1"] == num_duplicates - 1 + + @pytest.mark.asyncio + async def test_change_writes_previous_and_new( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord] + ): + """Test that when value changes, both previous and new records are written. + + Input: A, A, A, B + Expected: A (first), [], [], [A, B] (last A before change + new B) + """ + results = await write_sequence(tracker, "key1", [1, 1, 1, 2]) + + assert len(results[0]) == 1 # First A written + assert len(results[1]) == 0 # Duplicate suppressed + assert len(results[2]) == 0 # Duplicate suppressed + assert len(results[3]) == 2 # Both previous A and new B + assert results[3][0].value == 1 # Previous record + assert results[3][1].value == 2 # New record + + # Dupe count should be reset + assert tracker._dupe_counts["key1"] == 0 + + @pytest.mark.asyncio + async def test_no_deduplication_when_values_differ( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord] + ): + """Test that different values are not deduplicated.""" + results = await write_sequence(tracker, "key1", [1, 2, 3]) + + for i, result in enumerate(results): + # Each should write (possibly with previous on change) + assert len(result) >= 1 + assert result[-1].value == i + 1 # Last item should be current record + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "input_sequence,expected_output", + [ + # A,A,A,B,B,C → A (first), A,B (last A + change to B), B,C (last B + change to C) + ([1, 1, 1, 2, 2, 3], [1, 1, 2, 2, 3]), + # A,A,A → A (only first during execution) + ([1, 1, 1], [1]), + # A,B,C → A, B, C (no duplicates, just write each) + ([1, 2, 3], [1, 2, 3]), + # A,A,B,B,C,C → A, A,B (last A + B), B,C (last B + C) + ([1, 1, 2, 2, 3, 3], [1, 1, 2, 2, 3]), + # A,B,A,B,A → No duplicates, just write each = A,B,A,B,A + ([1, 2, 1, 2, 1], [1, 2, 1, 2, 1]), + ], + ) # fmt: skip + async def test_deduplication_sequences( + self, + tracker: AsyncKeyedDuplicateTracker[SampleRecord], + input_sequence: list[int], + expected_output: list[int], + ): + """Test various deduplication sequences.""" + results = await write_sequence(tracker, "key1", input_sequence) + all_written = flatten_results(results) + values = get_values(all_written) + + assert values == expected_output + + +class TestAsyncKeyedDuplicateTrackerPerKey: + """Test that deduplication is tracked independently per key.""" + + @pytest.mark.asyncio + async def test_deduplication_per_key( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord] + ): + """Test that deduplication is tracked independently per key.""" + record_key1 = SampleRecord(key="key1", value=1) + record_key2 = SampleRecord(key="key2", value=1) + + # Write same value to two different keys + result1_key1 = await tracker.deduplicate_record(record_key1) + result1_key2 = await tracker.deduplicate_record(record_key2) + + # Both should write (first for each key) + assert len(result1_key1) == 1 + assert len(result1_key2) == 1 + + # Write duplicates + result2_key1 = await tracker.deduplicate_record(record_key1) + result2_key2 = await tracker.deduplicate_record(record_key2) + + # Both should suppress + assert len(result2_key1) == 0 + assert len(result2_key2) == 0 + + # Each key should have its own dupe count + assert tracker._dupe_counts["key1"] == 1 + assert tracker._dupe_counts["key2"] == 1 + + @pytest.mark.asyncio + async def test_keys_maintain_independent_state( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord] + ): + """Test that different keys maintain completely independent state.""" + record_a_key1 = SampleRecord(key="key1", value=1) + record_b_key2 = SampleRecord(key="key2", value=2) + + # key1: Write A three times + await tracker.deduplicate_record(record_a_key1) + await tracker.deduplicate_record(record_a_key1) + await tracker.deduplicate_record(record_a_key1) + + # key2: Write B three times + await tracker.deduplicate_record(record_b_key2) + await tracker.deduplicate_record(record_b_key2) + await tracker.deduplicate_record(record_b_key2) + + # Each should have 2 duplicates + assert tracker._dupe_counts["key1"] == 2 + assert tracker._dupe_counts["key2"] == 2 + + # Previous records should be different by value + assert tracker._previous_records["key1"].value == 1 + assert tracker._previous_records["key2"].value == 2 + + +class TestAsyncKeyedDuplicateTrackerEquality: + """Test equality comparison for deduplication.""" + + @pytest.mark.asyncio + async def test_equality_uses_value_function( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord] + ): + """Test that deduplication uses the value_function for comparison.""" + record1 = SampleRecord(key="key1", value=1, metadata="test") + record2 = SampleRecord(key="key1", value=1, metadata="different") + record3 = SampleRecord(key="key1", value=2, metadata="test") + + # record1 and record2 have same value (even though metadata differs) + result1 = await tracker.deduplicate_record(record1) + assert len(result1) == 1 + + result2 = await tracker.deduplicate_record(record2) + assert len(result2) == 0 # Duplicate (same value) + + # record3 has different value + result3 = await tracker.deduplicate_record(record3) + assert len(result3) == 2 # Previous + new + + @pytest.mark.asyncio + async def test_complex_equality(self): + """Test deduplication with complex dictionary objects.""" + tracker_dict: AsyncKeyedDuplicateTracker[dict] = AsyncKeyedDuplicateTracker[ + dict + ]( + key_function=lambda record: record.get("key", "default"), + value_function=lambda record: record, + ) + + dict1 = {"key": "key1", "a": 1, "b": {"c": 2}} + dict2 = {"key": "key1", "a": 1, "b": {"c": 2}} + dict3 = {"key": "key1", "a": 1, "b": {"c": 3}} + + result1 = await tracker_dict.deduplicate_record(dict1) + assert len(result1) == 1 + + result2 = await tracker_dict.deduplicate_record(dict2) + assert len(result2) == 0 # Equal dicts + + result3 = await tracker_dict.deduplicate_record(dict3) + assert len(result3) == 2 # Different + + +class TestAsyncKeyedDuplicateTrackerConcurrency: + """Test concurrent access to the tracker.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize("num_tasks,writes_per_task", [(3, 10), (5, 5), (10, 3)]) + async def test_concurrent_writes_to_same_key( + self, + tracker: AsyncKeyedDuplicateTracker[SampleRecord], + num_tasks: int, + writes_per_task: int, + ): + """Test that concurrent writes to the same key are handled safely.""" + results = await write_sequence( + tracker, "key1", [1] * (num_tasks * writes_per_task) + ) + + # At least first record should be written + total_written = sum(len(r) for r in results) + assert total_written >= 1 + + # Total written should be much less than total due to deduplication + assert total_written < num_tasks * writes_per_task + + @pytest.mark.asyncio + @pytest.mark.parametrize("num_keys", [2, 3, 5]) + async def test_concurrent_writes_to_different_keys( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord], num_keys: int + ): + """Test that different keys can be written concurrently.""" + + async def write_to_key(key: str, count: int) -> int: + total = 0 + for _ in range(count): + record = SampleRecord(key=key, value=1) + result = await tracker.deduplicate_record(record) + total += len(result) + return total + + # Write to different keys concurrently + keys = [f"key{i}" for i in range(num_keys)] + results = await asyncio.gather(*[write_to_key(key, 5) for key in keys]) + + # Each key should have written at least once (first record) + for key_result in results: + assert key_result >= 1 + + # Each key should have its own lock + for key in keys: + assert key in tracker._dupe_locks + + @pytest.mark.asyncio + async def test_lock_creation_race_condition( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord] + ): + """Test that lock creation handles race conditions correctly.""" + + async def write_first_record() -> list[SampleRecord]: + record = SampleRecord(key="new_key", value=1) + return await tracker.deduplicate_record(record) + + # Try to create locks for the same key concurrently + results = await asyncio.gather( + write_first_record(), + write_first_record(), + write_first_record(), + ) + + # Should have created exactly one lock + assert "new_key" in tracker._dupe_locks + + # At least one should have written + total_written = sum(len(r) for r in results) + assert total_written >= 1 + + +class TestAsyncKeyedDuplicateTrackerEdgeCases: + """Test edge cases and special scenarios.""" + + @pytest.mark.asyncio + async def test_single_record( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord] + ): + """Test that a single record is written without issues.""" + record = SampleRecord(key="key1", value=1) + result = await tracker.deduplicate_record(record) + + assert len(result) == 1 + assert result[0] == record + assert tracker._dupe_counts["key1"] == 0 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "key", ["", "key-with-dashes", "key.with.dots", "key/with/slashes"] + ) + async def test_special_key_strings( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord], key: str + ): + """Test that various key formats work correctly.""" + record = SampleRecord(key=key, value=1) + result = await tracker.deduplicate_record(record) + + assert len(result) == 1 + assert key in tracker._previous_records + + @pytest.mark.asyncio + @pytest.mark.parametrize("num_keys", [10, 50, 100]) + async def test_many_keys( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord], num_keys: int + ): + """Test handling many different keys.""" + for i in range(num_keys): + record = SampleRecord(key=f"key{i}", value=1) + result = await tracker.deduplicate_record(record) + assert len(result) == 1 + + # Should have num_keys locks and previous records + assert len(tracker._dupe_locks) == num_keys + assert len(tracker._previous_records) == num_keys + + @pytest.mark.asyncio + @pytest.mark.parametrize("num_alternations", [3, 6, 10]) + async def test_alternating_values( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord], num_alternations: int + ): + """Test alternating between two values (no duplication - each is different).""" + # Alternate A, B, A, B, ... + sequence = [1 if i % 2 == 0 else 2 for i in range(num_alternations)] + results = await write_sequence(tracker, "key1", sequence) + + # Each value is different from previous, so all are written + total_written = sum(len(r) for r in results) + assert total_written == num_alternations + + @pytest.mark.asyncio + @pytest.mark.parametrize("num_duplicates", [10, 100, 1000]) + async def test_long_duplicate_sequence( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord], num_duplicates: int + ): + """Test handling a very long sequence of duplicates.""" + results = await write_sequence(tracker, "key1", [1] * num_duplicates) + + # First write should succeed, all others suppressed + assert len(results[0]) == 1 + for i in range(1, num_duplicates): + assert len(results[i]) == 0 + + # Should have num_duplicates - 1 duplicates + assert tracker._dupe_counts["key1"] == num_duplicates - 1 + + @pytest.mark.asyncio + async def test_none_values(self): + """Test handling None as record values.""" + + @dataclass + class NullableRecord: + key: str + value: int | None + + tracker_nullable: AsyncKeyedDuplicateTracker[NullableRecord] = ( + AsyncKeyedDuplicateTracker[NullableRecord]( + key_function=lambda record: record.key, + value_function=lambda record: record.value, + ) + ) + + record1 = NullableRecord(key="key1", value=None) + result1 = await tracker_nullable.deduplicate_record(record1) + assert len(result1) == 1 + assert result1[0].value is None + + record2 = NullableRecord(key="key1", value=None) + result2 = await tracker_nullable.deduplicate_record(record2) + assert len(result2) == 0 # Duplicate + + record3 = NullableRecord(key="key1", value=1) + result3 = await tracker_nullable.deduplicate_record(record3) + assert len(result3) == 2 # Previous None + new 1 + + +class TestAsyncKeyedDuplicateTrackerFlush: + """Test flushing remaining duplicates.""" + + @pytest.mark.asyncio + async def test_flush_with_no_pending_duplicates( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord] + ): + """Test that flush returns empty list when no pending duplicates.""" + await write_sequence(tracker, "key1", [1, 2, 3]) + + # No duplicates, so nothing to flush + to_flush = await tracker.flush_remaining_duplicates() + assert len(to_flush) == 0 + + @pytest.mark.asyncio + async def test_flush_with_pending_duplicates( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord] + ): + """Test that flush returns pending duplicates.""" + # Write A, A, A (2 pending duplicates) + await write_sequence(tracker, "key1", [1, 1, 1]) + + # Should flush the last A + to_flush = await tracker.flush_remaining_duplicates() + assert len(to_flush) == 1 + assert to_flush[0].value == 1 # Record value + + # Dupe count should be reset + assert tracker._dupe_counts["key1"] == 0 + + @pytest.mark.asyncio + async def test_flush_multiple_keys( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord] + ): + """Test flushing pending duplicates from multiple keys.""" + # key1: A, A, A + await write_sequence(tracker, "key1", [1, 1, 1]) + # key2: B, B + await write_sequence(tracker, "key2", [2, 2]) + # key3: C (no duplicates) + await write_sequence(tracker, "key3", [3]) + + # Should flush key1 and key2, but not key3 + to_flush = await tracker.flush_remaining_duplicates() + assert len(to_flush) == 2 + + flushed_values = {record.value for record in to_flush} + assert flushed_values == {1, 2} + + @pytest.mark.asyncio + async def test_flush_idempotent( + self, tracker: AsyncKeyedDuplicateTracker[SampleRecord] + ): + """Test that calling flush multiple times doesn't duplicate records.""" + await write_sequence(tracker, "key1", [1, 1, 1]) + + # First flush + to_flush1 = await tracker.flush_remaining_duplicates() + assert len(to_flush1) == 1 + + # Second flush should return nothing + to_flush2 = await tracker.flush_remaining_duplicates() + assert len(to_flush2) == 0 diff --git a/tests/unit/common/test_metric_utils.py b/tests/unit/common/test_metric_utils.py new file mode 100644 index 000000000..e5cafe0a3 --- /dev/null +++ b/tests/unit/common/test_metric_utils.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from aiperf.common.metric_utils import ( + build_hostname_aware_prometheus_endpoints, + normalize_metrics_endpoint_url, +) + + +class TestNormalizeMetricsEndpointUrl: + """Test URL normalization for metrics endpoints.""" + + @pytest.mark.parametrize( + "input_url,expected", + [ + ("http://localhost:9400", "http://localhost:9400/metrics"), + ("http://localhost:9400/", "http://localhost:9400/metrics"), + ("http://localhost:9400/metrics", "http://localhost:9400/metrics"), + ("http://localhost:9400/metrics/", "http://localhost:9400/metrics"), + ("http://node1:8081", "http://node1:8081/metrics"), + ("https://secure:443", "https://secure:443/metrics"), + ("http://10.0.0.1:9090", "http://10.0.0.1:9090/metrics"), + ], + ) # fmt: skip + def test_normalize_url_variations(self, input_url: str, expected: str): + """Test URL normalization handles various input formats correctly.""" + assert normalize_metrics_endpoint_url(input_url) == expected + + def test_normalize_preserves_scheme(self): + """Test that URL normalization preserves https scheme.""" + result = normalize_metrics_endpoint_url("https://secure:9400") + assert result.startswith("https://") + assert result == "https://secure:9400/metrics" + + def test_normalize_removes_trailing_slashes(self): + """Test that multiple trailing slashes are removed.""" + result = normalize_metrics_endpoint_url("http://localhost:9400///") + assert result == "http://localhost:9400/metrics" + + +class TestBuildHostnameAwarePrometheusEndpoints: + """Test hostname-aware Prometheus endpoint URL generation.""" + + def test_basic_endpoint_generation(self): + """Test generating endpoints from inference URL and default ports.""" + endpoints = build_hostname_aware_prometheus_endpoints( + "http://localhost:8000/v1/chat", [9400, 9401], include_inference_port=False + ) + assert len(endpoints) == 2 + assert "http://localhost:9400/metrics" in endpoints + assert "http://localhost:9401/metrics" in endpoints + + def test_preserves_scheme(self): + """Test that generated endpoints use the same scheme as inference endpoint.""" + endpoints = build_hostname_aware_prometheus_endpoints( + "https://secure-server:8000", [8081], include_inference_port=False + ) + assert len(endpoints) == 1 + assert endpoints[0] == "https://secure-server:8081/metrics" + + def test_extracts_hostname_from_url(self): + """Test hostname extraction from various URL formats.""" + test_cases = [ + ("http://node1:8000/v1/chat", [8081], "http://node1:8081/metrics"), + ("http://10.0.0.5:8000", [9090], "http://10.0.0.5:9090/metrics"), + ( + "https://api.example.com:443/inference", + [8081], + "https://api.example.com:8081/metrics", + ), + ] + for inference_url, ports, expected_endpoint in test_cases: + endpoints = build_hostname_aware_prometheus_endpoints( + inference_url, ports, include_inference_port=False + ) + assert expected_endpoint in endpoints + + def test_multiple_ports(self): + """Test generating multiple endpoints from multiple ports.""" + endpoints = build_hostname_aware_prometheus_endpoints( + "http://server:8000", [8081, 6880, 9090], include_inference_port=False + ) + assert len(endpoints) == 3 + assert "http://server:8081/metrics" in endpoints + assert "http://server:6880/metrics" in endpoints + assert "http://server:9090/metrics" in endpoints + + def test_empty_ports_list(self): + """Test behavior with empty ports list and no inference port.""" + endpoints = build_hostname_aware_prometheus_endpoints( + "http://localhost:8000", [], include_inference_port=False + ) + assert len(endpoints) == 0 + + def test_url_without_port(self): + """Test handling URLs without explicit port.""" + endpoints = build_hostname_aware_prometheus_endpoints( + "http://localhost/v1/chat", [8081], include_inference_port=False + ) + assert len(endpoints) == 1 + assert "http://localhost:8081/metrics" in endpoints + + def test_include_inference_port_with_explicit_port(self): + """Test including inference endpoint port when explicitly specified.""" + endpoints = build_hostname_aware_prometheus_endpoints( + "http://localhost:8000/v1/chat", [9400], include_inference_port=True + ) + assert len(endpoints) == 2 + assert "http://localhost:8000/metrics" in endpoints + assert "http://localhost:9400/metrics" in endpoints + # Inference port should be first + assert endpoints[0] == "http://localhost:8000/metrics" + + def test_include_inference_port_without_explicit_port(self): + """Test including default port when inference URL has no port.""" + endpoints = build_hostname_aware_prometheus_endpoints( + "http://localhost/v1/chat", [9400], include_inference_port=True + ) + assert len(endpoints) == 2 + assert "http://localhost:80/metrics" in endpoints + assert "http://localhost:9400/metrics" in endpoints + + def test_include_inference_port_https_default(self): + """Test including HTTPS default port 443.""" + endpoints = build_hostname_aware_prometheus_endpoints( + "https://secure/v1/chat", [9400], include_inference_port=True + ) + assert len(endpoints) == 2 + assert "https://secure:443/metrics" in endpoints + assert "https://secure:9400/metrics" in endpoints diff --git a/tests/unit/gpu_telemetry/test_telemetry_data_collector.py b/tests/unit/gpu_telemetry/test_telemetry_data_collector.py index 74a297cb7..6c08aa612 100644 --- a/tests/unit/gpu_telemetry/test_telemetry_data_collector.py +++ b/tests/unit/gpu_telemetry/test_telemetry_data_collector.py @@ -38,8 +38,8 @@ def test_collector_initialization_complete(self): collector_id="test_collector", ) - assert collector._dcgm_url == "http://localhost:9401/metrics" - assert collector._collection_interval == 0.1 + assert collector.endpoint_url == "http://localhost:9401/metrics" + assert collector.collection_interval == 0.1 assert collector.id == "test_collector" assert collector._session is None # Not initialized yet assert not collector.was_initialized @@ -54,8 +54,8 @@ def test_collector_initialization_minimal(self): """ collector = TelemetryDataCollector("http://localhost:9401/metrics") - assert collector._dcgm_url == "http://localhost:9401/metrics" - assert collector._collection_interval == 0.33 # Default collection interval + assert collector.endpoint_url == "http://localhost:9401/metrics" + assert collector.collection_interval == 0.33 # Default collection interval assert collector.id == "telemetry_collector" # Default ID assert collector._record_callback is None assert collector._error_callback is None @@ -274,7 +274,7 @@ async def test_metrics_fetching(self, sample_dcgm_data): await collector.initialize() - result = await collector._fetch_metrics() + result = await collector._fetch_metrics_text() assert result == sample_dcgm_data await collector.stop() @@ -291,7 +291,7 @@ async def test_fetch_metrics_session_closed(self): # Should raise CancelledError due to closed session with pytest.raises(asyncio.CancelledError): - await collector._fetch_metrics() + await collector._fetch_metrics_text() @pytest.mark.asyncio async def test_fetch_metrics_when_stop_requested(self): @@ -305,7 +305,7 @@ async def test_fetch_metrics_when_stop_requested(self): # Should raise CancelledError with pytest.raises(asyncio.CancelledError): - await collector._fetch_metrics() + await collector._fetch_metrics_text() # Clean up collector.stop_requested = False @@ -318,7 +318,7 @@ async def test_fetch_metrics_no_session(self): # Don't initialize - session is None with pytest.raises(RuntimeError, match="HTTP session not initialized"): - await collector._fetch_metrics() + await collector._fetch_metrics_text() class TestCollectionLifecycle: @@ -375,8 +375,8 @@ async def test_error_handling_in_collection_loop(self): await collector.initialize() - await collector._collect_telemetry_task() - await collector._collect_telemetry_task() + await collector._collect_metrics_task() + await collector._collect_metrics_task() await collector.stop() @@ -575,7 +575,7 @@ async def test_error_callback_exception_handling(self): mock_get.side_effect = aiohttp.ClientError("Connection failed") await collector.initialize() - await collector._collect_telemetry_task() + await collector._collect_metrics_task() await collector.stop() mock_error_callback.assert_called_once() diff --git a/tests/unit/post_processors/test_server_metrics_export_results_processor.py b/tests/unit/post_processors/test_server_metrics_export_results_processor.py new file mode 100644 index 000000000..8ec1830c7 --- /dev/null +++ b/tests/unit/post_processors/test_server_metrics_export_results_processor.py @@ -0,0 +1,1159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +from pathlib import Path + +import orjson +import pytest + +from aiperf.common.config import EndpointConfig, OutputConfig, ServiceConfig, UserConfig +from aiperf.common.enums import EndpointType, PrometheusMetricType +from aiperf.common.models.server_metrics_models import ( + HistogramData, + MetricFamily, + MetricSample, + ServerMetricsRecord, +) +from aiperf.post_processors.server_metrics_export_results_processor import ( + ServerMetricsExportResultsProcessor, +) +from tests.unit.post_processors.conftest import aiperf_lifecycle + + +@pytest.fixture +def user_config_server_metrics_export(tmp_artifact_dir: Path) -> UserConfig: + """Create UserConfig for server metrics export testing.""" + return UserConfig( + endpoint=EndpointConfig( + model_names=["test-model"], + type=EndpointType.CHAT, + ), + output=OutputConfig( + artifact_directory=tmp_artifact_dir, + ), + ) + + +@pytest.fixture +def sample_server_metrics_record_for_export() -> ServerMetricsRecord: + """Create sample ServerMetricsRecord for export testing.""" + return ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "requests_total": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Total requests", + samples=[ + MetricSample( + labels={"status": "success"}, + value=100.0, + ) + ], + ), + }, + ) + + +class TestServerMetricsExportResultsProcessorInitialization: + """Test ServerMetricsExportResultsProcessor initialization.""" + + def test_initialization( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test processor initializes with correct file paths.""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + assert ( + processor.output_file + == user_config_server_metrics_export.output.server_metrics_export_jsonl_file + ) + assert ( + processor._metadata_file + == user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + + def test_files_cleared_on_initialization( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + tmp_artifact_dir: Path, + ): + """Test that output files are cleared on initialization.""" + jsonl_file = tmp_artifact_dir / "server_metrics_export.jsonl" + metadata_file = tmp_artifact_dir / "server_metrics_metadata.json" + + jsonl_file.write_text("old data") + metadata_file.write_text("old metadata") + + ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + assert not jsonl_file.exists() or jsonl_file.stat().st_size == 0 + + +class TestServerMetricsRecordProcessing: + """Test processing ServerMetricsRecord objects.""" + + @pytest.mark.asyncio + async def test_process_single_record( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + sample_server_metrics_record_for_export: ServerMetricsRecord, + ): + """Test processing single server metrics record.""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + await processor.process_server_metrics_record( + sample_server_metrics_record_for_export + ) + + output_file = ( + user_config_server_metrics_export.output.server_metrics_export_jsonl_file + ) + assert output_file.exists() + + lines = output_file.read_text().strip().split("\n") + assert len(lines) == 1 + + data = orjson.loads(lines[0]) + assert data["endpoint_url"] == "http://localhost:8081/metrics" + assert data["timestamp_ns"] == 1_000_000_000 + assert data["endpoint_latency_ns"] == 5_000_000 + assert "metrics" in data + + @pytest.mark.asyncio + async def test_process_multiple_records( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test processing multiple server metrics records with different metrics.""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + for i in range(5): + record = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000 + i * 1_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "counter": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Test counter", + samples=[ + MetricSample( + labels={}, + value=float( + i + ), # Different values to avoid deduplication + ) + ], + ), + }, + ) + await processor.process_server_metrics_record(record) + + output_file = ( + user_config_server_metrics_export.output.server_metrics_export_jsonl_file + ) + lines = output_file.read_text().strip().split("\n") + assert len(lines) == 5 + + @pytest.mark.asyncio + async def test_record_converted_to_slim_format( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + sample_server_metrics_record_for_export: ServerMetricsRecord, + ): + """Test that records are converted to slim format before writing.""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + await processor.process_server_metrics_record( + sample_server_metrics_record_for_export + ) + + output_file = ( + user_config_server_metrics_export.output.server_metrics_export_jsonl_file + ) + data = orjson.loads(output_file.read_text().strip()) + + assert "metrics" in data + assert "requests_total" in data["metrics"] + + +class TestMetadataExtraction: + """Test metadata extraction and writing.""" + + @pytest.mark.asyncio + async def test_metadata_extracted_on_first_record( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + sample_server_metrics_record_for_export: ServerMetricsRecord, + ): + """Test that metadata is extracted and written on first record from endpoint.""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + await processor.process_server_metrics_record( + sample_server_metrics_record_for_export + ) + + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + assert metadata_file.exists() + + metadata_content = orjson.loads(metadata_file.read_bytes()) + assert "endpoints" in metadata_content + assert "http://localhost:8081/metrics" in metadata_content["endpoints"] + + @pytest.mark.asyncio + async def test_metadata_contains_metric_schemas( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + sample_server_metrics_record_for_export: ServerMetricsRecord, + ): + """Test that metadata includes metric schemas (type, help).""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + await processor.process_server_metrics_record( + sample_server_metrics_record_for_export + ) + + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + metadata_content = orjson.loads(metadata_file.read_bytes()) + + endpoint_metadata = metadata_content["endpoints"][ + "http://localhost:8081/metrics" + ] + assert "metric_schemas" in endpoint_metadata + assert "requests_total" in endpoint_metadata["metric_schemas"] + + schema = endpoint_metadata["metric_schemas"]["requests_total"] + assert schema["type"] == "counter" + assert schema["description"] == "Total requests" + + @pytest.mark.asyncio + async def test_histogram_schema_includes_bucket_labels( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that histogram schemas are exported correctly.""" + record = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "ttft": MetricFamily( + type=PrometheusMetricType.HISTOGRAM, + description="Time to first token", + samples=[ + MetricSample( + labels={"model": "test"}, + histogram=HistogramData( + buckets={"0.01": 5.0, "0.1": 15.0, "+Inf": 50.0}, + sum=5.5, + count=50.0, + ), + ) + ], + ) + }, + ) + + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + await processor.process_server_metrics_record(record) + + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + metadata_content = orjson.loads(metadata_file.read_bytes()) + + schema = metadata_content["endpoints"]["http://localhost:8081/metrics"][ + "metric_schemas" + ]["ttft"] + assert schema["type"] == "histogram" + assert schema["description"] == "Time to first token" + + @pytest.mark.asyncio + async def test_metadata_includes_unique_label_values( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that metadata includes schema for metrics with multiple samples.""" + record = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "requests_total": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Total requests", + samples=[ + MetricSample( + labels={"status": "success", "endpoint": "chat"}, + value=100.0, + ), + MetricSample( + labels={"status": "error", "endpoint": "chat"}, + value=10.0, + ), + MetricSample( + labels={"status": "success", "endpoint": "completions"}, + value=50.0, + ), + ], + ) + }, + ) + + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + await processor.process_server_metrics_record(record) + + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + metadata_content = orjson.loads(metadata_file.read_bytes()) + + schema = metadata_content["endpoints"]["http://localhost:8081/metrics"][ + "metric_schemas" + ]["requests_total"] + assert schema["type"] == "counter" + assert schema["description"] == "Total requests" + + @pytest.mark.asyncio + async def test_unique_label_values_respects_cardinality_limit( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + monkeypatch, + ): + """Test that metadata handles metrics with multiple label values.""" + # Create record with 3 unique label values + record = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "requests_total": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Total requests", + samples=[ + MetricSample(labels={"status": "success"}, value=100.0), + MetricSample(labels={"status": "error"}, value=10.0), + MetricSample(labels={"status": "timeout"}, value=5.0), + ], + ) + }, + ) + + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + await processor.process_server_metrics_record(record) + + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + metadata_content = orjson.loads(metadata_file.read_bytes()) + + schema = metadata_content["endpoints"]["http://localhost:8081/metrics"][ + "metric_schemas" + ]["requests_total"] + assert schema["type"] == "counter" + assert schema["description"] == "Total requests" + + @pytest.mark.asyncio + async def test_metadata_updated_for_multiple_endpoints( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that metadata file contains all endpoints.""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + for endpoint in ["http://node1:8081/metrics", "http://node2:8081/metrics"]: + record = ServerMetricsRecord( + endpoint_url=endpoint, + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={}, + ) + await processor.process_server_metrics_record(record) + + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + metadata_content = orjson.loads(metadata_file.read_bytes()) + + assert len(metadata_content["endpoints"]) == 2 + assert "http://node1:8081/metrics" in metadata_content["endpoints"] + assert "http://node2:8081/metrics" in metadata_content["endpoints"] + + @pytest.mark.asyncio + async def test_metadata_only_written_once_per_endpoint( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that metadata is only extracted on first record per endpoint.""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + for _ in range(3): + record = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={}, + ) + await processor.process_server_metrics_record(record) + + assert ( + "http://localhost:8081/metrics" in processor._metadata_file_model.endpoints + ) + assert len(processor._metadata_file_model.endpoints) == 1 + + +class TestSummarizeMethod: + """Test summarize method behavior.""" + + @pytest.mark.asyncio + async def test_summarize_returns_empty_list( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that summarize returns empty list (export processors don't summarize).""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + results = await processor.summarize() + + assert results == [] + + +class TestMetadataReconciliation: + """Test metadata reconciliation for evolving metrics.""" + + @pytest.mark.asyncio + async def test_new_metrics_appearing_later( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that new metrics appearing in later records are captured.""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + # First record with metric A + record1 = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "metric_a": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Metric A", + samples=[MetricSample(labels={}, value=100.0)], + ), + }, + ) + await processor.process_server_metrics_record(record1) + + # Second record with metrics A and B + record2 = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=2_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "metric_a": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Metric A", + samples=[MetricSample(labels={}, value=101.0)], + ), + "metric_b": MetricFamily( + type=PrometheusMetricType.GAUGE, + description="Metric B", + samples=[MetricSample(labels={}, value=50.0)], + ), + }, + ) + await processor.process_server_metrics_record(record2) + + # Verify metadata includes both metrics + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + metadata_content = orjson.loads(metadata_file.read_bytes()) + endpoint_metadata = metadata_content["endpoints"][ + "http://localhost:8081/metrics" + ] + + assert "metric_a" in endpoint_metadata["metric_schemas"] + assert "metric_b" in endpoint_metadata["metric_schemas"] + assert endpoint_metadata["metric_schemas"]["metric_a"]["type"] == "counter" + assert endpoint_metadata["metric_schemas"]["metric_b"]["type"] == "gauge" + + @pytest.mark.asyncio + async def test_same_count_different_metrics( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that different metrics with same count are detected.""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + # First record with metrics A, B, C + record1 = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "metric_a": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Metric A", + samples=[MetricSample(labels={}, value=1.0)], + ), + "metric_b": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Metric B", + samples=[MetricSample(labels={}, value=2.0)], + ), + "metric_c": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Metric C", + samples=[MetricSample(labels={}, value=3.0)], + ), + }, + ) + await processor.process_server_metrics_record(record1) + + # Second record with metrics B, C, D (same count, but D is new) + record2 = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=2_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "metric_b": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Metric B", + samples=[MetricSample(labels={}, value=2.0)], + ), + "metric_c": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Metric C", + samples=[MetricSample(labels={}, value=3.0)], + ), + "metric_d": MetricFamily( + type=PrometheusMetricType.GAUGE, + description="Metric D", + samples=[MetricSample(labels={}, value=4.0)], + ), + }, + ) + await processor.process_server_metrics_record(record2) + + # Verify metadata includes all metrics (A, B, C, D) + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + metadata_content = orjson.loads(metadata_file.read_bytes()) + endpoint_metadata = metadata_content["endpoints"][ + "http://localhost:8081/metrics" + ] + + assert len(endpoint_metadata["metric_schemas"]) == 4 + assert "metric_a" in endpoint_metadata["metric_schemas"] + assert "metric_b" in endpoint_metadata["metric_schemas"] + assert "metric_c" in endpoint_metadata["metric_schemas"] + assert "metric_d" in endpoint_metadata["metric_schemas"] + + @pytest.mark.asyncio + async def test_histogram_bucket_changes( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that new histogram buckets are detected and merged.""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + # First record with histogram with 3 buckets + record1 = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "request_duration": MetricFamily( + type=PrometheusMetricType.HISTOGRAM, + description="Request duration", + samples=[ + MetricSample( + labels={}, + histogram=HistogramData( + buckets={"0.1": 10.0, "0.5": 50.0, "+Inf": 100.0}, + sum=25.5, + count=100.0, + ), + ) + ], + ), + }, + ) + await processor.process_server_metrics_record(record1) + + # Second record with 5 buckets (added 0.01 and 1.0) + record2 = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=2_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "request_duration": MetricFamily( + type=PrometheusMetricType.HISTOGRAM, + description="Request duration", + samples=[ + MetricSample( + labels={}, + histogram=HistogramData( + buckets={ + "0.01": 5.0, + "0.1": 15.0, + "0.5": 60.0, + "1.0": 80.0, + "+Inf": 120.0, + }, + sum=35.5, + count=120.0, + ), + ) + ], + ), + }, + ) + await processor.process_server_metrics_record(record2) + + # Verify metadata includes schema + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + metadata_content = orjson.loads(metadata_file.read_bytes()) + schema = metadata_content["endpoints"]["http://localhost:8081/metrics"][ + "metric_schemas" + ]["request_duration"] + + assert schema["type"] == "histogram" + assert schema["description"] == "Request duration" + + @pytest.mark.asyncio + async def test_summary_quantile_changes( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that new summary quantiles are detected and merged.""" + from aiperf.common.models.server_metrics_models import SummaryData + + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + # First record with 3 quantiles + record1 = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "response_time": MetricFamily( + type=PrometheusMetricType.SUMMARY, + description="Response time", + samples=[ + MetricSample( + labels={}, + summary=SummaryData( + quantiles={"0.5": 0.1, "0.9": 0.5, "0.99": 1.0}, + sum=50.0, + count=100.0, + ), + ) + ], + ), + }, + ) + await processor.process_server_metrics_record(record1) + + # Second record with 5 quantiles (added 0.25 and 0.95) + record2 = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=2_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "response_time": MetricFamily( + type=PrometheusMetricType.SUMMARY, + description="Response time", + samples=[ + MetricSample( + labels={}, + summary=SummaryData( + quantiles={ + "0.25": 0.05, + "0.5": 0.12, + "0.9": 0.55, + "0.95": 0.75, + "0.99": 1.1, + }, + sum=60.0, + count=120.0, + ), + ) + ], + ), + }, + ) + await processor.process_server_metrics_record(record2) + + # Verify metadata includes all quantiles (union) + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + metadata_content = orjson.loads(metadata_file.read_bytes()) + schema = metadata_content["endpoints"]["http://localhost:8081/metrics"][ + "metric_schemas" + ]["response_time"] + + assert schema["type"] == "summary" + assert schema["description"] == "Response time" + + @pytest.mark.asyncio + async def test_no_update_for_identical_metadata( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that metadata file is not rewritten when metrics don't change.""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + # First record + record1 = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "metric_a": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Metric A", + samples=[MetricSample(labels={}, value=100.0)], + ), + }, + ) + await processor.process_server_metrics_record(record1) + + metadata_file = user_config_server_metrics_export.output.server_metrics_metadata_json_file + first_mtime = metadata_file.stat().st_mtime_ns + + # Wait a bit to ensure timestamp would change if file is rewritten + await asyncio.sleep(0.01) + + # Second record with same metrics (different values) + record2 = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=2_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "metric_a": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Metric A", + samples=[MetricSample(labels={}, value=105.0)], + ), + }, + ) + await processor.process_server_metrics_record(record2) + + second_mtime = metadata_file.stat().st_mtime_ns + + # Metadata file should not be rewritten + assert first_mtime == second_mtime + + @pytest.mark.asyncio + async def test_metadata_merge_is_idempotent( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that merging the same metadata multiple times produces same result.""" + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + # Record with histogram + record = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "duration": MetricFamily( + type=PrometheusMetricType.HISTOGRAM, + description="Duration", + samples=[ + MetricSample( + labels={}, + histogram=HistogramData( + buckets={"0.1": 10.0, "0.5": 50.0, "+Inf": 100.0}, + sum=25.5, + count=100.0, + ), + ) + ], + ), + }, + ) + + # Process same record 3 times + await processor.process_server_metrics_record(record) + await processor.process_server_metrics_record(record) + await processor.process_server_metrics_record(record) + + # Verify schema is present + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + metadata_content = orjson.loads(metadata_file.read_bytes()) + schema = metadata_content["endpoints"]["http://localhost:8081/metrics"][ + "metric_schemas" + ]["duration"] + + assert schema["type"] == "histogram" + assert schema["description"] == "Duration" + + +class TestInfoMetricsHandling: + """Test that _info metrics are properly separated in metadata and excluded from slim records.""" + + @pytest.mark.asyncio + async def test_info_metrics_stored_separately_in_metadata( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that metrics ending in _info are stored in info_metrics field.""" + record = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "python_info": MetricFamily( + type=PrometheusMetricType.GAUGE, + description="Python platform information", + samples=[ + MetricSample( + labels={"version": "3.10.0", "implementation": "CPython"}, + value=1.0, + ) + ], + ), + "requests_total": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Total requests", + samples=[ + MetricSample( + labels={"status": "success"}, + value=100.0, + ) + ], + ), + }, + ) + + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + await processor.process_server_metrics_record(record) + + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + metadata_content = orjson.loads(metadata_file.read_bytes()) + + endpoint_metadata = metadata_content["endpoints"][ + "http://localhost:8081/metrics" + ] + + # Verify python_info is in info_metrics, not metric_schemas + assert "python_info" in endpoint_metadata["info_metrics"] + assert "python_info" not in endpoint_metadata["metric_schemas"] + + # Verify regular metrics are still in metric_schemas + assert "requests_total" in endpoint_metadata["metric_schemas"] + assert "requests_total" not in endpoint_metadata["info_metrics"] + + # Verify the info_metric contains description + labels (no values or type) + info_data = endpoint_metadata["info_metrics"]["python_info"] + assert "type" not in info_data + assert info_data["description"] == "Python platform information" + + # Verify labels are stored as list of dicts (values omitted) + assert "labels" in info_data + assert len(info_data["labels"]) == 1 + labels = info_data["labels"][0] + assert labels == {"version": "3.10.0", "implementation": "CPython"} + # Verify no value field + assert "value" not in info_data + + @pytest.mark.asyncio + async def test_info_metrics_excluded_from_slim_records( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that metrics ending in _info are excluded from slim JSONL records.""" + record = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "python_info": MetricFamily( + type=PrometheusMetricType.GAUGE, + description="Python platform information", + samples=[ + MetricSample( + labels={"version": "3.10.0"}, + value=1.0, + ) + ], + ), + "process_info": MetricFamily( + type=PrometheusMetricType.GAUGE, + description="Process information", + samples=[ + MetricSample( + labels={"pid": "1234"}, + value=1.0, + ) + ], + ), + "requests_total": MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Total requests", + samples=[ + MetricSample( + labels={"status": "success"}, + value=100.0, + ) + ], + ), + }, + ) + + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + await processor.process_server_metrics_record(record) + + jsonl_file = ( + user_config_server_metrics_export.output.server_metrics_export_jsonl_file + ) + lines = jsonl_file.read_text().strip().split("\n") + + # Should have 1 line + assert len(lines) == 1 + + slim_record = orjson.loads(lines[0]) + + # Verify _info metrics are NOT in the slim record + assert "python_info" not in slim_record["metrics"] + assert "process_info" not in slim_record["metrics"] + + # Verify regular metrics ARE in the slim record + assert "requests_total" in slim_record["metrics"] + + @pytest.mark.asyncio + async def test_mixed_info_and_regular_metrics( + self, + user_config_server_metrics_export: UserConfig, + service_config: ServiceConfig, + ): + """Test handling of multiple _info metrics alongside regular metrics.""" + record = ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "python_info": MetricFamily( + type=PrometheusMetricType.GAUGE, + description="Python info", + samples=[MetricSample(labels={}, value=1.0)], + ), + "server_info": MetricFamily( + type=PrometheusMetricType.GAUGE, + description="Server info", + samples=[MetricSample(labels={}, value=1.0)], + ), + "cpu_usage": MetricFamily( + type=PrometheusMetricType.GAUGE, + description="CPU usage", + samples=[MetricSample(labels={}, value=42.0)], + ), + "memory_usage": MetricFamily( + type=PrometheusMetricType.GAUGE, + description="Memory usage", + samples=[MetricSample(labels={}, value=1024.0)], + ), + }, + ) + + processor = ServerMetricsExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_server_metrics_export, + ) + + async with aiperf_lifecycle(processor): + await processor.process_server_metrics_record(record) + + # Check metadata file + metadata_file = ( + user_config_server_metrics_export.output.server_metrics_metadata_json_file + ) + metadata_content = orjson.loads(metadata_file.read_bytes()) + endpoint_metadata = metadata_content["endpoints"][ + "http://localhost:8081/metrics" + ] + + # Verify correct classification + assert len(endpoint_metadata["info_metrics"]) == 2 + assert "python_info" in endpoint_metadata["info_metrics"] + assert "server_info" in endpoint_metadata["info_metrics"] + + assert len(endpoint_metadata["metric_schemas"]) == 2 + assert "cpu_usage" in endpoint_metadata["metric_schemas"] + assert "memory_usage" in endpoint_metadata["metric_schemas"] + + # Verify info metrics have labels (no values) + python_info = endpoint_metadata["info_metrics"]["python_info"] + assert "labels" in python_info + assert len(python_info["labels"]) == 1 + assert isinstance(python_info["labels"][0], dict) + assert "value" not in python_info + + server_info = endpoint_metadata["info_metrics"]["server_info"] + assert "labels" in server_info + assert len(server_info["labels"]) == 1 + assert isinstance(server_info["labels"][0], dict) + assert "value" not in server_info + + # Check JSONL file + jsonl_file = ( + user_config_server_metrics_export.output.server_metrics_export_jsonl_file + ) + slim_record = orjson.loads(jsonl_file.read_text().strip()) + + # Only regular metrics in slim record + assert len(slim_record["metrics"]) == 2 + assert "cpu_usage" in slim_record["metrics"] + assert "memory_usage" in slim_record["metrics"] + assert "python_info" not in slim_record["metrics"] + assert "server_info" not in slim_record["metrics"] diff --git a/tests/unit/post_processors/test_telemetry_export_results_processor.py b/tests/unit/post_processors/test_telemetry_export_results_processor.py index c348831fa..38f83cf32 100644 --- a/tests/unit/post_processors/test_telemetry_export_results_processor.py +++ b/tests/unit/post_processors/test_telemetry_export_results_processor.py @@ -361,7 +361,7 @@ async def test_buffer_auto_flush_at_batch_size( gpu_uuid="GPU-test", gpu_model_name="Test GPU", hostname="node1", - telemetry_data=TelemetryMetrics(gpu_power_usage=100.0), + telemetry_data=TelemetryMetrics(gpu_power_usage=100.0 + i), ) await processor.process_telemetry_record(record) @@ -471,7 +471,7 @@ async def test_records_written_count( gpu_uuid="GPU-test", gpu_model_name="Test GPU", hostname="node1", - telemetry_data=TelemetryMetrics(gpu_power_usage=100.0), + telemetry_data=TelemetryMetrics(gpu_power_usage=100.0 + i), ) await processor.process_telemetry_record(record) @@ -858,7 +858,7 @@ async def test_flush_on_shutdown( gpu_uuid="GPU-test", gpu_model_name="Test GPU", hostname="node1", - telemetry_data=TelemetryMetrics(gpu_power_usage=100.0), + telemetry_data=TelemetryMetrics(gpu_power_usage=100.0 + i), ) await processor.process_telemetry_record(record) @@ -891,7 +891,7 @@ async def test_wait_for_async_tasks( gpu_uuid="GPU-test", gpu_model_name="Test GPU", hostname="node1", - telemetry_data=TelemetryMetrics(gpu_power_usage=100.0), + telemetry_data=TelemetryMetrics(gpu_power_usage=100.0 + i), ) await processor.process_telemetry_record(record) @@ -928,7 +928,7 @@ async def test_statistics_logged_on_shutdown( gpu_uuid="GPU-test", gpu_model_name="Test GPU", hostname="node1", - telemetry_data=TelemetryMetrics(gpu_power_usage=100.0), + telemetry_data=TelemetryMetrics(gpu_power_usage=100.0 + i), ) await processor.process_telemetry_record(record) @@ -1041,7 +1041,7 @@ async def test_large_batch_processing( gpu_uuid="GPU-test", gpu_model_name="Test GPU", hostname="node1", - telemetry_data=TelemetryMetrics(gpu_power_usage=100.0), + telemetry_data=TelemetryMetrics(gpu_power_usage=100.0 + i), ) await processor.process_telemetry_record(record) diff --git a/tests/unit/server_metrics/__init__.py b/tests/unit/server_metrics/__init__.py new file mode 100644 index 000000000..1a8431c3e --- /dev/null +++ b/tests/unit/server_metrics/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/server_metrics/conftest.py b/tests/unit/server_metrics/conftest.py new file mode 100644 index 000000000..865d1d4db --- /dev/null +++ b/tests/unit/server_metrics/conftest.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from aiperf.common.enums import PrometheusMetricType +from aiperf.common.models.server_metrics_models import ( + HistogramData, + MetricFamily, + MetricSample, + ServerMetricsRecord, + SummaryData, +) + + +@pytest.fixture +def sample_prometheus_metrics() -> str: + """Sample Prometheus metrics text from vLLM endpoint.""" + return """# HELP vllm:request_success_total Total number of successful requests +# TYPE vllm:request_success_total counter +vllm:request_success_total{model_name="meta-llama/Llama-3.1-8B-Instruct"} 150.0 + +# HELP vllm:time_to_first_token_seconds Time to first token +# TYPE vllm:time_to_first_token_seconds histogram +vllm:time_to_first_token_seconds_bucket{model_name="meta-llama/Llama-3.1-8B-Instruct",le="0.001"} 0.0 +vllm:time_to_first_token_seconds_bucket{model_name="meta-llama/Llama-3.1-8B-Instruct",le="0.005"} 5.0 +vllm:time_to_first_token_seconds_bucket{model_name="meta-llama/Llama-3.1-8B-Instruct",le="0.01"} 15.0 +vllm:time_to_first_token_seconds_bucket{model_name="meta-llama/Llama-3.1-8B-Instruct",le="+Inf"} 150.0 +vllm:time_to_first_token_seconds_sum{model_name="meta-llama/Llama-3.1-8B-Instruct"} 125.5 +vllm:time_to_first_token_seconds_count{model_name="meta-llama/Llama-3.1-8B-Instruct"} 150.0 + +# HELP vllm:gpu_cache_usage_perc GPU KV-cache usage percentage +# TYPE vllm:gpu_cache_usage_perc gauge +vllm:gpu_cache_usage_perc{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.42 +""" + + +@pytest.fixture +def sample_counter_metric() -> MetricFamily: + """Sample counter metric family.""" + return MetricFamily( + type=PrometheusMetricType.COUNTER, + description="Total number of successful requests", + samples=[ + MetricSample( + labels={"model_name": "meta-llama/Llama-3.1-8B-Instruct"}, + value=150.0, + ) + ], + ) + + +@pytest.fixture +def sample_gauge_metric() -> MetricFamily: + """Sample gauge metric family.""" + return MetricFamily( + type=PrometheusMetricType.GAUGE, + description="GPU KV-cache usage percentage", + samples=[ + MetricSample( + labels={"model_name": "meta-llama/Llama-3.1-8B-Instruct"}, + value=0.42, + ) + ], + ) + + +@pytest.fixture +def sample_histogram_metric() -> MetricFamily: + """Sample histogram metric family.""" + return MetricFamily( + type=PrometheusMetricType.HISTOGRAM, + description="Time to first token", + samples=[ + MetricSample( + labels={"model_name": "meta-llama/Llama-3.1-8B-Instruct"}, + histogram=HistogramData( + buckets={"0.001": 0.0, "0.005": 5.0, "0.01": 15.0, "+Inf": 150.0}, + sum=125.5, + count=150.0, + ), + ) + ], + ) + + +@pytest.fixture +def sample_summary_metric() -> MetricFamily: + """Sample summary metric family.""" + return MetricFamily( + type=PrometheusMetricType.SUMMARY, + description="Request latency quantiles", + samples=[ + MetricSample( + labels={"model_name": "test-model"}, + summary=SummaryData( + quantiles={"0.5": 0.1, "0.9": 0.5, "0.99": 1.0}, + sum=50.0, + count=100.0, + ), + ) + ], + ) + + +@pytest.fixture +def sample_server_metrics_record( + sample_counter_metric: MetricFamily, + sample_gauge_metric: MetricFamily, + sample_histogram_metric: MetricFamily, +) -> ServerMetricsRecord: + """Sample ServerMetricsRecord with multiple metric types.""" + return ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "vllm:request_success_total": sample_counter_metric, + "vllm:gpu_cache_usage_perc": sample_gauge_metric, + "vllm:time_to_first_token_seconds": sample_histogram_metric, + }, + ) diff --git a/tests/unit/server_metrics/test_malformed_metrics_validation.py b/tests/unit/server_metrics/test_malformed_metrics_validation.py new file mode 100644 index 000000000..495f84d3f --- /dev/null +++ b/tests/unit/server_metrics/test_malformed_metrics_validation.py @@ -0,0 +1,227 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for validation of malformed histogram and summary metrics.""" + +import pytest + +from aiperf.server_metrics.server_metrics_data_collector import ( + ServerMetricsDataCollector, +) + + +class TestMalformedHistogramValidation: + """Test that malformed histograms are properly rejected.""" + + @pytest.mark.asyncio + async def test_histogram_with_no_buckets_skipped(self): + """Test that histograms with sum/count but no buckets are skipped.""" + collector = ServerMetricsDataCollector( + endpoint_url="http://localhost:8080/metrics", + collection_interval=1.0, + ) + + # Malformed histogram: has sum and count but no bucket samples + malformed_prometheus = """# HELP http_request_duration_seconds Request duration +# TYPE http_request_duration_seconds histogram +http_request_duration_seconds_sum 100.0 +http_request_duration_seconds_count 50 +""" + records = collector._parse_metrics_to_records(malformed_prometheus, 1000) + + # Should return empty list because histogram is incomplete (empty snapshots are suppressed) + assert len(records) == 0 + + @pytest.mark.asyncio + async def test_histogram_with_only_buckets_skipped(self): + """Test that histograms with only buckets (no sum/count) are skipped.""" + collector = ServerMetricsDataCollector( + endpoint_url="http://localhost:8080/metrics", + collection_interval=1.0, + ) + + # Malformed histogram: has buckets but no sum/count + malformed_prometheus = """# HELP http_request_duration_seconds Request duration +# TYPE http_request_duration_seconds histogram +http_request_duration_seconds_bucket{le="0.1"} 10 +http_request_duration_seconds_bucket{le="1.0"} 25 +http_request_duration_seconds_bucket{le="+Inf"} 50 +""" + records = collector._parse_metrics_to_records(malformed_prometheus, 1000) + + # Should return empty list because histogram is incomplete (empty snapshots are suppressed) + assert len(records) == 0 + + @pytest.mark.asyncio + async def test_histogram_with_only_sum_skipped(self): + """Test that histograms with only sum (no count/buckets) are skipped.""" + collector = ServerMetricsDataCollector( + endpoint_url="http://localhost:8080/metrics", + collection_interval=1.0, + ) + + # Malformed histogram: has only sum + malformed_prometheus = """# HELP http_request_duration_seconds Request duration +# TYPE http_request_duration_seconds histogram +http_request_duration_seconds_sum 100.0 +""" + records = collector._parse_metrics_to_records(malformed_prometheus, 1000) + + # Should return empty list because histogram is incomplete (empty snapshots are suppressed) + assert len(records) == 0 + + @pytest.mark.asyncio + async def test_valid_histogram_accepted(self): + """Test that valid histograms with all required fields are accepted.""" + collector = ServerMetricsDataCollector( + endpoint_url="http://localhost:8080/metrics", + collection_interval=1.0, + ) + + # Valid histogram: has buckets, sum, and count + valid_prometheus = """# HELP http_request_duration_seconds Request duration +# TYPE http_request_duration_seconds histogram +http_request_duration_seconds_bucket{le="0.1"} 10 +http_request_duration_seconds_bucket{le="1.0"} 25 +http_request_duration_seconds_bucket{le="+Inf"} 50 +http_request_duration_seconds_sum 100.0 +http_request_duration_seconds_count 50 +""" + records = collector._parse_metrics_to_records(valid_prometheus, 1000) + + # Should successfully parse the complete histogram + assert len(records) == 1 + record = records[0] + assert "http_request_duration_seconds" in record.metrics + metric_family = record.metrics["http_request_duration_seconds"] + assert len(metric_family.samples) == 1 + assert metric_family.samples[0].histogram is not None + assert len(metric_family.samples[0].histogram.buckets) == 3 + assert metric_family.samples[0].histogram.sum == 100.0 + assert metric_family.samples[0].histogram.count == 50 + + +class TestMalformedSummaryValidation: + """Test that malformed summaries are properly rejected.""" + + @pytest.mark.asyncio + async def test_summary_with_no_quantiles_skipped(self): + """Test that summaries with sum/count but no quantiles are skipped.""" + collector = ServerMetricsDataCollector( + endpoint_url="http://localhost:8080/metrics", + collection_interval=1.0, + ) + + # Malformed summary: has sum and count but no quantile samples + malformed_prometheus = """# HELP http_request_duration_seconds Request duration +# TYPE http_request_duration_seconds summary +http_request_duration_seconds_sum 100.0 +http_request_duration_seconds_count 50 +""" + records = collector._parse_metrics_to_records(malformed_prometheus, 1000) + + # Should return empty list because summary is incomplete (empty snapshots are suppressed) + assert len(records) == 0 + + @pytest.mark.asyncio + async def test_summary_with_only_quantiles_skipped(self): + """Test that summaries with only quantiles (no sum/count) are skipped.""" + collector = ServerMetricsDataCollector( + endpoint_url="http://localhost:8080/metrics", + collection_interval=1.0, + ) + + # Malformed summary: has quantiles but no sum/count + malformed_prometheus = """# HELP http_request_duration_seconds Request duration +# TYPE http_request_duration_seconds summary +http_request_duration_seconds{quantile="0.5"} 0.1 +http_request_duration_seconds{quantile="0.9"} 0.5 +http_request_duration_seconds{quantile="0.99"} 1.0 +""" + records = collector._parse_metrics_to_records(malformed_prometheus, 1000) + + # Should return empty list because summary is incomplete (empty snapshots are suppressed) + assert len(records) == 0 + + @pytest.mark.asyncio + async def test_summary_with_only_sum_skipped(self): + """Test that summaries with only sum (no count/quantiles) are skipped.""" + collector = ServerMetricsDataCollector( + endpoint_url="http://localhost:8080/metrics", + collection_interval=1.0, + ) + + # Malformed summary: has only sum + malformed_prometheus = """# HELP http_request_duration_seconds Request duration +# TYPE http_request_duration_seconds summary +http_request_duration_seconds_sum 100.0 +""" + records = collector._parse_metrics_to_records(malformed_prometheus, 1000) + + # Should return empty list because summary is incomplete (empty snapshots are suppressed) + assert len(records) == 0 + + @pytest.mark.asyncio + async def test_valid_summary_accepted(self): + """Test that valid summaries with all required fields are accepted.""" + collector = ServerMetricsDataCollector( + endpoint_url="http://localhost:8080/metrics", + collection_interval=1.0, + ) + + # Valid summary: has quantiles, sum, and count + valid_prometheus = """# HELP http_request_duration_seconds Request duration +# TYPE http_request_duration_seconds summary +http_request_duration_seconds{quantile="0.5"} 0.1 +http_request_duration_seconds{quantile="0.9"} 0.5 +http_request_duration_seconds{quantile="0.99"} 1.0 +http_request_duration_seconds_sum 100.0 +http_request_duration_seconds_count 50 +""" + records = collector._parse_metrics_to_records(valid_prometheus, 1000) + + # Should successfully parse the complete summary + assert len(records) == 1 + record = records[0] + assert "http_request_duration_seconds" in record.metrics + metric_family = record.metrics["http_request_duration_seconds"] + assert len(metric_family.samples) == 1 + assert metric_family.samples[0].summary is not None + assert len(metric_family.samples[0].summary.quantiles) == 3 + assert metric_family.samples[0].summary.sum == 100.0 + assert metric_family.samples[0].summary.count == 50 + + +class TestMixedValidAndInvalidMetrics: + """Test behavior when some metrics are valid and others are malformed.""" + + @pytest.mark.asyncio + async def test_valid_metrics_preserved_when_invalid_skipped(self): + """Test that valid metrics are still processed when invalid ones are skipped.""" + collector = ServerMetricsDataCollector( + endpoint_url="http://localhost:8080/metrics", + collection_interval=1.0, + ) + + # Mix of valid counter, invalid histogram, and valid gauge + mixed_prometheus = """# HELP requests_total Total requests +# TYPE requests_total counter +requests_total 1000 + +# HELP http_duration_seconds Request duration (malformed - no buckets) +# TYPE http_duration_seconds histogram +http_duration_seconds_sum 100.0 +http_duration_seconds_count 50 + +# HELP active_connections Active connections +# TYPE active_connections gauge +active_connections 42 +""" + records = collector._parse_metrics_to_records(mixed_prometheus, 1000) + + # Should have valid counter and gauge, but not the malformed histogram + assert len(records) == 1 + record = records[0] + # Note: prometheus_client strips _total suffix from counter names + assert "requests" in record.metrics + assert "active_connections" in record.metrics + assert "http_duration_seconds" not in record.metrics diff --git a/tests/unit/server_metrics/test_server_metrics_data_collector.py b/tests/unit/server_metrics/test_server_metrics_data_collector.py new file mode 100644 index 000000000..61bfd84c7 --- /dev/null +++ b/tests/unit/server_metrics/test_server_metrics_data_collector.py @@ -0,0 +1,392 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock, patch + +import aiohttp +import pytest + +from aiperf.common.enums import PrometheusMetricType +from aiperf.common.models import ErrorDetails +from aiperf.common.models.server_metrics_models import ServerMetricsRecord +from aiperf.server_metrics.server_metrics_data_collector import ( + ServerMetricsDataCollector, +) + + +class TestServerMetricsDataCollectorInitialization: + """Test ServerMetricsDataCollector initialization.""" + + def test_initialization_complete(self): + """Test collector initialization with all parameters.""" + collector = ServerMetricsDataCollector( + endpoint_url="http://localhost:8081/metrics", + collection_interval=0.5, + reachability_timeout=10.0, + collector_id="test_collector", + ) + + assert collector._endpoint_url == "http://localhost:8081/metrics" + assert collector._collection_interval == 0.5 + assert collector._reachability_timeout == 10.0 + assert collector.id == "test_collector" + assert collector._session is None + assert not collector.was_initialized + + def test_initialization_with_defaults(self): + """Test collector uses default values when not specified.""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + + assert collector._endpoint_url == "http://localhost:8081/metrics" + assert collector._collection_interval == 0.33 + assert collector.id == "server_metrics_collector" + + +class TestPrometheusMetricParsing: + """Test Prometheus metric parsing functionality.""" + + def test_parse_counter_metrics(self): + """Test parsing simple counter metrics.""" + metrics_text = """# HELP requests_total Total requests +# TYPE requests_total counter +requests_total{status="success"} 100.0 +requests_total{status="error"} 5.0 +""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + records = collector._parse_metrics_to_records(metrics_text, 1_000_000) + + assert len(records) == 1 + record = records[0] + assert "requests" in record.metrics + assert record.metrics["requests"].type == PrometheusMetricType.COUNTER + assert len(record.metrics["requests"].samples) == 2 + + def test_parse_gauge_metrics(self): + """Test parsing gauge metrics.""" + metrics_text = """# HELP gpu_utilization GPU utilization percentage +# TYPE gpu_utilization gauge +gpu_utilization{gpu="0"} 0.85 +gpu_utilization{gpu="1"} 0.92 +""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + records = collector._parse_metrics_to_records(metrics_text, 1_000_000) + + assert len(records) == 1 + record = records[0] + assert "gpu_utilization" in record.metrics + assert record.metrics["gpu_utilization"].type == PrometheusMetricType.GAUGE + assert len(record.metrics["gpu_utilization"].samples) == 2 + + def test_parse_histogram_metrics(self, sample_prometheus_metrics): + """Test parsing histogram metrics with buckets.""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + records = collector._parse_metrics_to_records( + sample_prometheus_metrics, 1_000_000 + ) + + assert len(records) == 1 + record = records[0] + assert "vllm:time_to_first_token_seconds" in record.metrics + + histogram_metric = record.metrics["vllm:time_to_first_token_seconds"] + assert histogram_metric.type == PrometheusMetricType.HISTOGRAM + assert len(histogram_metric.samples) == 1 + + sample = histogram_metric.samples[0] + assert sample.histogram is not None + assert len(sample.histogram.buckets) == 4 + assert sample.histogram.sum == 125.5 + assert sample.histogram.count == 150.0 + + def test_parse_summary_metrics(self): + """Test parsing summary metrics with quantiles.""" + metrics_text = """# HELP request_duration_seconds Request duration +# TYPE request_duration_seconds summary +request_duration_seconds{quantile="0.5"} 0.1 +request_duration_seconds{quantile="0.9"} 0.5 +request_duration_seconds{quantile="0.99"} 1.0 +request_duration_seconds_sum 50.0 +request_duration_seconds_count 100.0 +""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + records = collector._parse_metrics_to_records(metrics_text, 1_000_000) + + assert len(records) == 1 + record = records[0] + assert "request_duration_seconds" in record.metrics + + summary_metric = record.metrics["request_duration_seconds"] + assert summary_metric.type == PrometheusMetricType.SUMMARY + assert len(summary_metric.samples) == 1 + + sample = summary_metric.samples[0] + assert sample.summary is not None + assert len(sample.summary.quantiles) == 3 + assert sample.summary.sum == 50.0 + assert sample.summary.count == 100.0 + + def test_parse_mixed_metric_types(self, sample_prometheus_metrics): + """Test parsing response containing multiple metric types.""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + records = collector._parse_metrics_to_records( + sample_prometheus_metrics, 1_000_000 + ) + + assert len(records) == 1 + record = records[0] + + assert "vllm:request_success" in record.metrics + assert "vllm:gpu_cache_usage_perc" in record.metrics + assert "vllm:time_to_first_token_seconds" in record.metrics + + assert ( + record.metrics["vllm:request_success"].type == PrometheusMetricType.COUNTER + ) + assert ( + record.metrics["vllm:gpu_cache_usage_perc"].type + == PrometheusMetricType.GAUGE + ) + assert ( + record.metrics["vllm:time_to_first_token_seconds"].type + == PrometheusMetricType.HISTOGRAM + ) + + def test_skip_created_metrics(self): + """Test that _created metrics are skipped during parsing.""" + metrics_text = """# HELP requests_total Total requests +# TYPE requests_total counter +requests_total 100.0 +requests_total_created 1704067200.0 + +# HELP histogram_seconds Histogram metric +# TYPE histogram_seconds histogram +histogram_seconds_bucket{le="+Inf"} 50.0 +histogram_seconds_sum 5.0 +histogram_seconds_count 50.0 +histogram_seconds_created 1704067200.0 +""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + records = collector._parse_metrics_to_records(metrics_text, 1_000_000) + + assert len(records) == 1 + record = records[0] + + assert "requests" in record.metrics + assert "requests_created" not in record.metrics + assert "histogram_seconds" in record.metrics + assert "histogram_seconds_created" not in record.metrics + + def test_parse_metrics_with_labels(self): + """Test parsing metrics with multiple label combinations.""" + metrics_text = """# HELP http_requests_total Total HTTP requests +# TYPE http_requests_total counter +http_requests_total{method="GET",status="200"} 150.0 +http_requests_total{method="POST",status="200"} 75.0 +http_requests_total{method="GET",status="404"} 5.0 +""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + records = collector._parse_metrics_to_records(metrics_text, 1_000_000) + + assert len(records) == 1 + record = records[0] + assert "http_requests" in record.metrics + assert len(record.metrics["http_requests"].samples) == 3 + + def test_parse_empty_response(self): + """Test parsing empty or whitespace-only responses.""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + + empty_cases = ["", " \n\n "] + + for empty_data in empty_cases: + records = collector._parse_metrics_to_records(empty_data, 1_000_000) + assert len(records) == 0 + + def test_parse_invalid_format_raises_error(self): + """Test that invalid Prometheus format raises ValueError.""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + + # Invalid TYPE directive without metric name + invalid_format = "# HELP comment\n# TYPE comment" + + with pytest.raises(ValueError): + collector._parse_metrics_to_records(invalid_format, 1_000_000) + + def test_parse_incomplete_histogram(self): + """Test that incomplete histograms (missing sum/count) are skipped and result in empty snapshots.""" + metrics_text = """# HELP incomplete_histogram Incomplete histogram +# TYPE incomplete_histogram histogram +incomplete_histogram_bucket{le="0.01"} 5.0 +incomplete_histogram_bucket{le="+Inf"} 10.0 +""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + records = collector._parse_metrics_to_records(metrics_text, 1_000_000) + + # Empty snapshots are suppressed to reduce I/O noise + assert len(records) == 0 + + def test_record_metadata_populated(self): + """Test that ServerMetricsRecord metadata is correctly populated.""" + metrics_text = """# HELP test_metric Test metric +# TYPE test_metric counter +test_metric 1.0 +""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + records = collector._parse_metrics_to_records(metrics_text, 5_000_000) + + assert len(records) == 1 + record = records[0] + + assert record.endpoint_url == "http://localhost:8081/metrics" + assert record.endpoint_latency_ns == 5_000_000 + assert record.timestamp_ns > 0 + + +class TestMetricDeduplication: + """Test metric sample deduplication logic.""" + + def test_duplicate_counter_values_last_wins(self): + """Test that duplicate counter samples keep last value.""" + metrics_text = """# HELP test_counter Test counter +# TYPE test_counter counter +test_counter{label="a"} 10.0 +test_counter{label="a"} 20.0 +test_counter{label="a"} 30.0 +""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + records = collector._parse_metrics_to_records(metrics_text, 1_000_000) + + assert len(records) == 1 + samples = records[0].metrics["test_counter"].samples + + assert len(samples) == 1 + assert samples[0].value == 30.0 + + +class TestAsyncLifecycle: + """Test async lifecycle management.""" + + @pytest.mark.asyncio + async def test_initialization_creates_session(self): + """Test that initialization creates aiohttp session.""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + + await collector.initialize() + + assert collector._session is not None + assert isinstance(collector._session, aiohttp.ClientSession) + + await collector.stop() + + @pytest.mark.asyncio + async def test_stop_closes_session(self): + """Test that stop closes aiohttp session.""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + + await collector.initialize() + session = collector._session + + await collector.stop() + + assert session.closed + + @pytest.mark.asyncio + async def test_reachability_check_success(self): + """Test URL reachability check with successful response.""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + + with patch.object( + collector, "_check_reachability_with_session", new_callable=AsyncMock + ) as mock_check: + mock_check.return_value = True + + await collector.initialize() + is_reachable = await collector.is_url_reachable() + + assert is_reachable + mock_check.assert_called_once() + + await collector.stop() + + @pytest.mark.asyncio + async def test_reachability_check_failure(self): + """Test URL reachability check with failed response.""" + collector = ServerMetricsDataCollector("http://localhost:8081/metrics") + + with patch.object( + collector, "_check_reachability_with_session", new_callable=AsyncMock + ) as mock_check: + mock_check.return_value = False + + await collector.initialize() + is_reachable = await collector.is_url_reachable() + + assert not is_reachable + + await collector.stop() + + +class TestCallbackFunctionality: + """Test callback mechanisms for records and errors.""" + + @pytest.mark.asyncio + async def test_record_callback_invoked(self): + """Test that record callback is invoked with collected records.""" + record_callback = AsyncMock() + collector = ServerMetricsDataCollector( + "http://localhost:8081/metrics", + record_callback=record_callback, + collector_id="test_collector", + ) + + test_records = [ + ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={}, + ) + ] + + await collector._send_records_via_callback(test_records) + + record_callback.assert_called_once_with(test_records, "test_collector") + + @pytest.mark.asyncio + async def test_error_callback_invoked(self): + """Test that error callback is invoked on collection errors.""" + error_callback = AsyncMock() + collector = ServerMetricsDataCollector( + "http://localhost:8081/metrics", + error_callback=error_callback, + collector_id="test_collector", + ) + + await collector.initialize() + + with patch.object( + collector, + "_collect_and_process_metrics", + side_effect=ValueError("Test error"), + ): + await collector._collect_metrics_task() + + error_callback.assert_called_once() + args = error_callback.call_args[0] + assert isinstance(args[0], ErrorDetails) + assert args[1] == "test_collector" + + await collector.stop() + + @pytest.mark.asyncio + async def test_no_callback_on_empty_records(self): + """Test that record callback is not invoked for empty record list.""" + record_callback = AsyncMock() + collector = ServerMetricsDataCollector( + "http://localhost:8081/metrics", + record_callback=record_callback, + ) + + await collector._send_records_via_callback([]) + + record_callback.assert_not_called() diff --git a/tests/unit/server_metrics/test_server_metrics_manager.py b/tests/unit/server_metrics/test_server_metrics_manager.py new file mode 100644 index 000000000..64ce33006 --- /dev/null +++ b/tests/unit/server_metrics/test_server_metrics_manager.py @@ -0,0 +1,397 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock, patch + +import pytest + +from aiperf.common.config import EndpointConfig, ServiceConfig, UserConfig +from aiperf.common.enums import CommandType, EndpointType +from aiperf.common.messages import ProfileConfigureCommand, ProfileStartCommand +from aiperf.common.messages.server_metrics_messages import ServerMetricsRecordsMessage +from aiperf.common.models import ErrorDetails +from aiperf.common.models.server_metrics_models import ServerMetricsRecord +from aiperf.server_metrics.server_metrics_manager import ServerMetricsManager + + +@pytest.fixture +def user_config_with_endpoint() -> UserConfig: + """Create UserConfig with inference endpoint.""" + return UserConfig( + endpoint=EndpointConfig( + model_names=["test-model"], + type=EndpointType.CHAT, + url="http://localhost:8000/v1/chat", + ), + ) + + +@pytest.fixture +def user_config_with_server_metrics_urls() -> UserConfig: + """Create UserConfig with custom server metrics URLs.""" + return UserConfig( + endpoint=EndpointConfig( + model_names=["test-model"], + type=EndpointType.CHAT, + url="http://localhost:8000/v1/chat", + ), + server_metrics=[ + "http://custom-endpoint:9400/metrics", + "http://another-endpoint:8081", + ], + ) + + +class TestServerMetricsManagerInitialization: + """Test ServerMetricsManager initialization and endpoint discovery.""" + + def test_initialization_basic( + self, + service_config: ServiceConfig, + user_config_with_endpoint: UserConfig, + ): + """Test basic initialization with inference endpoint.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_endpoint, + ) + + assert manager._collectors == {} + # Should include inference port by default + assert manager._server_metrics_endpoints == ["http://localhost:8000/metrics"] + assert manager._collection_interval == 0.33 + + def test_endpoint_discovery_from_inference_url( + self, + service_config: ServiceConfig, + user_config_with_endpoint: UserConfig, + ): + """Test that inference endpoint port is discovered by default.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_endpoint, + ) + + # Should include inference port (localhost:8000) by default + assert len(manager._server_metrics_endpoints) == 1 + assert "localhost:8000" in manager._server_metrics_endpoints[0] + + def test_custom_server_metrics_urls_added( + self, + service_config: ServiceConfig, + user_config_with_server_metrics_urls: UserConfig, + ): + """Test that user-specified server metrics URLs are added to endpoint list.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_server_metrics_urls, + ) + + assert ( + "http://custom-endpoint:9400/metrics" in manager._server_metrics_endpoints + ) + assert ( + "http://another-endpoint:8081/metrics" in manager._server_metrics_endpoints + ) + + def test_duplicate_urls_avoided( + self, + service_config: ServiceConfig, + user_config_with_server_metrics_urls: UserConfig, + ): + """Test that duplicate URLs are deduplicated.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_server_metrics_urls, + ) + + endpoint_counts = {} + for endpoint in manager._server_metrics_endpoints: + endpoint_counts[endpoint] = endpoint_counts.get(endpoint, 0) + 1 + + for count in endpoint_counts.values(): + assert count == 1 + + +class TestProfileConfigureCommand: + """Test profile configuration and endpoint reachability checking.""" + + @pytest.mark.asyncio + async def test_configure_with_reachable_endpoints( + self, + service_config: ServiceConfig, + user_config_with_server_metrics_urls: UserConfig, + ): + """Test configuration when all endpoints are reachable.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_server_metrics_urls, + ) + + with patch( + "aiperf.server_metrics.server_metrics_manager.ServerMetricsDataCollector" + ) as mock_collector_class: + mock_collector = AsyncMock() + mock_collector.is_url_reachable = AsyncMock(return_value=True) + mock_collector_class.return_value = mock_collector + + await manager._profile_configure_command( + ProfileConfigureCommand( + service_id=manager.id, + command=CommandType.PROFILE_CONFIGURE, + config={}, + ) + ) + + assert len(manager._collectors) > 0 + + @pytest.mark.asyncio + async def test_configure_with_unreachable_endpoints( + self, + service_config: ServiceConfig, + user_config_with_endpoint: UserConfig, + ): + """Test configuration when no endpoints are reachable.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_endpoint, + ) + + with patch( + "aiperf.server_metrics.server_metrics_manager.ServerMetricsDataCollector" + ) as mock_collector_class: + mock_collector = AsyncMock() + mock_collector.is_url_reachable = AsyncMock(return_value=False) + mock_collector_class.return_value = mock_collector + + await manager._profile_configure_command( + ProfileConfigureCommand( + service_id=manager.id, + command=CommandType.PROFILE_CONFIGURE, + config={}, + ) + ) + + assert len(manager._collectors) == 0 + + @pytest.mark.asyncio + async def test_configure_clears_existing_collectors( + self, + service_config: ServiceConfig, + user_config_with_endpoint: UserConfig, + ): + """Test that configuration clears previous collectors.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_endpoint, + ) + + manager._collectors["old_collector"] = AsyncMock() + + with patch( + "aiperf.server_metrics.server_metrics_manager.ServerMetricsDataCollector" + ) as mock_collector_class: + mock_collector = AsyncMock() + mock_collector.is_url_reachable = AsyncMock(return_value=True) + mock_collector_class.return_value = mock_collector + + await manager._profile_configure_command( + ProfileConfigureCommand( + service_id=manager.id, + command=CommandType.PROFILE_CONFIGURE, + config={}, + ) + ) + + assert "old_collector" not in manager._collectors + + +class TestProfileStartCommand: + """Test profile start functionality.""" + + @pytest.mark.asyncio + async def test_start_initializes_and_starts_collectors( + self, + service_config: ServiceConfig, + user_config_with_endpoint: UserConfig, + ): + """Test that start command initializes and starts all collectors.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_endpoint, + ) + + mock_collector = AsyncMock() + manager._collectors["http://localhost:8081/metrics"] = mock_collector + + await manager._on_start_profiling( + ProfileStartCommand( + service_id=manager.id, command=CommandType.PROFILE_START + ) + ) + + mock_collector.initialize.assert_called_once() + mock_collector.start.assert_called_once() + + @pytest.mark.asyncio + async def test_start_with_no_collectors( + self, + service_config: ServiceConfig, + user_config_with_endpoint: UserConfig, + ): + """Test start command when no collectors configured.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_endpoint, + ) + + await manager._on_start_profiling( + ProfileStartCommand( + service_id=manager.id, command=CommandType.PROFILE_START + ) + ) + + @pytest.mark.asyncio + async def test_start_handles_initialization_failure( + self, + service_config: ServiceConfig, + user_config_with_endpoint: UserConfig, + ): + """Test start command handles collector initialization failures.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_endpoint, + ) + + mock_collector = AsyncMock() + mock_collector.initialize.side_effect = Exception("Initialization failed") + manager._collectors["http://localhost:8081/metrics"] = mock_collector + + await manager._on_start_profiling( + ProfileStartCommand( + service_id=manager.id, command=CommandType.PROFILE_START + ) + ) + + +class TestCallbackFunctionality: + """Test callback handling for records and errors.""" + + @pytest.mark.asyncio + async def test_record_callback_sends_message( + self, + service_config: ServiceConfig, + user_config_with_endpoint: UserConfig, + ): + """Test that record callback sends ServerMetricsRecordsMessage.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_endpoint, + ) + + manager.records_push_client.push = AsyncMock() + + test_records = [ + ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={}, + ) + ] + + await manager._on_server_metrics_records(test_records, "test_collector") + + manager.records_push_client.push.assert_called_once() + call_args = manager.records_push_client.push.call_args[0][0] + assert isinstance(call_args, ServerMetricsRecordsMessage) + assert call_args.records == test_records + + @pytest.mark.asyncio + async def test_error_callback_logs_error( + self, + service_config: ServiceConfig, + user_config_with_endpoint: UserConfig, + ): + """Test that error callback logs the error.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_endpoint, + ) + + test_error = ErrorDetails.from_exception(ValueError("Test error")) + + await manager._on_server_metrics_error(test_error, "test_collector") + + @pytest.mark.asyncio + async def test_record_callback_handles_send_failure( + self, + service_config: ServiceConfig, + user_config_with_endpoint: UserConfig, + ): + """Test that record callback handles message send failures gracefully.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_endpoint, + ) + + manager.records_push_client.push = AsyncMock( + side_effect=Exception("Send failed") + ) + + test_records = [ + ServerMetricsRecord( + endpoint_url="http://localhost:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={}, + ) + ] + + await manager._on_server_metrics_records(test_records, "test_collector") + + +class TestStopAllCollectors: + """Test stopping all collectors.""" + + @pytest.mark.asyncio + async def test_stop_all_collectors_calls_stop( + self, + service_config: ServiceConfig, + user_config_with_endpoint: UserConfig, + ): + """Test that stop_all_collectors stops each collector.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_endpoint, + ) + + mock_collector1 = AsyncMock() + mock_collector2 = AsyncMock() + manager._collectors = { + "endpoint1": mock_collector1, + "endpoint2": mock_collector2, + } + + await manager._stop_all_collectors() + + mock_collector1.stop.assert_called_once() + mock_collector2.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_stop_all_collectors_handles_failure( + self, + service_config: ServiceConfig, + user_config_with_endpoint: UserConfig, + ): + """Test that stop_all_collectors handles failures gracefully.""" + manager = ServerMetricsManager( + service_config=service_config, + user_config=user_config_with_endpoint, + ) + + mock_collector = AsyncMock() + mock_collector.stop.side_effect = Exception("Stop failed") + manager._collectors = {"endpoint1": mock_collector} + + await manager._stop_all_collectors() diff --git a/tests/unit/server_metrics/test_server_metrics_models.py b/tests/unit/server_metrics/test_server_metrics_models.py new file mode 100644 index 000000000..c589f9b9c --- /dev/null +++ b/tests/unit/server_metrics/test_server_metrics_models.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from aiperf.common.enums import PrometheusMetricType +from aiperf.common.models.server_metrics_models import ( + HistogramData, + MetricFamily, + MetricSample, + MetricSchema, + ServerMetricsRecord, + SlimMetricSample, + SummaryData, +) + + +class TestMetricSampleConversion: + """Test MetricSample to SlimMetricSample conversion.""" + + def test_counter_to_slim(self): + """Test converting simple counter metric to slim format.""" + sample = MetricSample( + labels={"model": "test-model", "status": "success"}, + value=100.0, + ) + slim = sample.to_slim() + + assert slim.labels == {"model": "test-model", "status": "success"} + assert slim.value == 100.0 + assert slim.histogram is None + assert slim.summary is None + + def test_gauge_to_slim(self): + """Test converting gauge metric to slim format.""" + sample = MetricSample(labels=None, value=0.75) + slim = sample.to_slim() + + assert slim.labels is None + assert slim.value == 0.75 + assert slim.histogram is None + + def test_histogram_to_slim(self): + """Test converting histogram metric to slim dict format.""" + sample = MetricSample( + labels={"model": "test"}, + histogram=HistogramData( + buckets={"0.01": 5.0, "0.1": 15.0, "1.0": 50.0, "+Inf": 100.0}, + sum=125.5, + count=100.0, + ), + ) + slim = sample.to_slim() + + assert slim.labels == {"model": "test"} + assert slim.value is None + assert slim.histogram == {"0.01": 5.0, "0.1": 15.0, "1.0": 50.0, "+Inf": 100.0} + assert slim.sum == 125.5 + assert slim.count == 100.0 + + def test_summary_to_slim(self): + """Test converting summary metric to slim dict format.""" + sample = MetricSample( + labels={"endpoint": "/v1/chat"}, + summary=SummaryData( + quantiles={"0.5": 0.1, "0.9": 0.5, "0.99": 1.0}, + sum=50.0, + count=100.0, + ), + ) + slim = sample.to_slim() + + assert slim.labels == {"endpoint": "/v1/chat"} + assert slim.summary == {"0.5": 0.1, "0.9": 0.5, "0.99": 1.0} + assert slim.sum == 50.0 + assert slim.count == 100.0 + + +class TestServerMetricsRecordConversion: + """Test ServerMetricsRecord to slim format conversion.""" + + def test_full_record_to_slim( + self, + sample_counter_metric: MetricFamily, + sample_histogram_metric: MetricFamily, + ): + """Test converting complete record with multiple metric types.""" + record = ServerMetricsRecord( + endpoint_url="http://node1:8081/metrics", + timestamp_ns=1_000_000_000, + endpoint_latency_ns=5_000_000, + metrics={ + "requests_total": sample_counter_metric, + "ttft": sample_histogram_metric, + }, + ) + + slim = record.to_slim() + + assert slim.endpoint_url == "http://node1:8081/metrics" + assert slim.timestamp_ns == 1_000_000_000 + assert slim.endpoint_latency_ns == 5_000_000 + assert len(slim.metrics) == 2 + assert "requests_total" in slim.metrics + assert "ttft" in slim.metrics + + assert isinstance(slim.metrics["requests_total"][0], SlimMetricSample) + assert slim.metrics["requests_total"][0].value == 150.0 + + assert isinstance(slim.metrics["ttft"][0], SlimMetricSample) + assert slim.metrics["ttft"][0].histogram is not None + + def test_slim_record_preserves_endpoint_url(self, sample_server_metrics_record): + """Test that endpoint_url is preserved in slim format.""" + slim = sample_server_metrics_record.to_slim() + assert slim.endpoint_url == sample_server_metrics_record.endpoint_url + + +class TestMetricSchema: + """Test MetricSchema model for metadata.""" + + def test_counter_schema(self): + """Test schema for counter metric.""" + schema = MetricSchema( + type=PrometheusMetricType.COUNTER, + description="Total number of requests", + ) + + assert schema.type == PrometheusMetricType.COUNTER + assert schema.description == "Total number of requests" + + def test_histogram_schema_with_buckets(self): + """Test schema for histogram.""" + schema = MetricSchema( + type=PrometheusMetricType.HISTOGRAM, + description="Request duration histogram", + ) + + assert schema.type == PrometheusMetricType.HISTOGRAM + assert schema.description == "Request duration histogram" + + def test_summary_schema_with_quantiles(self): + """Test schema for summary.""" + schema = MetricSchema( + type=PrometheusMetricType.SUMMARY, + description="Request latency quantiles", + ) + + assert schema.type == PrometheusMetricType.SUMMARY + assert schema.description == "Request latency quantiles" + + +class TestHistogramAndSummaryData: + """Test HistogramData and SummaryData models.""" + + def test_histogram_data_structure(self): + """Test HistogramData contains buckets, sum, and count.""" + hist = HistogramData( + buckets={"0.01": 10.0, "0.1": 25.0, "+Inf": 50.0}, + sum=5.5, + count=50.0, + ) + + assert len(hist.buckets) == 3 + assert hist.buckets["0.01"] == 10.0 + assert hist.sum == 5.5 + assert hist.count == 50.0 + + def test_summary_data_structure(self): + """Test SummaryData contains quantiles, sum, and count.""" + summary = SummaryData( + quantiles={"0.5": 0.1, "0.9": 0.5, "0.99": 1.0}, + sum=50.0, + count=100.0, + ) + + assert len(summary.quantiles) == 3 + assert summary.quantiles["0.5"] == 0.1 + assert summary.sum == 50.0 + assert summary.count == 100.0 + + def test_histogram_optional_fields(self): + """Test histogram with optional fields as None.""" + hist = HistogramData( + buckets={"0.01": 10.0}, + sum=None, + count=None, + ) + + assert hist.buckets == {"0.01": 10.0} + assert hist.sum is None + assert hist.count is None