Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/auto_deploy/nano_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ transforms:
cache_config:
# mamba_dtype: float32
mamba_dtype: null
fuse_mamba_a_log:
stage: post_load_fusion
enabled: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transform to fuse A_log into A for Mamba/NemotronH models."""

import operator
from typing import Optional, Tuple

import torch
import torch.nn as nn
from torch._inductor.pattern_matcher import (
CallFunction,
CallMethod,
KeywordArg,
Match,
register_graph_pattern,
)
from torch.fx import GraphModule, Node

from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
from ...utils.pattern_matcher import ADPatternMatcherPass
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry


def _get_attr_by_name(obj, name):
for part in name.split("."):
obj = getattr(obj, part)
return obj


def _set_attr_by_name(obj, name, value):
parts = name.split(".")
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], value)


def _del_attr_by_name(obj, name):
parts = name.split(".")
for part in parts[:-1]:
obj = getattr(obj, part)
delattr(obj, parts[-1])


_PATTERN_INPUT_NAME = "a_log_like"


def _find_a_log_attr(node: Optional[Node]) -> Optional[Node]:
"""Walk backwards through up to `max_backtrack_steps` unary nodes to find the A_log attribute."""
current = node
max_backtrack_steps = 4
for _ in range(max_backtrack_steps):
if current is None:
return None
if current.op == "get_attr":
return current
inputs = list(current.all_input_nodes)
if len(inputs) != 1:
return None
current = inputs[0]
return current if current and current.op == "get_attr" else None


def _ensure_a_fused_param(gm: GraphModule, param_name: str) -> Optional[str]:
"""Create (if missing) the fused parameter that replaces A_log usage."""
if not param_name.endswith("A_log"):
return None

new_param_name = param_name.replace("A_log", "A_fused")
try:
_get_attr_by_name(gm, new_param_name)
return new_param_name
except AttributeError:
pass

try:
a_log = _get_attr_by_name(gm, param_name)
except AttributeError:
ad_logger.warning(f"Could not find attribute {param_name} in gm.")
return None

with torch.no_grad():
a_fused = -torch.exp(a_log.float())

_set_attr_by_name(
gm,
new_param_name,
nn.Parameter(a_fused, requires_grad=False),
)
return new_param_name


def _remove_unused_a_log_params(gm: GraphModule) -> bool:
"""Remove detached A_log parameters after fusion."""

def _is_a_log_node(node: Node) -> bool:
return (
node.op == "get_attr" and isinstance(node.target, str) and node.target.endswith("A_log")
)

used_a_log_targets = {str(node.target) for node in gm.graph.nodes if _is_a_log_node(node)}
removed = False

def _maybe_remove(name: str) -> None:
nonlocal removed
if not name.endswith("A_log") or name in used_a_log_targets:
return
try:
_del_attr_by_name(gm, name)
removed = True
except AttributeError:
ad_logger.warning(f"Failed to delete unused parameter {name} from GraphModule.")

for name, _ in list(gm.named_parameters()):
_maybe_remove(name)
for name, _ in list(gm.named_buffers()):
_maybe_remove(name)

return removed


def _has_a_log_attr(match: Match) -> bool:
node = match.kwargs.get(_PATTERN_INPUT_NAME)
attr_node = _find_a_log_attr(node if isinstance(node, Node) else None)
return bool(
attr_node and isinstance(attr_node.target, str) and attr_node.target.endswith("A_log")
)


def _fuse_a_log_handler(match: Match, a_log_like: Node) -> None:
graph = match.graph
gm = graph.owning_module
if gm is None:
ad_logger.warning("Pattern matched but owning GraphModule is missing.")
return

neg_node = match.output_node()
exp_node = neg_node.args[0] if neg_node.args else None
if not isinstance(exp_node, Node):
ad_logger.warning("Unexpected exp node structure; skipping fusion.")
return

attr_node = _find_a_log_attr(a_log_like)
if attr_node is None or not isinstance(attr_node.target, str):
ad_logger.warning("Could not trace back to A_log attribute; skipping fusion.")
return

fused_name = _ensure_a_fused_param(gm, attr_node.target)
if fused_name is None:
return

new_attr_node = graph.get_attr(fused_name)
neg_node.replace_all_uses_with(new_attr_node)
match.erase_nodes()


def _register_fuse_a_log_patterns(patterns: ADPatternMatcherPass) -> None:
"""Register neg(exp(.)) patterns that should be folded into fused constants."""

def _register(pattern):
register_graph_pattern(
pattern,
extra_check=_has_a_log_attr,
pass_dict=patterns,
)(_fuse_a_log_handler)

exp_call_function_targets = (
torch.exp,
torch.ops.aten.exp.default,
)
neg_call_function_targets = (
operator.neg,
torch.neg,
torch.ops.aten.neg.default,
)
neg_call_method_targets = ("neg",)
for exp_target in exp_call_function_targets:
exp_expr = CallFunction(exp_target, KeywordArg(_PATTERN_INPUT_NAME))
for neg_target in neg_call_function_targets:
_register(CallFunction(neg_target, exp_expr))
for neg_target in neg_call_method_targets:
_register(CallMethod(neg_target, exp_expr))

exp_call_method_targets = ("exp",)
for exp_target in exp_call_method_targets:
exp_expr = CallMethod(exp_target, KeywordArg(_PATTERN_INPUT_NAME))
for neg_target in neg_call_function_targets:
_register(CallFunction(neg_target, exp_expr))
for neg_target in neg_call_method_targets:
_register(CallMethod(neg_target, exp_expr))


@TransformRegistry.register("fuse_mamba_a_log")
class FuseMambaALog(BaseTransform):
"""Fuse A_log parameter into A constant/parameter.

