Skip to content

Commit ada0a8d

Browse files
evezhierRia Jain
authored andcommitted
[TRTLLM-5271][feat] best_of/n for pytorch workflow (NVIDIA#5997)
Signed-off-by: Olya Kozlova <[email protected]>
1 parent 424786b commit ada0a8d

File tree

9 files changed

+332
-56
lines changed

9 files changed

+332
-56
lines changed

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,9 @@ class GenericLlmRequest
467467
initialize(req.getInputTokenIds(), req.getOutputConfig().returnLogProbs);
468468
}
469469

470+
GenericLlmRequest(GenericLlmRequest&& request) = default;
471+
GenericLlmRequest(GenericLlmRequest const& request) = default;
472+
470473
void setExcludeInputFromOutput(bool exclude)
471474
{
472475
mExcludeInputFromOutput = exclude;
@@ -2318,6 +2321,9 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
23182321
mKvCacheRetentionConfig = request.getKvCacheRetentionConfig();
23192322
}
23202323

2324+
LlmRequest(LlmRequest&& request) = default;
2325+
LlmRequest(LlmRequest const& request) = default;
2326+
23212327
/// @brief Create a Response from the current state of the request
23222328
/// @details Note that there is some dependency on the order of operations in this method. Modify with care!
23232329
/// @return An optional Response

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ void initBindings(nb::module_& m)
187187
.def_prop_ro("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest)
188188
.def_prop_ro("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest)
189189
.def_prop_ro("llm_request_type", &GenLlmReq::getLlmRequestType)
190+
.def_prop_ro("parent_request_id", &GenLlmReq::getParentRequestId)
191+
.def_prop_ro("is_child", &GenLlmReq::isChild)
190192
.def_prop_ro("multimodal_hashes",
191193
[](GenLlmReq& self)
192194
{
@@ -351,11 +353,13 @@ void initBindings(nb::module_& m)
351353
nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt,
352354
nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt,
353355
nb::arg("context_phase_params") = std::nullopt)
356+
.def(nb::init<tb::LlmRequest const&>())
354357
.def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"),
355358
nb::arg("max_draft_len"), nb::arg("vocab_size_padded"), nb::arg("max_endocer_input_len") = std::nullopt,
356359
nb::arg("enable_kv_cache_reuse") = false)
357360
.def("create_response", &tb::LlmRequest::createResponse, nb::arg("use_fast_logits") = false,
358361
nb::arg("mpi_world_rank") = 0)
362+
.def("create_child_request", &tb::LlmRequest::createChildRequest, nb::arg("child_id"))
359363
.def("create_result", &tb::LlmRequest::createResult, nb::arg("use_fast_logits") = false,
360364
nb::arg("mpi_world_rank") = 0)
361365
.def("create_serialized_result",

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ void initBindings(pybind11::module_& m)
192192
.def_property_readonly("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest)
193193
.def_property_readonly("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest)
194194
.def_property_readonly("llm_request_type", &GenLlmReq::getLlmRequestType)
195+
.def_property_readonly("parent_request_id", &GenLlmReq::getParentRequestId)
196+
.def_property_readonly("is_child", &GenLlmReq::isChild)
195197
.def_property_readonly("multimodal_hashes",
196198
[](GenLlmReq& self)
197199
{
@@ -254,7 +256,7 @@ void initBindings(pybind11::module_& m)
254256
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics);
255257

256258
py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
257-
.def(py::init(
259+
.def(py::init<>(
258260
[](tb::LlmRequest::RequestIdType request_id, tb::LlmRequest::SizeType32 max_new_tokens,
259261
std::vector<tb::LlmRequest::TokenIdType> input_tokens, runtime::SamplingConfig sampling_config,
260262
bool is_streaming, std::optional<tb::LlmRequest::SizeType32> end_id,
@@ -357,11 +359,13 @@ void initBindings(pybind11::module_& m)
357359
py::arg("return_perf_metrics") = false, py::arg("guided_decoding_params") = std::nullopt,
358360
py::arg("language_adapter_uid") = std::nullopt, py::arg("allotted_time_ms") = std::nullopt,
359361
py::arg("context_phase_params") = std::nullopt)
362+
.def(py::init<tb::LlmRequest const&>())
360363
.def("validate", &tb::LlmRequest::validate, py::arg("max_input_len"), py::arg("max_seq_len"),
361364
py::arg("max_draft_len"), py::arg("vocab_size_padded"), py::arg("max_endocer_input_len") = std::nullopt,
362365
py::arg("enable_kv_cache_reuse") = false)
363366
.def("create_response", &tb::LlmRequest::createResponse, py::arg("use_fast_logits") = false,
364367
py::arg("mpi_world_rank") = 0)
368+
.def("create_child_request", &tb::LlmRequest::createChildRequest, py::arg("child_id"))
365369
.def("create_result", &tb::LlmRequest::createResult, py::arg("use_fast_logits") = false,
366370
py::arg("mpi_world_rank") = 0)
367371
.def("create_serialized_result",

examples/llm-api/quickstart_advanced.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def add_llm_args(parser):
107107
parser.add_argument("--top_k", type=int, default=None)
108108
parser.add_argument("--top_p", type=float, default=None)
109109
parser.add_argument('--load_format', type=str, default='auto')
110+
parser.add_argument('--n', type=int, default=1)
111+
parser.add_argument('--best_of', type=int, default=None)
110112
parser.add_argument('--max_beam_width', type=int, default=1)
111113

112114
# Speculative decoding
@@ -193,6 +195,7 @@ def setup_llm(args, **kwargs):
193195
batch_sizes=args.cuda_graph_batch_sizes,
194196
enable_padding=args.cuda_graph_padding_enabled,
195197
) if args.use_cuda_graph else None
198+
196199
llm = LLM(
197200
model=args.model_dir,
198201
backend='pytorch',
@@ -228,6 +231,15 @@ def setup_llm(args, **kwargs):
228231
**kwargs,
229232
)
230233

234+
use_beam_search = args.max_beam_width > 1
235+
best_of = args.best_of or args.n
236+
if use_beam_search:
237+
if args.n == 1 and args.best_of is None:
238+
args.n = args.max_beam_width
239+
assert best_of <= args.max_beam_width, f"beam width: {best_of}, should be less or equal to max_beam_width: {args.max_beam_width}"
240+
241+
assert best_of >= args.n, f"In sampling mode best_of value: {best_of} should be less or equal to n: {args.n}"
242+
231243
sampling_params = SamplingParams(
232244
max_tokens=args.max_tokens,
233245
temperature=args.temperature,
@@ -236,8 +248,9 @@ def setup_llm(args, **kwargs):
236248
return_context_logits=args.return_context_logits,
237249
return_generation_logits=args.return_generation_logits,
238250
logprobs=args.logprobs,
239-
n=args.max_beam_width,
240-
use_beam_search=args.max_beam_width > 1)
251+
n=args.n,
252+
best_of=best_of,
253+
use_beam_search=use_beam_search)
241254
return llm, sampling_params
242255

243256

@@ -250,23 +263,23 @@ def main():
250263

251264
for i, output in enumerate(outputs):
252265
prompt = output.prompt
253-
for beam_idx, beam in enumerate(output.outputs):
254-
generated_text = beam.text
266+
for sequence_idx, sequence in enumerate(output.outputs):
267+
generated_text = sequence.text
255268
# Skip printing the beam_idx if no beam search was used
256-
beam_id_text = f"[{beam_idx}]" if args.max_beam_width > 1 else ""
269+
sequence_id_text = f"[{sequence_idx}]" if args.max_beam_width > 1 or args.n > 1 else ""
257270
print(
258-
f"[{i}]{beam_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
271+
f"[{i}]{sequence_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
259272
)
260273
if args.return_context_logits:
261274
print(
262-
f"[{i}]{beam_id_text} Context logits: {output.context_logits}"
275+
f"[{i}]{sequence_id_text} Context logits: {output.context_logits}"
263276
)
264277
if args.return_generation_logits:
265278
print(
266-
f"[{i}]{beam_id_text} Generation logits: {beam.generation_logits}"
279+
f"[{i}]{sequence_id_text} Generation logits: {sequence.generation_logits}"
267280
)
268281
if args.logprobs:
269-
print(f"[{i}]{beam_id_text} Logprobs: {beam.logprobs}")
282+
print(f"[{i}]{sequence_id_text} Logprobs: {sequence.logprobs}")
270283

271284

272285
if __name__ == '__main__':

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
class RequestQueueItem:
2323
id: int
2424
request: Optional[ExecutorRequest] = None
25+
child_req_ids: Optional[list] = None
2526
is_canceled_request: bool = False
2627
query: Optional[list] = None # only used in `StarAttention`
2728

@@ -83,6 +84,12 @@ def _get_from_request_queue(
8384
pass
8485
return items
8586

87+
@staticmethod
88+
def _get_num_child_requests(request: ExecutorRequest) -> int:
89+
sampling_config = request.sampling_config
90+
return 0 if sampling_config.beam_width > 1 else (
91+
sampling_config.num_return_sequences or 1) - 1
92+
8693
def _get_from_waiting_queue(
8794
self,
8895
waiting_queue: deque[RequestQueueItem],
@@ -111,14 +118,19 @@ def _get_from_waiting_queue(
111118
scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests.copy(
112119
) if enable_attention_dp else None
113120
while req_count < max_req_count and waiting_queue:
121+
req_item = waiting_queue[0]
122+
num_children = len(
123+
req_item.child_req_ids) if req_item.child_req_ids else 0
124+
if (req_count + 1 + num_children) > max_req_count:
125+
break
114126
req_item = waiting_queue.popleft()
115127
can_process = self._can_process_attention_dp_request(
116128
req_item, scheduling_all_ranks_num_active_requests
117129
) if enable_attention_dp else True
118130

119131
if can_process:
120132
items.append(req_item)
121-
req_count += 1
133+
req_count += 1 + num_children
122134
else:
123135
pending_requests.append(req_item)
124136

@@ -149,17 +161,43 @@ def _can_process_attention_dp_request(
149161

150162
return False
151163

164+
def _get_request_id(self):
165+
# (next_request_id + 1) % UINT64_MAX
166+
current_id = self.next_request_id
167+
self.next_request_id = (self.next_request_id + 1) & ((1 << 64) - 1)
168+
return current_id
169+
170+
def _generate_child_request_ids(
171+
self, request: ExecutorRequest) -> List[int] | None:
172+
""" Generate child request IDs if needed. """
173+
child_req_ids = None
174+
num_children = self._get_num_child_requests(request)
175+
if num_children > 0:
176+
child_req_ids = []
177+
for _ in range(num_children):
178+
child_req_id = self._get_request_id()
179+
if self.enable_iter_perf_stats:
180+
self.start_times[child_req_id] = time.time()
181+
child_req_ids.append(child_req_id)
182+
183+
return child_req_ids
184+
152185
def enqueue_requests(self, requests: List[ExecutorRequest]):
153186
req_ids = []
154187
try:
155188
self.enqueue_lock.acquire()
156-
start_time = time.time()
157189
for request in requests:
158-
self.start_times[self.next_request_id] = start_time
190+
req_id = self._get_request_id()
191+
192+
if self.enable_iter_perf_stats:
193+
self.start_times[req_id] = time.time()
194+
195+
child_req_ids = self._generate_child_request_ids(request)
159196
self.request_queue.put(
160-
RequestQueueItem(self.next_request_id, request))
161-
req_ids.append(self.next_request_id)
162-
self.next_request_id += 1
197+
RequestQueueItem(req_id, request, child_req_ids,
198+
query=None))
199+
200+
req_ids.append(req_id)
163201
finally:
164202
self.enqueue_lock.release()
165203
return req_ids
@@ -186,15 +224,18 @@ def enqueue_request(self,
186224
try:
187225
self.enqueue_lock.acquire()
188226
assert self.active, "PyExecutor has already been shutdown."
189-
req_id = self.next_request_id
227+
req_id = self._get_request_id()
190228
if self.enable_iter_perf_stats:
191229
self.start_times[req_id] = time.time()
192230

193-
if query is not None:
194-
self.request_queue.put(RequestQueueItem(req_id, request, query))
195-
else:
196-
self.request_queue.put(RequestQueueItem(req_id, request))
197-
self.next_request_id += 1
231+
child_req_ids = self._generate_child_request_ids(request)
232+
self.request_queue.put(
233+
RequestQueueItem(
234+
req_id,
235+
request,
236+
child_req_ids=child_req_ids,
237+
query=query,
238+
))
198239
finally:
199240
self.enqueue_lock.release()
200241

@@ -530,6 +571,10 @@ def _update_new_active_requests_queue_latency(
530571
if req_item.id in self.start_times:
531572
self.new_active_requests_queue_latency_ms += now - self.start_times.pop(
532573
req_item.id)
574+
if req_item.child_req_ids:
575+
for child_id in req_item.child_req_ids:
576+
self.new_active_requests_queue_latency_ms += now - self.start_times.pop(
577+
child_id)
533578

534579
@nvtx_range("_merge_requests")
535580
def _merge_requests(self, new_requests: list[RequestQueueItem]):
@@ -543,12 +588,15 @@ def _merge_requests(self, new_requests: list[RequestQueueItem]):
543588
else:
544589
raise NotImplementedError(f'unsupport cp type {cp_type}')
545590
else:
546-
return [
547-
executor_request_to_llm_request(
548-
req_item.id, req_item.request,
591+
req_with_children = []
592+
for req_item in new_requests:
593+
req = executor_request_to_llm_request(
594+
req_item.id, req_item.request, req_item.child_req_ids,
549595
self._should_exclude_last_generation_logits())
550-
for req_item in new_requests
551-
]
596+
req_with_children.append(req)
597+
if req.child_requests:
598+
req_with_children.extend(req.child_requests)
599+
return req_with_children
552600

553601
def _merge_star_attention_requests(self,
554602
new_requests: list[RequestQueueItem]):

0 commit comments

Comments
 (0)