Skip to content

Commit ad94c87

Browse files
committed
Re-enable the activation fusion
Signed-off-by: Chenghao Zhang <[email protected]>
1 parent 06e2fb3 commit ad94c87

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,17 @@ def _apply(
8585
) -> Tuple[GraphModule, TransformInfo]:
8686
graph = gm.graph
8787

88+
# Import wrapper to match against
89+
# We use the wrapper because the underlying op returns None (void) to avoid aliasing,
90+
# but the wrapper returns the tensor to maintain graph data flow.
91+
from ...custom_ops.mamba.cuda_backend_causal_conv import cuda_cached_causal_conv1d_wrapper
92+
93+
target_op = cuda_cached_causal_conv1d_wrapper
94+
8895
# Step 1: Identify causal_conv + activation pattern
8996
matches = _match_causal_conv_activation_pattern(
9097
graph,
91-
target_op=torch.ops.auto_deploy.cuda_cached_causal_conv1d,
98+
target_op=target_op,
9299
)
93100

94101
# Step 2: Replace matched patterns with fused version
@@ -98,7 +105,7 @@ def _apply(
98105
# Replace the last arg (activation=None) with activation_name
99106
new_args = list(conv_node.args[:-1]) + [activation_name]
100107
fused_node = graph.call_function(
101-
torch.ops.auto_deploy.cuda_cached_causal_conv1d,
108+
target_op,
102109
args=tuple(new_args),
103110
)
104111

0 commit comments

Comments
 (0)