Skip to content

Commit 24e8048

Browse files
authored
Add --ignore-eos flag (#24)
* first commit * added note about being vllm specific
1 parent 1330e4d commit 24e8048

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

benchmark_serving.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ async def send_stream_request(
157157
prompt: str,
158158
prompt_len: int,
159159
output_len: int,
160+
ignore_eos: bool,
160161
best_of: int,
161162
use_beam_search: bool,
162163
top_k: int,
@@ -180,7 +181,7 @@ async def send_stream_request(
180181
"temperature": 0.0 if use_beam_search else 1.0,
181182
"top_p": 1.0,
182183
"max_tokens": output_len,
183-
"ignore_eos": True,
184+
"ignore_eos": ignore_eos,
184185
"stream": True,
185186
}
186187
elif backend == "jetstream":
@@ -264,6 +265,7 @@ async def send_request(
264265
prompt: str,
265266
prompt_len: int,
266267
output_len: int,
268+
ignore_eos: bool,
267269
best_of: int,
268270
use_beam_search: bool,
269271
top_k: int,
@@ -287,7 +289,7 @@ async def send_request(
287289
"temperature": 0.0 if use_beam_search else 1.0,
288290
"top_p": 1.0,
289291
"max_tokens": output_len,
290-
"ignore_eos": False,
292+
"ignore_eos": ignore_eos,
291293
"stream": False,
292294
}
293295
elif backend == "tgi":
@@ -418,11 +420,11 @@ async def run_single_request(args: argparse.Namespace, api_url: str, tokenizer:
418420
prompt: str, prompt_len: int, output_len: int, chosen_model: str) -> Tuple[str, Tuple]:
419421
if args.stream_request:
420422
result = await send_stream_request(
421-
args.backend, api_url, prompt, prompt_len, output_len,
423+
args.backend, api_url, prompt, prompt_len, output_len, args.ignore_eos,
422424
args.best_of, args.use_beam_search, args.top_k, tokenizer, args.sax_model, chosen_model, args.request_timeout,)
423425
else:
424426
result = await send_request(
425-
args.backend, api_url, prompt, prompt_len, output_len,
427+
args.backend, api_url, prompt, prompt_len, output_len, args.ignore_eos,
426428
args.best_of, args.use_beam_search, args.top_k, tokenizer, args.sax_model, chosen_model, args.request_timeout,)
427429
return chosen_model, result
428430

@@ -973,6 +975,14 @@ def parse_traffic_split(arg):
973975
"Maximum number of input tokens for filtering the benchmark dataset."
974976
),
975977
)
978+
parser.add_argument(
979+
"--ignore-eos",
980+
action="store_true",
981+
help=(
982+
"If set and model server is vllm, the generation process will ignore the end-of-sequence (EOS) token, "
983+
"allowing output to continue until reaching --max-output-length or another stopping condition."
984+
),
985+
)
976986
parser.add_argument(
977987
"--top-k",
978988
type=int,

0 commit comments

Comments
 (0)