Replaces:
A = -torch.exp(self.A_log.float())
With:
A = self.A_fused
"""

def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
patterns = ADPatternMatcherPass()
_register_fuse_a_log_patterns(patterns)
num_matches = patterns.apply(gm.graph)

if num_matches > 0:
gm.graph.eliminate_dead_code()
_remove_unused_a_log_params(gm)

return gm, TransformInfo(
skipped=False,
num_matches=num_matches,
is_clean=num_matches == 0,
has_valid_shapes=True,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
import torch.nn as nn
from _graph_test_helpers import run_test_transformed_gm
from torch.fx import GraphModule

from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer


class DummyMambaALogModule(nn.Module):
def __init__(self, num_features=16, dtype=torch.float32, device="cuda"):
super().__init__()
self.register_parameter(
"A_log",
nn.Parameter(torch.randn(num_features, device=device, dtype=dtype)),
)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
fused_a = -torch.exp(self.A_log.float())
return inputs + fused_a

def get_input(self, device="cuda", dtype=torch.float32) -> torch.Tensor:
return torch.randn(self.A_log.shape[0], device=device, dtype=dtype)


def _apply_fuse_mamba_a_log(gm: GraphModule) -> GraphModule:
return InferenceOptimizer(
None,
{
"fuse_mamba_a_log": {
"stage": "post_load_fusion",
},
},
)(None, gm)


def test_fuse_mamba_a_log_creates_fused_param():
device = "cuda"
dtype = torch.float32
torch.manual_seed(42)

model = DummyMambaALogModule(num_features=8, dtype=dtype, device=device).to(
device=device, dtype=dtype
)
x = model.get_input(device=device, dtype=dtype)

gm = torch_export_to_gm(model, args=(x,), clone=True)
gm_transformed = _apply_fuse_mamba_a_log(gm)

run_test_transformed_gm(
model,
x,
gm_transformed,
lambda gm_out: any(
node.op == "get_attr" and str(node.target).endswith("A_fused")
for node in gm_out.graph.nodes
),
lambda num: num,
atol=1e-5,
rtol=1e-5,
test_load_hook=False,
strict_loading=True,
)

fused_params = [
name for name, _ in gm_transformed.named_parameters() if name.endswith("A_fused")
]
assert fused_params, "Expected fused A parameter to be registered."
assert not any(name.endswith("A_log") for name, _ in gm_transformed.named_parameters()), (
"A_log parameter should be removed after fusion."
)
assert not any(
node.target in {torch.exp, torch.ops.aten.exp.default}
for node in gm_transformed.graph.nodes
), "exp node should be removed after fusion."


def test_fuse_mamba_a_log_memory_usage():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.empty_cache()

device = "cuda"
dtype = torch.float32
num_features = 1024 * 1024

model = DummyMambaALogModule(num_features=num_features, dtype=dtype, device=device).to(
device=device, dtype=dtype
)
x = model.get_input(device=device, dtype=dtype)
gm = torch_export_to_gm(model, args=(x,), clone=True)

torch.cuda.synchronize()
torch.cuda.empty_cache()
mem_before = torch.cuda.memory_allocated()

gm_transformed = _apply_fuse_mamba_a_log(gm)

torch.cuda.synchronize()
torch.cuda.empty_cache()
mem_after = torch.cuda.memory_allocated()

diff = mem_after - mem_before
tolerance = 5 * 1024 # 5_KiB tolerance for allocator variance

assert abs(diff) <= tolerance, (
f"Unexpected memory delta after fusion. Expected no additional memory, got {diff} bytes."
)

with pytest.raises(AttributeError):
gm_transformed.get_parameter("A_log")
Loading