diff --git a/modelopt/torch/export/hf_config_map.py b/modelopt/torch/export/hf_config_map.py index 377e1887b5..664f855662 100644 --- a/modelopt/torch/export/hf_config_map.py +++ b/modelopt/torch/export/hf_config_map.py @@ -64,9 +64,9 @@ (["dense_attention_every_n_layers"], "dense_attention_every_n_layers"), # Phi3-small (["gegelu_limit"], "gegelu_limit"), # Phi3-small ( - ["num_local_experts", "moe_num_experts"], + ["num_local_experts", "moe_num_experts", "n_routed_experts"], "moe_num_experts", - ), # Mixture of Experts (Mixtral, DBRX) + ), # Mixture of Experts (Mixtral, DBRX, DeepSeek) (["num_experts_per_tok", "moe_top_k"], "moe_top_k"), # Mixture of Experts (Mixtral, DBRX) (["model_type"], "qwen_type"), # qwen (["lru_width"], "rnn_hidden_size"), # Recurrent Gemma diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index 9a2cd4b2f0..330c306706 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -98,6 +98,8 @@ def get_experts_list(module: torch.nn.Module, model_type: str): ] ): linear_names = ["gate_proj", "down_proj", "up_proj"] + elif "deepseek" in model_type: + linear_names = ["gate_proj", "down_proj", "up_proj"] else: raise NotImplementedError(f" {model_type} not supported") @@ -150,6 +152,33 @@ def check_model_compatibility(module_list: list[nn.Module]) -> tuple[bool, bool, def get_transformer_layers(model: nn.Module) -> list[nn.Module]: """Returns the root module of the transformer model.""" + if "Megatron" in type(model).__name__: + if hasattr(model, "model") and "GPTModel" in type(model.model).__name__: + # NEMO mcore models can be handled with the following branch. + model = model.model + + # NEMO non mcore models, we need to find the language_model module first. + children = [model] + language_model = None + while children and not language_model: + next_children = [] + for child in children: + if type(child).__name__ == "TransformerLanguageModel": + language_model = child + break + next_children.extend(list(child.children())) + children = next_children + if language_model: + warn("Warning: this is an old NEMO checkpoint format and will be deprecated soon.") + layers = list(language_model.embedding.children()) + list( + language_model.encoder.children() + ) + + if hasattr(language_model, "output_layer"): + layers.append(language_model.output_layer) + + return layers + if "GPTModel" in type(model).__name__: # mcore models layers = [] @@ -298,14 +327,20 @@ def is_mlp(module: nn.Module) -> bool: return any(key in type(module).__name__.upper() for key in ("MLP", "T5DENSE")) +def _is_deepseek_moe_name(module_name: str) -> bool: + return "deepseek" in module_name and "moe" in module_name + + def is_moe(module: nn.Module) -> bool: """Returns whether the module is an MOE layer.""" name = type(module).__name__.lower() # Auto-detect common MoE patterns if name.endswith("sparsemoeblock") or "moelayer" in name: return True + if _is_deepseek_moe_name(name) and hasattr(module, "gate") and hasattr(module, "experts"): + return True # Explicit matches for non-standard naming - return any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn"]) + return any(key in name for key in ["arcticmoe", "dbrxffn", "gptossmoe"]) def is_quantlinear(module: nn.Module) -> bool: @@ -358,7 +393,7 @@ def build_qkv( num_kv_heads = ext_config.num_kv_heads if "ColumnParallelLinear" in type(qkv_module).__name__: - # For Megatron-core model, num_kv_heads/num_attention_heads is the first dimension of QKV + # For NEMO model, num_kv_heads/num_attention_heads is the first dimension of QKV model_metadata_config["head_is_first_dim"] = True qkv_weight = qkv_module.weight @@ -965,14 +1000,17 @@ def module_match_name_list(module, name_list): """ return any(name.lower() in type(module).__name__.lower() for name in name_list) - if module_match_name_list( + module_name = type(module).__name__.lower() + + if _is_deepseek_moe_name(module_name): + return ["gate_proj", "down_proj", "up_proj"] + elif module_match_name_list( module, [ "Qwen2MoeSparseMoeBlock", "Qwen3MoeSparseMoeBlock", "Qwen3NextSparseMoeBlock", "Qwen3_5MoeSparseMoeBlock", - "DeepseekMoE", ], ): return ["gate_proj", "down_proj", "up_proj"] @@ -1455,7 +1493,7 @@ def _set_layer_config_from_metaconfig(layer_config, metaconfig): if k in metaconfig: setattr(layer_config, name, metaconfig[k]) - # MCore use "rope" as an alias for "rope_gpt_neox" + # MCore / NeMo use "rope" as an alias for "rope_gpt_neox" if layer_config.position_embedding_type == "rope": layer_config.position_embedding_type = "rope_gpt_neox" diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 4ceb51cd2c..50e95f5a06 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -1216,12 +1216,30 @@ def _update_svdquant(modules, new_pre_quant_scale): # Mathematical equivalence: # Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T # After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T - (["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")), + ( + [ + "LlamaAttention", + "Qwen3Attention", + "Qwen3MoeAttention", + "DeepseekV2Attention", + "DeepseekV3Attention", + ], + ("v_proj", "o_proj"), + ), # MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension # Mathematical equivalence: # Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T # After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T - (["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")), + ( + [ + "LlamaMLP", + "Qwen3MLP", + "Qwen3MoeMLP", + "DeepseekV2MLP", + "DeepseekV3MLP", + ], + ("up_proj", "down_proj"), + ), ] diff --git a/tests/unit/torch/export/test_deepseek_export_support.py b/tests/unit/torch/export/test_deepseek_export_support.py new file mode 100644 index 0000000000..c0c5443a50 --- /dev/null +++ b/tests/unit/torch/export/test_deepseek_export_support.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from modelopt.torch.export.hf_config_map import HF_CONFIG_MAP +from modelopt.torch.export.layer_utils import get_expert_linear_names, get_experts_list, is_moe +from modelopt.torch.export.quant_utils import PQS_FUSE_MODULE_MAPPING + + +class _FakeDeepseekExpert(nn.Module): + def __init__(self): + super().__init__() + self.gate_proj = nn.Linear(8, 16, bias=False) + self.down_proj = nn.Linear(16, 8, bias=False) + self.up_proj = nn.Linear(8, 16, bias=False) + + +class _FakeDeepseekGate(nn.Module): + def __init__(self, num_experts=2): + super().__init__() + self.top_k = 1 + self.n_routed_experts = num_experts + self.gating_dim = 8 + self.weight = nn.Parameter(torch.empty(num_experts, 8)) + nn.init.normal_(self.weight) + + +class DeepseekV3MoE(nn.Module): + def __init__(self, num_experts=2): + super().__init__() + self.gate = _FakeDeepseekGate(num_experts) + self.experts = nn.ModuleList([_FakeDeepseekExpert() for _ in range(num_experts)]) + self.shared_experts = _FakeDeepseekExpert() + + +def test_is_moe_detects_deepseek_v3_moe(): + assert is_moe(DeepseekV3MoE()) + + +def test_get_expert_linear_names_for_deepseek_v3(): + assert get_expert_linear_names(DeepseekV3MoE()) == ["gate_proj", "down_proj", "up_proj"] + + +def test_get_experts_list_for_deepseek_model_type(): + module = DeepseekV3MoE(num_experts=3) + + experts_list = get_experts_list(module, "deepseekv3forcausallm") + + assert len(experts_list) == 3 + assert all(len(expert_group) == 3 for expert_group in experts_list) + assert experts_list[0][0] is module.experts[0].gate_proj + assert experts_list[1][1] is module.experts[1].down_proj + assert experts_list[2][2] is module.experts[2].up_proj + + +def test_hf_config_map_supports_deepseek_num_experts(): + assert any( + output_name == "moe_num_experts" and "n_routed_experts" in input_names + for input_names, output_name in HF_CONFIG_MAP + ) + + +def test_prequant_fuse_mapping_covers_deepseek_v3(): + assert any( + "DeepseekV3Attention" in module_names and linear_pair == ("v_proj", "o_proj") + for module_names, linear_pair in PQS_FUSE_MODULE_MAPPING + ) + assert any( + "DeepseekV3MLP" in module_names and linear_pair == ("up_proj", "down_proj") + for module_names, linear_pair in PQS_FUSE_MODULE_MAPPING + )