@@ -80,7 +80,7 @@ def __init__(self, config: CUDAGraphRunnerConfig):
8080
8181 def _create_shared_static_tensors (self ):
8282 """Allocates static tensors sized for the largest possible batch."""
83- max_draft_len = self .config .original_max_draft_len if self .config .is_spec_decode else 0
83+ max_draft_len = self .config .original_max_draft_len if self .config .spec_config is not None else 0
8484 token_per_request = max_draft_len + 1
8585 max_total_tokens = (self .max_supported_batch_size *
8686 self .max_beam_width * token_per_request )
@@ -192,7 +192,7 @@ def capture(self,
192192 key : Tuple [int , int , int ],
193193 forward_fn : Callable ,
194194 initial_inputs : Dict [str , Any ],
195- enable_spec_decode : bool ,
195+ enable_spec_decode : bool = False ,
196196 postprocess_fn : Optional [Callable ] = None ):
197197 """Captures the forward pass for a given batch size."""
198198 batch_size = key [0 ]
@@ -358,8 +358,10 @@ def _round_up_batch_size(self, batch_size: int) -> int:
358358 return self .supported_batch_sizes [idx ]
359359
360360 @contextlib .contextmanager
361- def pad_batch (self , scheduled_requests : ScheduledRequests ,
362- resource_manager : ResourceManager , runtime_draft_len : int ):
361+ def pad_batch (self ,
362+ scheduled_requests : ScheduledRequests ,
363+ resource_manager : ResourceManager ,
364+ runtime_draft_len : int = 0 ):
363365 """Context manager to pad a batch to a graph-compatible size."""
364366 padding_size = self ._get_padded_batch (scheduled_requests ,
365367 resource_manager ,
0 commit comments