Skip to content

Commit df41f22

Browse files
authored
[TRTLLM-8831][feat] Enable early exit with overlap scheduler (#8587)
Signed-off-by: Robin Kobus <[email protected]>
1 parent 6151a4c commit df41f22

File tree

16 files changed

+195
-145
lines changed

16 files changed

+195
-145
lines changed

cpp/tensorrt_llm/batch_manager/llmRequest.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ void LlmRequest::createSerializedResult(
6969
/// Note that there is some dependency on the order of operations in this method. Modify with care!
7070
std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int32_t mpiWorldRank)
7171
{
72-
if (!(isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS)))
72+
auto const streamingInProgress = mIsStreaming
73+
&& (mState == LlmRequestState::kGENERATION_IN_PROGRESS || mState == LlmRequestState::kGENERATION_TO_COMPLETE);
74+
if (!(isFinished() || streamingInProgress))
7375
{
7476
return std::nullopt;
7577
}

cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
6464
LlmRequestState>(),
6565
nb::arg("ctx_chunk_config") = std::nullopt, nb::arg("max_context_length") = std::nullopt,
6666
nb::arg("no_schedule_until_state") = LlmRequestState::kCONTEXT_INIT,
67-
nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_COMPLETE)
67+
nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_TO_COMPLETE)
6868
.def("__call__", &MicroBatchScheduler::operator(), nb::arg("active_requests"), nb::arg("inflight_req_ids"),
6969
nb::arg("max_batch_size_runtime"), nb::arg("max_num_tokens_runtime"))
7070
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ void initBindings(nb::module_& m)
103103
.def("get_last_tokens", nb::overload_cast<>(&GenLlmReq::getLastTokens))
104104
.def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, nb::arg("for_next_iteration") = false)
105105
.def_prop_ro("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens)
106+
.def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration)
106107
.def("add_new_token", &GenLlmReq::addNewToken, nb::arg("token"), nb::arg("beam"))
107108
.def("add_new_tokens", &GenLlmReq::addNewTokens, nb::arg("beam_tokens"))
108109
.def_prop_ro("num_draft_tokens", &GenLlmReq::getNumDraftTokens)

cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
6565
LlmRequestState>(),
6666
py::arg("ctx_chunk_config") = std::nullopt, py::arg("max_context_length") = std::nullopt,
6767
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
68-
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
69-
"LlmRequestState.GENERATION_COMPLETE"))
68+
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_TO_COMPLETE,
69+
"LlmRequestState.GENERATION_TO_COMPLETE"))
7070
.def("__call__", &MicroBatchScheduler::operator(), py::arg("active_requests"), py::arg("inflight_req_ids"),
7171
py::arg("max_batch_size_runtime"), py::arg("max_num_tokens_runtime"))
7272
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ void initBindings(pybind11::module_& m)
107107
.def("get_last_tokens", py::overload_cast<>(&GenLlmReq::getLastTokens))
108108
.def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, py::arg("for_next_iteration") = false)
109109
.def_property_readonly("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens)
110+
.def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration)
110111
.def("add_new_token", &GenLlmReq::addNewToken, py::arg("token"), py::arg("beam"))
111112
.def("add_new_tokens", &GenLlmReq::addNewTokens, py::arg("beam_tokens"))
112113
.def_property_readonly("num_draft_tokens", &GenLlmReq::getNumDraftTokens)

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def __init__(
153153
self.llm_args.batch_wait_timeout_iters = 0
154154
self.llm_args.batch_wait_max_tokens_ratio = 0.0
155155
self.llm_args.max_num_tokens = seq_info.max_num_tokens
156-
self.iter_counter = 0
157156

158157
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
159158
self.max_beam_width = max_beam_width

tensorrt_llm/_torch/expert_statistic.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,15 @@ def create(rank_id: int):
2929
rank_id, start, stop)
3030

3131
@staticmethod
32-
def set_iter(iter_id: int) -> bool:
32+
def should_record() -> bool:
3333
if ExpertStatistic.expert_statistic_obj is not None:
34-
return ExpertStatistic.expert_statistic_obj._set_iter(iter_id)
35-
else:
36-
return False
34+
return ExpertStatistic.expert_statistic_obj._should_record
35+
return False
36+
37+
@staticmethod
38+
def set_iter(iter_id: int) -> None:
39+
if ExpertStatistic.expert_statistic_obj is not None:
40+
ExpertStatistic.expert_statistic_obj._set_iter(iter_id)
3741

