Skip to content

Commit a396d06

Browse files
committed
fix ci
Signed-off-by: junq <[email protected]>
1 parent d14f687 commit a396d06

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self, config: CUDAGraphRunnerConfig):
8080

8181
def _create_shared_static_tensors(self):
8282
"""Allocates static tensors sized for the largest possible batch."""
83-
max_draft_len = self.config.original_max_draft_len if self.config.is_spec_decode else 0
83+
max_draft_len = self.config.original_max_draft_len if self.config.spec_config is not None else 0
8484
token_per_request = max_draft_len + 1
8585
max_total_tokens = (self.max_supported_batch_size *
8686
self.max_beam_width * token_per_request)
@@ -192,7 +192,7 @@ def capture(self,
192192
key: Tuple[int, int, int],
193193
forward_fn: Callable,
194194
initial_inputs: Dict[str, Any],
195-
enable_spec_decode: bool,
195+
enable_spec_decode: bool = False,
196196
postprocess_fn: Optional[Callable] = None):
197197
"""Captures the forward pass for a given batch size."""
198198
batch_size = key[0]
@@ -358,8 +358,10 @@ def _round_up_batch_size(self, batch_size: int) -> int:
358358
return self.supported_batch_sizes[idx]
359359

360360
@contextlib.contextmanager
361-
def pad_batch(self, scheduled_requests: ScheduledRequests,
362-
resource_manager: ResourceManager, runtime_draft_len: int):
361+
def pad_batch(self,
362+
scheduled_requests: ScheduledRequests,
363+
resource_manager: ResourceManager,
364+
runtime_draft_len: int = 0):
363365
"""Context manager to pad a batch to a graph-compatible size."""
364366
padding_size = self._get_padded_batch(scheduled_requests,
365367
resource_manager,

tests/unittest/_torch/helpers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,17 @@ def create_mock_cuda_graph_runner(batch_size: int, use_mrope: bool = False):
173173
config = CUDAGraphRunnerConfig(
174174
use_cuda_graph=True,
175175
cuda_graph_padding_enabled=False,
176-
supported_batch_sizes=[batch_size],
177-
max_supported_batch_size=batch_size,
178-
max_batch_size=batch_size,
176+
cuda_graph_batch_sizes=[batch_size],
177+
max_cuda_graph_batch_size=batch_size,
178+
batch_size=batch_size,
179179
max_beam_width=1,
180-
max_draft_len=0,
181180
max_num_tokens=1,
182181
use_mrope=use_mrope,
183182
spec_config=None,
184183
cuda_graph_mem_pool=None,
185184
enable_attention_dp=False,
185+
original_max_draft_len=0,
186+
is_draft_model=False,
186187
mapping=Mapping(),
187188
dist=None,
188189
kv_cache_manager_key=ResourceManagerType.KV_CACHE_MANAGER)

0 commit comments

Comments
 (0)