Skip to content

Commit 56e4f5e

Browse files
committed
tests: add unit tests, rename, fixes
1 parent a9a9e80 commit 56e4f5e

File tree

3 files changed

+127
-4
lines changed

3 files changed

+127
-4
lines changed

src/aiperf/exporters/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
## ⚠️ This file is auto-generated by mkinit ⚠️ ##
99
## ⚠️ Do not edit below this line ⚠️ ##
1010
########################################################################
11-
from aiperf.exporters.console_api_error_insight_exporter import (
12-
ConsoleApiErrorInsightExporter,
11+
from aiperf.exporters.console_api_error_exporter import (
12+
ConsoleApiErrorExporter,
1313
ErrorInsight,
1414
MaxCompletionTokensDetector,
1515
)
@@ -60,7 +60,7 @@
6060
)
6161

6262
__all__ = [
63-
"ConsoleApiErrorInsightExporter",
63+
"ConsoleApiErrorExporter",
6464
"ConsoleErrorExporter",
6565
"ConsoleExperimentalMetricsExporter",
6666
"ConsoleInternalMetricsExporter",

src/aiperf/exporters/console_api_error_insight_exporter.py renamed to src/aiperf/exporters/console_api_error_exporter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def detect(error_summary):
7979
fixes=[
8080
"Remove --output-tokens-mean.",
8181
'Or use --extra-inputs "max_tokens:<value>".',
82+
"Or run AIPerf with '--use-legacy-max-tokens' to force use of the legacy 'max_tokens' field instead of 'max_completion_tokens'.",
8283
],
8384
)
8485

@@ -87,7 +88,7 @@ def detect(error_summary):
8788

8889
@implements_protocol(ConsoleExporterProtocol)
8990
@ConsoleExporterFactory.register(ConsoleExporterType.API_ERRORS)
90-
class ConsoleApiErrorInsightExporter(AIPerfLoggerMixin):
91+
class ConsoleApiErrorExporter(AIPerfLoggerMixin):
9192
"""Displays helpful diagnostic panels for known API error patterns."""
9293

9394
DETECTORS = [
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import json
4+
from unittest.mock import MagicMock
5+
6+
import pytest
7+
from rich.console import Console
8+
9+
from aiperf.exporters.console_api_error_exporter import (
10+
ConsoleApiErrorExporter,
11+
MaxCompletionTokensDetector,
12+
)
13+
from aiperf.exporters.exporter_config import ExporterConfig
14+
15+
16+
class MockErrorDetails:
17+
def __init__(
18+
self, code=400, type="Bad Request", message="", cause=None, details=None
19+
):
20+
self.code = code
21+
self.type = type
22+
self.message = message
23+
self.cause = cause
24+
self.details = details
25+
26+
27+
class MockErrorDetailsCount:
28+
def __init__(self, error_details, count):
29+
self.error_details = error_details
30+
self.count = count
31+
32+
33+
def make_summary(err):
34+
return [MockErrorDetailsCount(err, 1)]
35+
36+
37+
@pytest.fixture
38+
def basic_error_payload():
39+
return json.dumps(
40+
{
41+
"message": "[{'type': 'extra_forbidden','loc': ('body','max_completion_tokens'),"
42+
"'msg': 'Extra inputs are not permitted'}]"
43+
}
44+
)
45+
46+
47+
def test_detector_detects_max_completion_tokens_error(basic_error_payload):
48+
"""Detector should return an ErrorInsight when TRT-LLM style error appears."""
49+
err = MockErrorDetails(message=basic_error_payload)
50+
summary = make_summary(err)
51+
52+
detector = MaxCompletionTokensDetector()
53+
insight = detector.detect(summary)
54+
55+
assert insight is not None
56+
assert "max_completion_tokens" in insight.problem
57+
assert "max_tokens" in insight.problem
58+
assert any("max_completion_tokens" in c for c in insight.causes)
59+
60+
61+
def test_detector_returns_none_for_unrelated_error():
62+
err = MockErrorDetails(message='{"message": "context_length_exceeded"}')
63+
summary = make_summary(err)
64+
65+
detector = MaxCompletionTokensDetector()
66+
assert detector.detect(summary) is None
67+
68+
69+
def test_detector_returns_none_when_no_errors():
70+
detector = MaxCompletionTokensDetector()
71+
assert detector.detect(None) is None
72+
assert detector.detect([]) is None
73+
74+
75+
def test_exporter_prints_panel_for_detected_error(basic_error_payload):
76+
"""Exporter should print a Rich panel when the detector returns an insight."""
77+
mock_console = MagicMock(spec=Console)
78+
79+
err = MockErrorDetails(message=basic_error_payload)
80+
error_summary = make_summary(err)
81+
82+
exporter_config = MagicMock(spec=ExporterConfig)
83+
exporter_config.results = MagicMock()
84+
exporter_config.results.error_summary = error_summary
85+
86+
exporter = ConsoleApiErrorExporter(exporter_config)
87+
88+
import asyncio
89+
90+
asyncio.run(exporter.export(mock_console))
91+
92+
assert mock_console.print.call_count >= 2
93+
94+
_, args, _ = mock_console.print.mock_calls[1]
95+
panel = args[0]
96+
97+
assert hasattr(panel, "renderable")
98+
panel_text = str(panel.renderable)
99+
panel_title = str(panel.title)
100+
101+
assert "Unsupported Parameter: max_completion_tokens" in panel_title
102+
103+
assert "The backend rejected 'max_completion_tokens'" in panel_text
104+
assert "This backend only supports 'max_tokens'." in panel_text
105+
106+
assert "--use-legacy-max-tokens" in panel_text
107+
108+
109+
def test_exporter_skips_when_no_insight():
110+
mock_console = MagicMock(spec=Console)
111+
112+
exporter_config = MagicMock(spec=ExporterConfig)
113+
exporter_config.results = MagicMock()
114+
exporter_config.results.error_summary = []
115+
116+
exporter = ConsoleApiErrorExporter(exporter_config)
117+
118+
import asyncio
119+
120+
asyncio.run(exporter.export(mock_console))
121+
122+
assert mock_console.print.call_count == 0

0 commit comments

Comments
 (0)