Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 79 additions & 61 deletions tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,73 +204,85 @@ def __init__(self,
def bitmask_size(self) -> int:
return math.ceil(self.vocab_size_padded / 32)

def _build(self, requests: GuidedRequests) -> None:
def _build(self, requests: GuidedRequests) -> List[Tuple[int, str]]:
"""Build the bitmask for requests with guided decoding enabled.

Specifically, this method:
- build and advance the grammar matcher for context and generation requests, respectively;
- call the grammar matcher to fill the bitmask on CPU;
- asynchronously copy the bitmask to GPU.
"""
failed_requests = []
self.token_mask_host[:requests.num_bitmask_tokens].fill_(0)

for req, offset in requests.valid_requests_with_offsets():
slot = req.seq_slot
self.num_advanced_tokens[slot] = 0
self.num_guided_tokens[slot] = 0
try:
self.num_advanced_tokens[slot] = 0
self.num_guided_tokens[slot] = 0

matcher_init: bool = req.require_matcher_init()
matcher_advance: bool = req.require_matcher_advance()
if not (matcher_init or matcher_advance):
continue

if matcher_init:
matcher = self.grammar_matcher_factory.create(
req.guided_decoding_params)
self.grammar_matchers[slot] = matcher

if matcher_advance:
matcher = self.grammar_matchers[slot]
# The last new token must be acceptable unless the matcher is terminated:
# 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration.
# 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration.
if matcher.is_terminated() or self.is_draft_terminated[slot]:
matcher_init: bool = req.require_matcher_init()
matcher_advance: bool = req.require_matcher_advance()
if not (matcher_init or matcher_advance):
continue
accepted = matcher.accept_token(req.new_token)
if not accepted:
if req.is_draft:
self.is_draft_terminated[slot] = True
logger.debug(
f"Draft request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}."
)

if matcher_init:
matcher = self.grammar_matcher_factory.create(
req.guided_decoding_params)
self.grammar_matchers[slot] = matcher

if matcher_advance:
matcher = self.grammar_matchers[slot]
# The last new token must be acceptable unless the matcher is terminated or None:
# 1. For the main model loop, when overlap scheduler is enabled, the matcher may have accepted the EOS token in the draft tokens at the previous iteration.
# 2. For the draft model loop, the matcher may have accepted the EOS token at the previous drafting iteration.
# 3. The matcher can be None if there was an error during its creation.
if matcher is None or matcher.is_terminated(
) or self.is_draft_terminated[slot]:
continue
# TODO: Make this an error response.
raise ValueError(
f"Request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}."
)

self.num_advanced_tokens[slot] += 1
if not matcher.is_terminated():
matcher.fill_next_token_bitmask(self.bitmask_host, offset)
self.token_mask_host[offset] = 1
self.num_guided_tokens[slot] += 1
# Process draft tokens
for i, tid in enumerate(req.draft_tokens, 1):
accepted = matcher.accept_token(tid)
accepted = matcher.accept_token(req.new_token)
if not accepted:
break
self.num_advanced_tokens[slot] += 1
if matcher.is_terminated():
break
matcher.fill_next_token_bitmask(self.bitmask_host,
offset + i)
self.token_mask_host[offset + i] = 1
self.num_guided_tokens[slot] += 1
if req.is_draft:
self.is_draft_terminated[slot] = True
logger.debug(
f"Draft request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}."
)
continue
# TODO: Make this an error response.
raise ValueError(
f"Request {req.request_id} at slot {slot} failed to accept last new token: {req.new_token}."
)

if req.is_draft:
assert len(req.draft_tokens) == 0
self.num_advanced_draft_tokens[
slot] += self.num_advanced_tokens[slot]
self.num_advanced_tokens[slot] += 1
if not matcher.is_terminated():
matcher.fill_next_token_bitmask(self.bitmask_host, offset)
self.token_mask_host[offset] = 1
self.num_guided_tokens[slot] += 1
# Process draft tokens
for i, tid in enumerate(req.draft_tokens, 1):
accepted = matcher.accept_token(tid)
if not accepted:
break
self.num_advanced_tokens[slot] += 1
if matcher.is_terminated():
break
matcher.fill_next_token_bitmask(self.bitmask_host,
offset + i)
self.token_mask_host[offset + i] = 1
self.num_guided_tokens[slot] += 1

if req.is_draft:
assert len(req.draft_tokens) == 0
self.num_advanced_draft_tokens[
slot] += self.num_advanced_tokens[slot]
except Exception as e:
error_msg = f"Guided decoding error: {str(e)}"
failed_requests.append((req.request_id, error_msg))
logger.error(
f"Request {req.request_id} at slot {slot} failed during guided decoding: {error_msg}",
exc_info=True)

return failed_requests

def _copy_bitmask(self,
requests: GuidedRequests,
Expand Down Expand Up @@ -306,8 +318,8 @@ def add_batch(self, scheduled_requests: ScheduledRequests) -> None:
scheduled_requests, self.max_num_draft_tokens)

@nvtx_range("GuideDecoder.build")
def build(self) -> None:
self._build(self.requests)
def build(self) -> List[Tuple[int, str]]:
return self._build(self.requests)

@nvtx_range("GuideDecoder.copy_bitmask")
def copy_bitmask(self, num_bitmask_tokens: Optional[int] = None) -> None:
Expand All @@ -325,8 +337,8 @@ def apply_bitmask(self,

def execute(self,
logits: torch.Tensor,
d2t: Optional[torch.Tensor] = None) -> None:
self.build()
d2t: Optional[torch.Tensor] = None) -> List[Tuple[int, str]]:
failed_requests = self.build()

with torch.cuda.stream(self.stream):
torch.cuda.current_stream().wait_event(self.token_event)
Expand All @@ -337,6 +349,8 @@ def execute(self,
self.apply_bitmask(logits, d2t=d2t)
self.token_event.record()

return failed_requests

def _rollback_rejected_tokens(self, requests: GuidedRequests) -> None:
"""Rollback the grammar matcher for rejected tokens.

Expand Down Expand Up @@ -460,23 +474,25 @@ def fetch_batch(self) -> None:
)

@hostfunc
def build(self) -> None:
self._build(self.requests_hostfunc)
def build(self) -> List[Tuple[int, str]]:
return self._build(self.requests_hostfunc)

def execute(self,
logits: torch.Tensor,
d2t: Optional[torch.Tensor] = None) -> None:
d2t: Optional[torch.Tensor] = None) -> List[Tuple[int, str]]:
with torch.cuda.stream(self.stream):
torch.cuda.current_stream().wait_event(self.token_event)
self.fetch_batch()
self.init_disagg_gen_requests()
self.build()
failed_requests = self.build()
self.copy_bitmask()
self.bitmask_event.record()

torch.cuda.current_stream().wait_event(self.bitmask_event)
self.apply_bitmask(logits, d2t=d2t)

return failed_requests

@hostfunc
def rollback_rejected_tokens(self) -> None:
self._rollback_rejected_tokens(self.requests_hostfunc)
Expand Down Expand Up @@ -532,13 +548,13 @@ def fetch_draft_batch(self, draft_step: int = 0) -> None:
def execute_draft_batch(self,
logits: torch.Tensor,
d2t: Optional[torch.Tensor] = None,
draft_step: int = 0) -> None:
draft_step: int = 0) -> List[Tuple[int, str]]:
with torch.cuda.stream(self.stream):
torch.cuda.current_stream().wait_event(self.token_event)
self.fetch_draft_batch(draft_step=draft_step)
if draft_step == 0:
self.rollback_rejected_tokens()
self.build()
failed_requests = self.build()
if draft_step == self.max_num_draft_tokens - 1:
self.rollback_draft_tokens()
# Overwrite num_bitmask_tokens since the request might not be updated on CUDA stream yet.
Expand All @@ -550,3 +566,5 @@ def execute_draft_batch(self,
self.apply_bitmask(logits,
d2t=d2t,
num_bitmask_tokens=len(self.requests))

return failed_requests
76 changes: 69 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,14 +876,22 @@ def _executor_loop_pp(self):

batch_outputs = self._forward_step(scheduled_batch)

guided_decoder_failed_requests = None
if self.guided_decoder is not None:
self.guided_decoder.add_batch(scheduled_batch)
self.guided_decoder.execute(
guided_decoder_failed_requests = self.guided_decoder.execute(
batch_outputs['logits'])

sample_state = self._sample_async(
scheduled_batch, batch_outputs)
assert sample_state is not None, "Sampling failed"

# Handle guided decoder errors after _sample_async to avoid state conflicts.
# If called before, failed requests would be marked as GENERATION_COMPLETE,
# causing _sample_async to fail when accessing context_chunk_size property.
self._handle_guided_decoder_errors(
scheduled_batch, guided_decoder_failed_requests)

self._update_request_states(scheduled_batch)

if self.enable_iter_perf_stats:
Expand Down Expand Up @@ -1194,11 +1202,21 @@ def _executor_loop(self):
self.guided_decoder.rollback_draft_tokens()

batch_outputs = self._forward_step(scheduled_batch)

guided_decoder_failed_requests = None
if self.guided_decoder is not None:
self.guided_decoder.execute(batch_outputs['logits'])
guided_decoder_failed_requests = self.guided_decoder.execute(
batch_outputs['logits'])

sample_state = self._sample_async(scheduled_batch,
batch_outputs)

# Handle guided decoder errors after _sample_async to avoid state conflicts.
# If called before, failed requests would be marked as GENERATION_COMPLETE,
# causing _sample_async to fail when accessing context_chunk_size property.
self._handle_guided_decoder_errors(
scheduled_batch, guided_decoder_failed_requests)

if self.drafter is not None:
self.drafter.run_drafter_post(scheduled_batch,
self.resource_manager,
Expand Down Expand Up @@ -1450,15 +1468,23 @@ def _executor_loop_overlap(self):
self.drafter.cleanup_previous_draft_resources()

if can_queue:
guided_decoder_failed_requests = None
if self.guided_decoder is not None:
# add_batch must be called again to have updated new tokens.
self.guided_decoder.add_batch(scheduled_batch)
self.guided_decoder.execute(batch_outputs['logits'])
guided_decoder_failed_requests = self.guided_decoder.execute(
batch_outputs['logits'])

sample_state = self._sample_async(scheduled_batch,
batch_outputs)
assert sample_state is not None, "Sampling failed"

# Handle guided decoder errors after _sample_async to avoid state conflicts.
# If called before, failed requests would be marked as GENERATION_COMPLETE,
# causing _sample_async to fail when accessing context_chunk_size property.
self._handle_guided_decoder_errors(
scheduled_batch, guided_decoder_failed_requests)

self._update_request_states(scheduled_batch)

ctx_transmission_reqs = self._send_disagg_ctx_cache(
Expand Down Expand Up @@ -2138,18 +2164,28 @@ def _update_requests(self,
self._handle_errors(error_msg)

def _handle_errors(self,
error_msg: Optional[str] = None,
error_msg: Optional[Union[str, List[str]]] = None,
*,
requests: Optional[List[LlmRequest]] = None):
error_responses: Dict[int, LlmResponse] = {}
error_msg = error_msg or "error"
failed_requests = requests if requests is not None else self.active_requests
for request in failed_requests:

error_msg = error_msg or ["error"] * len(failed_requests)
if isinstance(error_msg, str):
error_msg = [error_msg] * len(failed_requests)
elif len(error_msg) != len(failed_requests):
logger.warning(
f"Length mismatch: error_msg has {len(error_msg)} items, "
f"but there are {len(failed_requests)} requests. "
f"Falling back to default error message.")
error_msg = ["error"] * len(failed_requests)

for request, err_msg in zip(failed_requests, error_msg):
req_id = request.py_request_id
request.state = LlmRequestState.GENERATION_COMPLETE
error_responses[req_id] = LlmResponse(
request_id=req_id,
error_msg=error_msg,
error_msg=err_msg,
client_id=request.py_client_id)
if requests is None:
self.active_requests.clear()
Expand Down Expand Up @@ -2541,6 +2577,32 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors,
def reset_prefix_cache(self):
self.kv_cache_manager.reset_reuse_state()

def _handle_guided_decoder_errors(
self, scheduled_batch: ScheduledRequests,
failed_requests: Optional[List[Tuple[int, str]]]):
"""Handle errors that occurred during guided decoding.

Args:
scheduled_batch: The current batch of scheduled requests
failed_requests: List of (request_id, error_message) tuples for failed requests,
or None if no failures occurred
"""
if not failed_requests:
return

failed_req_id_to_err = {req_id: err for req_id, err in failed_requests}
errors = []
failed_llm_requests = []

for request in scheduled_batch.all_requests():
if request.py_request_id not in failed_req_id_to_err:
continue
errors.append(failed_req_id_to_err[request.py_request_id])
failed_llm_requests.append(request)

if failed_llm_requests:
self._handle_errors(errors, requests=failed_llm_requests)


class DisaggPPTerminationHandler:
"""Handles termination synchronization across pipeline parallel ranks under disaggregated serving.
Expand Down