Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""CUDA-backed cached causal conv1d custom ops and attention descriptor.

This mirrors `torch_backend_causal_conv.py` but reuses existing TRT-LLM CUDA
Expand Down Expand Up @@ -94,7 +109,7 @@ def cuda_causal_conv_prepare_metadata_fake(
)


@torch.library.custom_op("auto_deploy::cuda_cached_causal_conv1d", mutates_args={})
@torch.library.custom_op("auto_deploy::cuda_cached_causal_conv1d", mutates_args={"input"})
def _cuda_cached_causal_conv1d(
# INPUTS (dense but may be flattened across sequences)
input: torch.Tensor, # [b, s, c_in]
Expand All @@ -114,13 +129,15 @@ def _cuda_cached_causal_conv1d(
groups: int,
padding_mode: str,
activation: Optional[str],
) -> torch.Tensor:
) -> None:
"""Flattened cached causal conv that respects slot-indexed state caches (CUDA backend).

Supports two layouts from the attention interface:
- Generate-only: input is [b, 1, c_in]. We'll gather caches using slot_idx[:b].
- Flattened context/mixed: input is [1, total_s, c_in] and seq_len/seq_start
describe per-sequence segments. We'll process each segment and scatter final states to caches.

NOTE: This op modifies `input` in-place.
"""
b, s = input.shape[:2]
num_seq = seq_len.shape[0]
Expand All @@ -137,8 +154,6 @@ def _cuda_cached_causal_conv1d(
# Flatten tokens
bs = b * s
inp_flat = input.reshape(bs, *input.shape[2:]) # [total_s, C_in]
y = torch.empty(b, s, weight.shape[0], device=input.device, dtype=input.dtype)
y_flat = y.view(bs, *y.shape[2:])

# Prepare weight as [dim, width] (depthwise)
if weight.ndim == 3:
Expand Down Expand Up @@ -180,18 +195,16 @@ def _cuda_cached_causal_conv1d(
activation=activation,
pad_slot_id=PAD_SLOT_ID,
) # (dim, total_prefill_tokens)

# Scatter outputs back to y
y_prefill = y_varlen.transpose(0, 1) # [total_prefill_tokens, C_out]
y_flat[:total_prefill_tokens].copy_(y_prefill)
# Scatter outputs back to input buffer
inp_flat[:total_prefill_tokens] = y_varlen.transpose(0, 1)

# DECODE: batch update for single-token sequences
if num_decode > 0:
x_decode = inp_flat[
total_prefill_tokens : total_prefill_tokens + num_decode
] # [num_decode, C_in]

y_dec = causal_conv1d_update(
causal_conv1d_update(
x_decode, # [batch, dim]
conv_state_cache,
w2d,
Expand All @@ -202,12 +215,7 @@ def _cuda_cached_causal_conv1d(
pad_slot_id=PAD_SLOT_ID,
)

if y_dec.dim() == 3:
y_dec = y_dec.squeeze(-1)
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(y_dec)

# Custom op must not return an alias of any input; return a fresh tensor
return y
return


@_cuda_cached_causal_conv1d.register_fake
Expand All @@ -231,9 +239,12 @@ def _cuda_cached_causal_conv1d_fake(
padding_mode: str,
activation: Optional[str],
):
return torch.empty(
input.shape[0], input.shape[1], weight.shape[0], device=input.device, dtype=input.dtype
)
return


def cuda_cached_causal_conv1d_wrapper(input, *args, **kwargs):
torch.ops.auto_deploy.cuda_cached_causal_conv1d(input, *args, **kwargs)
return input


@AttentionRegistry.register("cuda_causal_conv")
Expand All @@ -259,7 +270,7 @@ def get_source_attention_op(cls) -> OpOverloadPacket:

@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.auto_deploy.cuda_cached_causal_conv1d
return cuda_cached_causal_conv1d_wrapper

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple

import torch
Expand Down Expand Up @@ -185,13 +200,13 @@ def _triton_cached_ssm(
C_flat = C.reshape(bs, *C.shape[2:]) # [bs, G, N]
dt_flat = dt.reshape(bs, dt.shape[2]) # [bs, H]

y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format)
y_flat = y.view(bs, *y.shape[2:])

ssm_state_size = B.shape[3]

num_prefill, num_prefill_tokens, num_decode = batch_info_tensor.tolist()

y_prefill = None
y_decode = None

# Prefill: concatenate tokens at the front and run combined scan
if num_prefill > 0:
hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
Expand Down Expand Up @@ -232,7 +247,6 @@ def _triton_cached_ssm(
mamba_ssm_cache_dtype=ssm_state_cache.dtype,
)

y_flat[:num_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
ssm_state_cache.index_copy_(
0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype)
)
Expand All @@ -251,7 +265,7 @@ def _triton_cached_ssm(
A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size)
D_full = D[..., None].expand(num_heads, head_dim)

y_dec = selective_state_update(
y_decode = selective_state_update(
ssm_state_cache,
x_decode,
dt_hp,
Expand All @@ -265,9 +279,19 @@ def _triton_cached_ssm(
state_batch_indices=slot_idx_decode,
) # [nd, H, D]

y_flat[num_prefill_tokens : num_prefill_tokens + num_decode].copy_(y_dec.to(y_flat.dtype))

return y
# Dispatch return logic
if num_prefill > 0 and num_decode > 0:
y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format)
y_flat = y.view(bs, *y.shape[2:])
y_flat[:num_prefill_tokens].copy_(y_prefill[0])
y_flat[num_prefill_tokens : num_prefill_tokens + num_decode].copy_(y_decode)
return y
elif num_prefill > 0:
return y_prefill[0].view(b, s, num_heads, head_dim).to(hidden_states.dtype)
elif num_decode > 0:
return y_decode.view(b, s, num_heads, head_dim).to(hidden_states.dtype)
else:
return torch.empty_like(hidden_states)


@_triton_cached_ssm.register_fake
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Fusion transform for fusing activation functions into causal_conv1d operations."""

from typing import List, Optional, Tuple
Expand All @@ -6,6 +21,7 @@
import torch.nn.functional as F
from torch.fx import GraphModule, Node

from ...custom_ops.mamba.cuda_backend_causal_conv import cuda_cached_causal_conv1d_wrapper
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import is_op
Expand Down Expand Up @@ -85,10 +101,12 @@ def _apply(
) -> Tuple[GraphModule, TransformInfo]:
graph = gm.graph

target_op = cuda_cached_causal_conv1d_wrapper

# Step 1: Identify causal_conv + activation pattern
matches = _match_causal_conv_activation_pattern(
graph,
target_op=torch.ops.auto_deploy.cuda_cached_causal_conv1d,
target_op=target_op,
)

# Step 2: Replace matched patterns with fused version
Expand All @@ -98,7 +116,7 @@ def _apply(
# Replace the last arg (activation=None) with activation_name
new_args = list(conv_node.args[:-1]) + [activation_name]
fused_node = graph.call_function(
torch.ops.auto_deploy.cuda_cached_causal_conv1d,
target_op,
args=tuple(new_args),
)

Expand Down
Loading