Skip to content

Commit 4bc902a

Browse files
committed
fix
Signed-off-by: junq <[email protected]>
1 parent e4fea88 commit 4bc902a

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

tests/unittest/_torch/modeling/test_modeling_multimodal.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Dict, List, Optional, Tuple, Type
99

1010
import torch
11-
from _torch.helpers import create_mock_engine
11+
from _torch.helpers import create_mock_cuda_graph_runner
1212
from transformers import AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
1313
from utils.llm_data import llm_models_root
1414

@@ -17,7 +17,6 @@
1717
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
1818
from tensorrt_llm._torch.metadata import KVCacheParams
1919
from tensorrt_llm._torch.model_config import ModelConfig
20-
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
2120
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
2221
from tensorrt_llm._utils import str_dtype_to_torch
2322
from tensorrt_llm.bindings.executor import KvCacheConfig
@@ -425,8 +424,7 @@ def run_trtllm_forward(self, trtllm_inputs, use_cuda_graph: bool = False):
425424
trtllm_inputs["attn_metadata"].prepare()
426425
return self.trtllm_model.forward(**trtllm_inputs)
427426
else:
428-
mock_engine = create_mock_engine(1)
429-
graph_runner = CUDAGraphRunner(mock_engine)
427+
graph_runner = create_mock_cuda_graph_runner(1)
430428
trtllm_inputs["attn_metadata"] = trtllm_inputs[
431429
"attn_metadata"
432430
].create_cuda_graph_metadata(1)

0 commit comments

Comments
 (0)