Skip to content

Commit bc80dc9

Browse files
Fixes expected output matching in examples
1 parent 562a42a commit bc80dc9

File tree

3 files changed

+53
-66
lines changed

3 files changed

+53
-66
lines changed

tripy/examples/nanogpt/README.md

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,13 @@ for expected accuracy.
3434
python3 example.py --input-text "What is the answer to life, the universe, and everything?" --seed=0
3535
```
3636

37-
<!-- Tripy: TEST: EXPECTED_STDOUT Start -->
3837
<!--
38+
Tripy: TEST: EXPECTED_STDOUT Start
3939
```
40-
What is the answer to life, the universe, and everything\? How can we know what's real\? How can
41-
====
42-
(?s).*?
43-
What is the answer to life, the universe, and everything\? The answer to the questions that
40+
(?s).*?What is the answer to life, the universe, and everything\? (How can we know what's real\? How can|The answer to the questions that are asked of us)
4441
```
45-
-->
46-
<!-- Tripy: TEST: EXPECTED_STDOUT End -->
42+
Tripy: TEST: EXPECTED_STDOUT End
43+
-->
4744
4845
### Running with Quantization
4946
@@ -60,33 +57,26 @@ To run with a quantization mode, pass `--quant-mode` to `example.py`. The suppor
6057
```bash
6158
python3 example.py --input-text "What is the answer to life, the universe, and everything?" --seed=0 --quant-mode int8-weight-only
6259
```
63-
<!-- Tripy: TEST: EXPECTED_STDOUT Start -->
6460
<!--
61+
Tripy: TEST: EXPECTED_STDOUT Start
6562
```
66-
(?s).*?
67-
What is the answer to life, the universe, and everything\? How is life possible, what is the meaning of
68-
====
69-
(?s).*?
70-
What is the answer to life, the universe, and everything\? The answer to the questions that
71-
====
72-
(?s).*?
73-
What is the answer to life, the universe, and everything\? How can
63+
(?s).*?What is the answer to life, the universe, and everything\? (The answer to the questions that|How is life possible, what is the meaning of|How can)
7464
```
75-
-->
76-
<!-- Tripy: TEST: EXPECTED_STDOUT End -->
65+
Tripy: TEST: EXPECTED_STDOUT End
66+
-->
7767
7868
2. Weight-only int4 quantization:
7969
70+
*Note: `int4` quantization may result in poor accuracy for this model.*
71+
*We include it here primarily to demonstrate the workflow.*
72+
8073
```bash
8174
python3 example.py --input-text "What is the answer to life, the universe, and everything?" --seed=0 --quant-mode int4-weight-only
8275
```
83-
<!-- Tripy: TEST: EXPECTED_STDOUT Start -->
8476
<!--
77+
Tripy: TEST: EXPECTED_STDOUT Start
8578
```
86-
(?s).*?
87-
What is the answer to life, the universe, and everything\? What is what is what is what is what is
79+
(?s).*?What is the answer to life, the universe, and everything\? What is what is what is what is what is
8880
```
89-
-->
90-
<!-- Tripy: TEST: EXPECTED_STDOUT End -->
91-
92-
*Note: `int4` quantization may result in poor accuracy. We include it here primarily to demonstrate the workflow.*
81+
Tripy: TEST: EXPECTED_STDOUT End
82+
-->

tripy/examples/segment-anything-model-v2/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ This is an implementation of SAM2 model ([original repository](https://github.co
2828
python3 image_demo.py
2929
```
3030

31-
<!-- Tripy: TEST: EXPECTED_STDOUT Start -->
3231
<!--
32+
Tripy: TEST: EXPECTED_STDOUT Start
3333
```
3434
Scores for each prediction: {0.78759766~5%} {0.640625~5%} {0.05099487~5%}
3535
```
36-
-->
37-
<!-- Tripy: TEST: EXPECTED_STDOUT End -->
36+
Tripy: TEST: EXPECTED_STDOUT End
37+
-->
3838

3939
### Video segmentation pipeline
4040

tripy/tests/test_examples.py

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -93,30 +93,32 @@ def __str__(self):
9393
@pytest.mark.l1_release_package
9494
@pytest.mark.parametrize("example", EXAMPLES, ids=lambda case: str(case))
9595
def test_examples(example, sandboxed_install_run):
96-
97-
def test_with_tolerance(expected, actual, tolerance):
98-
return (abs(float(actual) - float(expected)) / float(expected)) * 100 <= float(tolerance)
99-
10096
def process_tolerances(expected_output):
101-
specs = []
102-
placeholder_regex = r"{(\d+\.?\d*)~(\d+)%}"
103-
pattern = expected_output
97+
# Adjusts the expected output into a regex that will be more lenient when matching
98+
# values with tolerances. The actual tolerance checks are done separately.
99+
tolerance_specs = []
100+
tolerance_regex = r"{(\d+\.?\d*)~(\d+)%}"
104101

105102
# Replace tolerance patterns with more flexible capture group
106-
matches = list(re.finditer(placeholder_regex, pattern))
103+
matches = list(re.finditer(tolerance_regex, expected_output))
104+
105+
if not matches:
106+
# If there are no tolerance values, don't modify the expected output:
107+
return expected_output, tolerance_specs
108+
107109
for match in matches:
108-
specs.append((match.group(1), match.group(2)))
109-
pattern = pattern.replace(match.group(0), r"(\d+\.?\d*)", 1)
110+
tolerance_specs.append((match.group(1), match.group(2)))
111+
expected_output = expected_output.replace(match.group(0), r"(\d+\.?\d*)", 1)
110112

111113
# Escape parentheses but not our capture group
112-
pattern = pattern.replace("(", r"\(")
113-
pattern = pattern.replace(")", r"\)")
114-
pattern = pattern.replace(r"\(\d+\.?\d*\)", r"(\d+\.?\d*)")
114+
expected_output = expected_output.replace("(", r"\(")
115+
expected_output = expected_output.replace(")", r"\)")
116+
expected_output = expected_output.replace(r"\(\d+\.?\d*\)", r"(\d+\.?\d*)")
115117

116118
# Make whitespace flexible
117-
pattern = pattern.replace(" ", r"\s+")
119+
expected_output = expected_output.replace(" ", r"\s+")
118120

119-
return pattern.strip(), specs
121+
return expected_output.strip(), tolerance_specs
120122

121123
with open(example.readme, "r", encoding="utf-8") as f:
122124
contents = f.read()
@@ -133,31 +135,26 @@ def process_tolerances(expected_output):
133135

134136
code = str(block)
135137
if block.has_marker("test: expected_stdout"):
136-
out = statuses[-1].stdout.strip()
137-
# expected = dedent(code).strip()
138-
expected_outs = dedent(code).split("====")
139-
for expected in expected_outs:
140-
pattern, specs = process_tolerances(expected)
141-
142-
# Apply the DOTALL flag to allow `.` to match newlines
143-
compiled_pattern = re.compile(pattern, re.DOTALL)
144-
match = compiled_pattern.search(out)
145-
146-
# match = re.search(pattern, out)
147-
if match and specs:
148-
# Check if captured numbers are within tolerance
149-
matched = all(
150-
test_with_tolerance(expected, actual, tolerance)
151-
for (expected, tolerance), actual in zip(specs, match.groups())
152-
)
153-
else:
154-
matched = bool(match)
155-
156-
if matched:
157-
break
138+
print("Checking command output against expected output: ", end="")
139+
actual = statuses[-1].stdout.strip()
140+
expected = dedent(code).strip()
141+
142+
expected, tolerance_specs = process_tolerances(expected)
143+
# Apply the DOTALL flag to allow `.` to match newlines
144+
expected = re.compile(expected, re.DOTALL)
145+
match = expected.search(actual)
146+
147+
# We always want to check if the text matched what we expected:
148+
matched = bool(match)
149+
# Additionally, check that numbers are within tolerance values if they were specified:
150+
if tolerance_specs:
151+
matched = matched and all(
152+
(abs(float(actual) - float(expected)) / float(expected)) * 100 <= float(tolerance)
153+
for (expected, tolerance), actual in zip(tolerance_specs, match.groups())
154+
)
158155

159156
print("matched!" if matched else "did not match!")
160-
print(f"==== STDOUT ====\n{out}")
157+
print(f"==== STDOUT ====\n{actual}")
161158
assert matched
162159
else:
163160
status = example.run(code, sandboxed_install_run)

0 commit comments

Comments
 (0)