Skip to content

Commit 48e7178

Browse files
fixed nemotron sharding
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent cac7fe4 commit 48e7178

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def detect_sharding_from_factory_config(
641641
world_size=world_size,
642642
dist_op=None,
643643
min_local_shape=min_local_shape,
644-
layer_type=LayerType.MAMBA,
644+
layer_type=LayerType.MAMBA_FULL,
645645
)
646646
)
647647
num_row_col_shards += 1

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ def check_and_apply(self, gm: GraphModule, node: Node) -> bool:
557557
class LayerType(Enum):
558558
ATTENTION = "attention"
559559
MAMBA = "mamba"
560+
MAMBA_FULL = "mamba_full"
560561
MLP = "mlp"
561562
MOE = "moe"
562563

@@ -612,7 +613,7 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
612613

613614
def apply(self, gm: GraphModule, node: Node) -> None:
614615
"""Apply TP sharding transformation to the graph module."""
615-
if self.layer_type == LayerType.MAMBA:
616+
if self.layer_type == LayerType.MAMBA_FULL:
616617
_insert_sharded_mamba(
617618
gm=gm,
618619
entry_node=node,

0 commit comments

Comments
 (0)