3842
@staticmethod
3943
def set_layer(layer_id: int) -> None:
@@ -57,10 +61,10 @@ def __init__(self, rank_id: int, start: int, stop: int) -> None:
5761
self._records = {}
5862

5963
@property
60-
def should_record(self) -> bool:
64+
def _should_record(self) -> bool:
6165
return self.current_iter_id is not None and self.start <= self.current_iter_id < self.stop
6266

63-
def _set_iter(self, iter_id: int) -> bool:
67+
def _set_iter(self, iter_id: int) -> None:
6468
self.current_iter_id = iter_id
6569
if iter_id == self.stop:
6670
logger.info(
@@ -74,14 +78,13 @@ def _set_iter(self, iter_id: int) -> bool:
7478
json.dump(self._meta_info, f)
7579
safetensors.torch.save_file(
7680
self._records, f"{path}/rank{self.rank_id}.safetensors")
77-
return self.should_record
7881

7982
def _set_layer(self, layer: int) -> None:
8083
self.current_layer = layer
8184

8285
def _maybe_add_info(self, expert_count: int,
8386
token_selected_experts: torch.Tensor) -> None:
84-
if not self.should_record:
87+
if not self._should_record:
8588
return
8689

8790
if self._meta_info is None:

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ def __del__(self):
164164
def maybe_get_cuda_graph(
165165
self,
166166
batch: ScheduledRequests,
167-
iter_counter: int,
168167
enable_spec_decode: bool,
169168
attn_metadata: Any,
170169
spec_metadata: Optional[Any] = None,
@@ -180,7 +179,7 @@ def maybe_get_cuda_graph(
180179
- The key for the graph, if applicable.
181180
"""
182181
# disable when doing statistic
183-
if ExpertStatistic.set_iter(iter_counter):
182+
if ExpertStatistic.should_record():
184183
return None, None, None
185184

186185
can_run_cuda_graph = batch.can_run_cuda_graph

tensorrt_llm/_torch/pyexecutor/handle_additional_outputs.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import torch
55

6-
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
6+
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
7+
LlmRequestState)
78
from tensorrt_llm._utils import nvtx_range
89
from tensorrt_llm.logger import logger
910

@@ -92,18 +93,19 @@ def __call__(
9293
(1, beam_width, 1)))
9394

9495
for llm_req in generation_requests:
95-
additional_outputs = llm_req.py_additional_outputs
96+
if llm_req.state != LlmRequestState.GENERATION_COMPLETE:
97+
additional_outputs = llm_req.py_additional_outputs
9698

97-
for name in additional_outputs:
98-
outputs_begin = (output_index_with_context
99-
if gather_context[name] else
100-
output_index_without_context)
101-
outputs_end = outputs_begin + beam_width
102-
103-
output_device_view = outputs[name][
104-
outputs_begin:outputs_end].reshape(1, beam_width, -1)
105-
llm_req.py_result.append_additional_generation_outputs(
106-
name, output_device_view)
99+
for name in additional_outputs:
100+
outputs_begin = (output_index_with_context
101+
if gather_context[name] else
102+
output_index_without_context)
103+
outputs_end = outputs_begin + beam_width
104+
105+
output_device_view = outputs[name][
106+
outputs_begin:outputs_end].reshape(1, beam_width, -1)
107+
llm_req.py_result.append_additional_generation_outputs(
108+
name, output_device_view)
107109

108110
output_index_with_context += beam_width
109111
output_index_without_context += beam_width

tensorrt_llm/_torch/pyexecutor/handle_logits.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import torch
55

6-
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
6+
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
7+
LlmRequestState)
78
from tensorrt_llm._utils import nvtx_range
89
from tensorrt_llm.logger import logger
910

@@ -72,6 +73,9 @@ def __call__(
7273

7374
total_context_logits = num_context_logits_prefix_sum[-1]
7475
for batch_index, llm_req in enumerate(generation_requests):
76+
if llm_req.state == LlmRequestState.GENERATION_COMPLETE:
77+
continue
78+
7579
logits_begin = total_context_logits + batch_index * beam_width
7680
logits_end = logits_begin + beam_width
7781

0 commit comments

Comments
 (0)