Skip to content

Commit 8e830a1

Browse files
authored
Fix dynamic context syntax and remove redundant tensors (#2336)
1 parent 7e18da2 commit 8e830a1

File tree

1 file changed

+1
-31
lines changed

1 file changed

+1
-31
lines changed

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def __init__(
337337
self.kv_reduced_dim = kv_lora_rank + qk_pos_emb_head_dim
338338
self.block_size_bytes = (
339339
dtype_size_bytes
340-
* num_attention_layers
340+
* self.num_attention_layers
341341
* self.block_size_tokens
342342
* self.kv_reduced_dim
343343
)
@@ -604,32 +604,6 @@ def allocate_memory_buffer():
604604
device=torch.cuda.current_device(),
605605
)
606606

607-
# `*_cudagraph_only` tensors are for use with cuda graphs to maintain
608-
# consistent input shapes, which is required to use cuda graphs.
609-
# During these steps, the `*_cudagraph_only`
610-
# tensors are used, otherwise their same-name but un-suffixed
611-
# corresponding tensors are used.
612-
613-
self.query_seq_lengths_cudagraph_only = torch.full(
614-
(self.max_total_requests,), 0, dtype=torch.int32, device=torch.cuda.current_device()
615-
)
616-
self.cu_query_seq_lengths_cudagraph_only = torch.full(
617-
(self.max_total_requests + 1,), 0, dtype=torch.int32, device=torch.cuda.current_device()
618-
)
619-
self.kv_seq_lengths_cudagraph_only = torch.full(
620-
(self.max_total_requests,), 0, dtype=torch.int32, device=torch.cuda.current_device()
621-
)
622-
self.cu_kv_seq_lengths_cudagraph_only = torch.full(
623-
(self.max_total_requests + 1,), 0, dtype=torch.int32, device=torch.cuda.current_device()
624-
)
625-
626-
self.request_to_kv_block_ids_cudagraph_only = torch.full(
627-
(self.max_total_requests, self.max_kv_block_count),
628-
0,
629-
dtype=torch.int,
630-
device=torch.cuda.current_device(),
631-
)
632-
633607
# Optional state tensors for hybrid models
634608
def allocate_mamba_states():
635609
"""Allocate Mamba states. This function is called below within
@@ -1495,10 +1469,6 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T
14951469
(Tensor) Newly paused request IDs.
14961470
"""
14971471

1498-
# If tensor state is deallocated, do not add request.
1499-
if not self.is_tensor_state_allocated:
1500-
raise TensorStateDeallocatedError(req.request_id)
1501-
15021472
# 1. The active token mask tells us which requests are still active and which are completed
15031473
# active_request_count -> This corresponds to requests that have not reached EOD or max length
15041474
# finished_request_count are requests that have reached the termination criterion

0 commit comments

Comments
 (0)