diff --git a/examples/auto_deploy/nano_v3.yaml b/examples/auto_deploy/nano_v3.yaml index 9d9acf6ef7f..1f2cfd0c614 100644 --- a/examples/auto_deploy/nano_v3.yaml +++ b/examples/auto_deploy/nano_v3.yaml @@ -24,3 +24,6 @@ transforms: cache_config: # mamba_dtype: float32 mamba_dtype: null + fuse_mamba_a_log: + stage: post_load_fusion + enabled: true diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py new file mode 100644 index 00000000000..5e967e1f35e --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transform to fuse A_log into A for Mamba/NemotronH models.""" + +import operator +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torch._inductor.pattern_matcher import ( + CallFunction, + CallMethod, + KeywordArg, + Match, + register_graph_pattern, +) +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.logger import ad_logger +from ...utils.pattern_matcher import ADPatternMatcherPass +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +def _get_attr_by_name(obj, name): + for part in name.split("."): + obj = getattr(obj, part) + return obj + + +def _set_attr_by_name(obj, name, value): + parts = name.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + +def _del_attr_by_name(obj, name): + parts = name.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + delattr(obj, parts[-1]) + + +_PATTERN_INPUT_NAME = "a_log_like" + + +def _find_a_log_attr(node: Optional[Node]) -> Optional[Node]: + """Walk backwards through up to `max_backtrack_steps` unary nodes to find the A_log attribute.""" + current = node + max_backtrack_steps = 4 + for _ in range(max_backtrack_steps): + if current is None: + return None + if current.op == "get_attr": + return current + inputs = list(current.all_input_nodes) + if len(inputs) != 1: + return None + current = inputs[0] + return current if current and current.op == "get_attr" else None + + +def _ensure_a_fused_param(gm: GraphModule, param_name: str) -> Optional[str]: + """Create (if missing) the fused parameter that replaces A_log usage.""" + if not param_name.endswith("A_log"): + return None + + new_param_name = param_name.replace("A_log", "A_fused") + try: + _get_attr_by_name(gm, new_param_name) + return new_param_name + except AttributeError: + pass + + try: + a_log = _get_attr_by_name(gm, param_name) + except AttributeError: + ad_logger.warning(f"Could not find attribute {param_name} in gm.") + return None + + with torch.no_grad(): + a_fused = -torch.exp(a_log.float()) + + _set_attr_by_name( + gm, + new_param_name, + nn.Parameter(a_fused, requires_grad=False), + ) + return new_param_name + + +def _remove_unused_a_log_params(gm: GraphModule) -> bool: + """Remove detached A_log parameters after fusion.""" + + def _is_a_log_node(node: Node) -> bool: + return ( + node.op == "get_attr" and isinstance(node.target, str) and node.target.endswith("A_log") + ) + + used_a_log_targets = {str(node.target) for node in gm.graph.nodes if _is_a_log_node(node)} + removed = False + + def _maybe_remove(name: str) -> None: + nonlocal removed + if not name.endswith("A_log") or name in used_a_log_targets: + return + try: + _del_attr_by_name(gm, name) + removed = True + except AttributeError: + ad_logger.warning(f"Failed to delete unused parameter {name} from GraphModule.") + + for name, _ in list(gm.named_parameters()): + _maybe_remove(name) + for name, _ in list(gm.named_buffers()): + _maybe_remove(name) + + return removed + + +def _has_a_log_attr(match: Match) -> bool: + node = match.kwargs.get(_PATTERN_INPUT_NAME) + attr_node = _find_a_log_attr(node if isinstance(node, Node) else None) + return bool( + attr_node and isinstance(attr_node.target, str) and attr_node.target.endswith("A_log") + ) + + +def _fuse_a_log_handler(match: Match, a_log_like: Node) -> None: + graph = match.graph + gm = graph.owning_module + if gm is None: + ad_logger.warning("Pattern matched but owning GraphModule is missing.") + return + + neg_node = match.output_node() + exp_node = neg_node.args[0] if neg_node.args else None + if not isinstance(exp_node, Node): + ad_logger.warning("Unexpected exp node structure; skipping fusion.") + return + + attr_node = _find_a_log_attr(a_log_like) + if attr_node is None or not isinstance(attr_node.target, str): + ad_logger.warning("Could not trace back to A_log attribute; skipping fusion.") + return + + fused_name = _ensure_a_fused_param(gm, attr_node.target) + if fused_name is None: + return + + new_attr_node = graph.get_attr(fused_name) + neg_node.replace_all_uses_with(new_attr_node) + match.erase_nodes() + + +def _register_fuse_a_log_patterns(patterns: ADPatternMatcherPass) -> None: + """Register neg(exp(.)) patterns that should be folded into fused constants.""" + + def _register(pattern): + register_graph_pattern( + pattern, + extra_check=_has_a_log_attr, + pass_dict=patterns, + )(_fuse_a_log_handler) + + exp_call_function_targets = ( + torch.exp, + torch.ops.aten.exp.default, + ) + neg_call_function_targets = ( + operator.neg, + torch.neg, + torch.ops.aten.neg.default, + ) + neg_call_method_targets = ("neg",) + for exp_target in exp_call_function_targets: + exp_expr = CallFunction(exp_target, KeywordArg(_PATTERN_INPUT_NAME)) + for neg_target in neg_call_function_targets: + _register(CallFunction(neg_target, exp_expr)) + for neg_target in neg_call_method_targets: + _register(CallMethod(neg_target, exp_expr)) + + exp_call_method_targets = ("exp",) + for exp_target in exp_call_method_targets: + exp_expr = CallMethod(exp_target, KeywordArg(_PATTERN_INPUT_NAME)) + for neg_target in neg_call_function_targets: + _register(CallFunction(neg_target, exp_expr)) + for neg_target in neg_call_method_targets: + _register(CallMethod(neg_target, exp_expr)) + + +@TransformRegistry.register("fuse_mamba_a_log") +class FuseMambaALog(BaseTransform): + """Fuse A_log parameter into A constant/parameter. + + Replaces: + A = -torch.exp(self.A_log.float()) + With: + A = self.A_fused + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + patterns = ADPatternMatcherPass() + _register_fuse_a_log_patterns(patterns) + num_matches = patterns.apply(gm.graph) + + if num_matches > 0: + gm.graph.eliminate_dead_code() + _remove_unused_a_log_params(gm) + + return gm, TransformInfo( + skipped=False, + num_matches=num_matches, + is_clean=num_matches == 0, + has_valid_shapes=True, + ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_mamba_a_log.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_mamba_a_log.py new file mode 100644 index 00000000000..e31feb85386 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_mamba_a_log.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +import torch.nn as nn +from _graph_test_helpers import run_test_transformed_gm +from torch.fx import GraphModule + +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer + + +class DummyMambaALogModule(nn.Module): + def __init__(self, num_features=16, dtype=torch.float32, device="cuda"): + super().__init__() + self.register_parameter( + "A_log", + nn.Parameter(torch.randn(num_features, device=device, dtype=dtype)), + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + fused_a = -torch.exp(self.A_log.float()) + return inputs + fused_a + + def get_input(self, device="cuda", dtype=torch.float32) -> torch.Tensor: + return torch.randn(self.A_log.shape[0], device=device, dtype=dtype) + + +def _apply_fuse_mamba_a_log(gm: GraphModule) -> GraphModule: + return InferenceOptimizer( + None, + { + "fuse_mamba_a_log": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) + + +def test_fuse_mamba_a_log_creates_fused_param(): + device = "cuda" + dtype = torch.float32 + torch.manual_seed(42) + + model = DummyMambaALogModule(num_features=8, dtype=dtype, device=device).to( + device=device, dtype=dtype + ) + x = model.get_input(device=device, dtype=dtype) + + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = _apply_fuse_mamba_a_log(gm) + + run_test_transformed_gm( + model, + x, + gm_transformed, + lambda gm_out: any( + node.op == "get_attr" and str(node.target).endswith("A_fused") + for node in gm_out.graph.nodes + ), + lambda num: num, + atol=1e-5, + rtol=1e-5, + test_load_hook=False, + strict_loading=True, + ) + + fused_params = [ + name for name, _ in gm_transformed.named_parameters() if name.endswith("A_fused") + ] + assert fused_params, "Expected fused A parameter to be registered." + assert not any(name.endswith("A_log") for name, _ in gm_transformed.named_parameters()), ( + "A_log parameter should be removed after fusion." + ) + assert not any( + node.target in {torch.exp, torch.ops.aten.exp.default} + for node in gm_transformed.graph.nodes + ), "exp node should be removed after fusion." + + +def test_fuse_mamba_a_log_memory_usage(): + torch.manual_seed(42) + torch.cuda.manual_seed(42) + torch.cuda.empty_cache() + + device = "cuda" + dtype = torch.float32 + num_features = 1024 * 1024 + + model = DummyMambaALogModule(num_features=num_features, dtype=dtype, device=device).to( + device=device, dtype=dtype + ) + x = model.get_input(device=device, dtype=dtype) + gm = torch_export_to_gm(model, args=(x,), clone=True) + + torch.cuda.synchronize() + torch.cuda.empty_cache() + mem_before = torch.cuda.memory_allocated() + + gm_transformed = _apply_fuse_mamba_a_log(gm) + + torch.cuda.synchronize() + torch.cuda.empty_cache() + mem_after = torch.cuda.memory_allocated() + + diff = mem_after - mem_before + tolerance = 5 * 1024 # 5_KiB tolerance for allocator variance + + assert abs(diff) <= tolerance, ( + f"Unexpected memory delta after fusion. Expected no additional memory, got {diff} bytes." + ) + + with pytest.raises(AttributeError): + gm_transformed.get_parameter("A_log")