File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed
tensorrt_llm/_torch/auto_deploy/transform/library Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -85,10 +85,17 @@ def _apply(
8585 ) -> Tuple [GraphModule , TransformInfo ]:
8686 graph = gm .graph
8787
88+ # Import wrapper to match against
89+ # We use the wrapper because the underlying op returns None (void) to avoid aliasing,
90+ # but the wrapper returns the tensor to maintain graph data flow.
91+ from ...custom_ops .mamba .cuda_backend_causal_conv import cuda_cached_causal_conv1d_wrapper
92+
93+ target_op = cuda_cached_causal_conv1d_wrapper
94+
8895 # Step 1: Identify causal_conv + activation pattern
8996 matches = _match_causal_conv_activation_pattern (
9097 graph ,
91- target_op = torch . ops . auto_deploy . cuda_cached_causal_conv1d ,
98+ target_op = target_op ,
9299 )
93100
94101 # Step 2: Replace matched patterns with fused version
@@ -98,7 +105,7 @@ def _apply(
98105 # Replace the last arg (activation=None) with activation_name
99106 new_args = list (conv_node .args [:- 1 ]) + [activation_name ]
100107 fused_node = graph .call_function (
101- torch . ops . auto_deploy . cuda_cached_causal_conv1d ,
108+ target_op ,
102109 args = tuple (new_args ),
103110 )
104111
You can’t perform that action at this time.
0 commit comments