Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions modelopt/torch/export/hf_config_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@
(["dense_attention_every_n_layers"], "dense_attention_every_n_layers"), # Phi3-small
(["gegelu_limit"], "gegelu_limit"), # Phi3-small
(
["num_local_experts", "moe_num_experts"],
["num_local_experts", "moe_num_experts", "n_routed_experts"],
"moe_num_experts",
), # Mixture of Experts (Mixtral, DBRX)
), # Mixture of Experts (Mixtral, DBRX, DeepSeek)
(["num_experts_per_tok", "moe_top_k"], "moe_top_k"), # Mixture of Experts (Mixtral, DBRX)
(["model_type"], "qwen_type"), # qwen
(["lru_width"], "rnn_hidden_size"), # Recurrent Gemma
Expand Down
48 changes: 43 additions & 5 deletions modelopt/torch/export/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def get_experts_list(module: torch.nn.Module, model_type: str):
]
):
linear_names = ["gate_proj", "down_proj", "up_proj"]
elif "deepseek" in model_type:
linear_names = ["gate_proj", "down_proj", "up_proj"]
else:
raise NotImplementedError(f" {model_type} not supported")

Expand Down Expand Up @@ -150,6 +152,33 @@ def check_model_compatibility(module_list: list[nn.Module]) -> tuple[bool, bool,

def get_transformer_layers(model: nn.Module) -> list[nn.Module]:
"""Returns the root module of the transformer model."""
if "Megatron" in type(model).__name__:
if hasattr(model, "model") and "GPTModel" in type(model.model).__name__:
# NEMO mcore models can be handled with the following branch.
model = model.model

# NEMO non mcore models, we need to find the language_model module first.
children = [model]
language_model = None
while children and not language_model:
next_children = []
for child in children:
if type(child).__name__ == "TransformerLanguageModel":
language_model = child
break
next_children.extend(list(child.children()))
children = next_children
if language_model:
warn("Warning: this is an old NEMO checkpoint format and will be deprecated soon.")
layers = list(language_model.embedding.children()) + list(
language_model.encoder.children()
)

if hasattr(language_model, "output_layer"):
layers.append(language_model.output_layer)

return layers

if "GPTModel" in type(model).__name__:
# mcore models
layers = []
Expand Down Expand Up @@ -298,14 +327,20 @@ def is_mlp(module: nn.Module) -> bool:
return any(key in type(module).__name__.upper() for key in ("MLP", "T5DENSE"))


def _is_deepseek_moe_name(module_name: str) -> bool:
return "deepseek" in module_name and "moe" in module_name


def is_moe(module: nn.Module) -> bool:
"""Returns whether the module is an MOE layer."""
name = type(module).__name__.lower()
# Auto-detect common MoE patterns
if name.endswith("sparsemoeblock") or "moelayer" in name:
return True
if _is_deepseek_moe_name(name) and hasattr(module, "gate") and hasattr(module, "experts"):
return True
# Explicit matches for non-standard naming
return any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn"])
return any(key in name for key in ["arcticmoe", "dbrxffn", "gptossmoe"])


def is_quantlinear(module: nn.Module) -> bool:
Expand Down Expand Up @@ -358,7 +393,7 @@ def build_qkv(
num_kv_heads = ext_config.num_kv_heads

if "ColumnParallelLinear" in type(qkv_module).__name__:
# For Megatron-core model, num_kv_heads/num_attention_heads is the first dimension of QKV
# For NEMO model, num_kv_heads/num_attention_heads is the first dimension of QKV
model_metadata_config["head_is_first_dim"] = True

qkv_weight = qkv_module.weight
Expand Down Expand Up @@ -965,14 +1000,17 @@ def module_match_name_list(module, name_list):
"""
return any(name.lower() in type(module).__name__.lower() for name in name_list)

if module_match_name_list(
module_name = type(module).__name__.lower()

if _is_deepseek_moe_name(module_name):
return ["gate_proj", "down_proj", "up_proj"]
elif module_match_name_list(
module,
[
"Qwen2MoeSparseMoeBlock",
"Qwen3MoeSparseMoeBlock",
"Qwen3NextSparseMoeBlock",
"Qwen3_5MoeSparseMoeBlock",
"DeepseekMoE",
],
):
return ["gate_proj", "down_proj", "up_proj"]
Expand Down Expand Up @@ -1455,7 +1493,7 @@ def _set_layer_config_from_metaconfig(layer_config, metaconfig):
if k in metaconfig:
setattr(layer_config, name, metaconfig[k])

# MCore use "rope" as an alias for "rope_gpt_neox"
# MCore / NeMo use "rope" as an alias for "rope_gpt_neox"
if layer_config.position_embedding_type == "rope":
layer_config.position_embedding_type = "rope_gpt_neox"

Expand Down
22 changes: 20 additions & 2 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,12 +1216,30 @@ def _update_svdquant(modules, new_pre_quant_scale):
# Mathematical equivalence:
# Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T
# After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")),
(
[
"LlamaAttention",
"Qwen3Attention",
"Qwen3MoeAttention",
"DeepseekV2Attention",
"DeepseekV3Attention",
],
("v_proj", "o_proj"),
),
# MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension
# Mathematical equivalence:
# Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
# After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")),
(
[
"LlamaMLP",
"Qwen3MLP",
"Qwen3MoeMLP",
"DeepseekV2MLP",
"DeepseekV3MLP",
],
("up_proj", "down_proj"),
),
]


Expand Down
85 changes: 85 additions & 0 deletions tests/unit/torch/export/test_deepseek_export_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn

from modelopt.torch.export.hf_config_map import HF_CONFIG_MAP
from modelopt.torch.export.layer_utils import get_expert_linear_names, get_experts_list, is_moe
from modelopt.torch.export.quant_utils import PQS_FUSE_MODULE_MAPPING


class _FakeDeepseekExpert(nn.Module):
def __init__(self):
super().__init__()
self.gate_proj = nn.Linear(8, 16, bias=False)
self.down_proj = nn.Linear(16, 8, bias=False)
self.up_proj = nn.Linear(8, 16, bias=False)


class _FakeDeepseekGate(nn.Module):
def __init__(self, num_experts=2):
super().__init__()
self.top_k = 1
self.n_routed_experts = num_experts
self.gating_dim = 8
self.weight = nn.Parameter(torch.empty(num_experts, 8))
nn.init.normal_(self.weight)


class DeepseekV3MoE(nn.Module):
def __init__(self, num_experts=2):
super().__init__()
self.gate = _FakeDeepseekGate(num_experts)
self.experts = nn.ModuleList([_FakeDeepseekExpert() for _ in range(num_experts)])
self.shared_experts = _FakeDeepseekExpert()


def test_is_moe_detects_deepseek_v3_moe():
assert is_moe(DeepseekV3MoE())


def test_get_expert_linear_names_for_deepseek_v3():
assert get_expert_linear_names(DeepseekV3MoE()) == ["gate_proj", "down_proj", "up_proj"]


def test_get_experts_list_for_deepseek_model_type():
module = DeepseekV3MoE(num_experts=3)

experts_list = get_experts_list(module, "deepseekv3forcausallm")

assert len(experts_list) == 3
assert all(len(expert_group) == 3 for expert_group in experts_list)
assert experts_list[0][0] is module.experts[0].gate_proj
assert experts_list[1][1] is module.experts[1].down_proj
assert experts_list[2][2] is module.experts[2].up_proj


def test_hf_config_map_supports_deepseek_num_experts():
assert any(
output_name == "moe_num_experts" and "n_routed_experts" in input_names
for input_names, output_name in HF_CONFIG_MAP
)


def test_prequant_fuse_mapping_covers_deepseek_v3():
assert any(
"DeepseekV3Attention" in module_names and linear_pair == ("v_proj", "o_proj")
for module_names, linear_pair in PQS_FUSE_MODULE_MAPPING
)
assert any(
"DeepseekV3MLP" in module_names and linear_pair == ("up_proj", "down_proj")
for module_names, linear_pair in PQS_FUSE_MODULE_MAPPING
)