Skip to content

Commit ee30299

Browse files
authored
Change topK to 10 in nanoGPT sample (#363)
Note: top K helps stablize the result a bit but not much, the result of `int8-weight-only` can vary among several outputs on CI machine, but cannot be reproduced locally. Worth looking into this issue later on. --------- Signed-off-by: yizhuoz004 <[email protected]>
1 parent 087e0d5 commit ee30299

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

tripy/examples/nanogpt/README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ for expected accuracy.
3838
<!--
3939
```
4040
(?s).*?
41-
What is the answer to life, the universe, and everything\? How can we get back at these questions\? And
41+
What is the answer to life, the universe, and everything\? How can we know what's real\? How can
42+
====
43+
(?s).*?
44+
What is the answer to life, the universe, and everything\? The answer to the questions that
4245
```
4346
-->
4447
<!-- Tripy: TEST: EXPECTED_STDOUT End -->
@@ -62,7 +65,13 @@ To run with a quantization mode, pass `--quant-mode` to `example.py`. The suppor
6265
<!--
6366
```
6467
(?s).*?
68+
What is the answer to life, the universe, and everything\? The answer to the questions that
69+
====
70+
(?s).*?
6571
What is the answer to life, the universe, and everything\? How is life possible, what is the meaning of
72+
====
73+
(?s).*?
74+
What is the answer to life, the universe, and everything\? How can
6675
```
6776
-->
6877
<!-- Tripy: TEST: EXPECTED_STDOUT End -->

tripy/examples/nanogpt/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def main():
8282
input_ids = encoder.encode(args.input_text, allowed_special={"<|endoftext|>"})
8383

8484
TEMPERATURE = 0.8
85-
TOP_K = 200
85+
TOP_K = 5
8686

8787
padded_seq_len = len(input_ids) + args.max_new_tokens
8888

tripy/tests/test_examples.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,12 @@ def test_examples(example, sandboxed_install_run):
103103
if block.has_marker("test: expected_stdout"):
104104
print("Checking command output against expected output: ", end="")
105105
out = statuses[-1].stdout.strip()
106-
matched = re.match(dedent(block_text).strip(), out)
106+
matched = False
107+
expected_outs = dedent(block_text).split("====")
108+
for expected in expected_outs:
109+
if re.match(expected.strip(), out):
110+
matched = True
111+
break
107112
print("matched!" if matched else "did not match!")
108113
print(f"==== STDOUT ====\n{out}")
109114
assert matched

0 commit comments

Comments
 (0)