-
Notifications
You must be signed in to change notification settings - Fork 628
Moe bf16 ep #4144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Moe bf16 ep #4144
Changes from 12 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
d767bca
refactor pytorch.nn.moe
grimoire d0865d4
add ep support
grimoire f663b34
fix tp
grimoire 30fbac1
support blocked fp8 moe with split_size<world_size
grimoire 2e64bb5
unit test allow both fa3 and fa
grimoire 1107fe2
add singleton
grimoire 89e7050
singleton and ctxmgrbase
grimoire bd53fb5
comment
grimoire 6a09849
Merge branch 'main' into moe-bf16-ep
grimoire 6d581eb
add static
grimoire 32e0699
remove chunk
grimoire eeaafef
merge main
grimoire 640bc25
remove forward dptp
grimoire 260fba7
bound check
grimoire 755a987
remove monkey patch
grimoire 1e1d669
merge main
grimoire 00a096c
rename kernel
grimoire File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| from .blocked_fp8 import TritonFusedMoEBlockedF8Builder # noqa: F401 | ||
| from .default import TritonFusedMoEBuilder # noqa: F401 | ||
| from .w8a8 import TritonFusedMoEW8A8Builder # noqa: F401 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,259 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
|
|
||
| import contextlib | ||
| from typing import Callable, List | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from lmdeploy.pytorch.backends.moe import FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl | ||
| from lmdeploy.pytorch.distributed import get_dist_manager | ||
| from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8 | ||
| from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 | ||
| from lmdeploy.pytorch.kernels.cuda.fused_moe import _renormalize | ||
| from lmdeploy.pytorch.model_inputs import get_step_ctx_manager | ||
| from lmdeploy.utils import get_logger | ||
|
|
||
| from .ep_utils import gather_outputs_by_attn_tp, split_inputs_by_attn_tp | ||
|
|
||
| logger = get_logger('lmdeploy') | ||
|
|
||
|
|
||
| class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl): | ||
| """Triton fused moe blocked f8 implementation.""" | ||
|
|
||
| def __init__(self, | ||
| top_k: int, | ||
| num_experts: int, | ||
| renormalize: bool = False, | ||
| block_size: int = 128, | ||
| out_dtype: torch.dtype = torch.float16): | ||
| self.num_experts = num_experts | ||
| self.top_k = top_k | ||
| self.renormalize = renormalize | ||
| self.block_size = block_size | ||
| self.out_dtype = out_dtype | ||
|
|
||
| def ep_expert_list(self, world_size: int, rank: int): | ||
| """Experts list of current rank.""" | ||
| num_experts = self.num_experts | ||
| expert_per_rank = (num_experts + world_size - 1) // world_size | ||
| first_expert = rank * expert_per_rank | ||
| last_expert = min(first_expert + expert_per_rank, num_experts) | ||
| return list(range(first_expert, last_expert)) | ||
|
|
||
| def forward(self, | ||
| hidden_states: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.LongTensor, | ||
| gate_up_weights: torch.Tensor, | ||
| gate_up_scale: torch.Tensor, | ||
| down_weights: torch.Tensor, | ||
| down_scale: torch.Tensor, | ||
| gate_up_bias: torch.Tensor = None, | ||
| down_bias: torch.Tensor = None, | ||
| expert_list: List[int] = None, | ||
| act_func: Callable = None): | ||
| """forward.""" | ||
| input_size = hidden_states.shape | ||
| hidden_states = hidden_states.flatten(0, -2) | ||
| input_quant, input_scale = quant_fp8(hidden_states, self.block_size, dtype=gate_up_weights.dtype) | ||
|
|
||
| expert_offset = 0 | ||
| num_experts = None | ||
| if expert_list is not None and len(expert_list) != self.num_experts: | ||
| expert_offset = expert_list[0] | ||
| num_experts = self.num_experts | ||
| output = fused_moe_blocked_fp8(input_quant, | ||
| input_scale, | ||
| gate_up_weights, | ||
| gate_up_scale, | ||
| down_weights, | ||
| down_scale, | ||
| topk_weights=topk_weights, | ||
| topk_ids=topk_ids, | ||
| topk=self.top_k, | ||
| w1_bias=gate_up_bias, | ||
| w2_bias=down_bias, | ||
| out_dtype=hidden_states.dtype, | ||
| expert_offset=expert_offset, | ||
| num_experts=num_experts, | ||
| renormalize=self.renormalize, | ||
| act_func=act_func) | ||
| output = output.unflatten(0, input_size[:-1]) | ||
| return output | ||
|
|
||
|
|
||
| @contextlib.contextmanager | ||
| def monk_deep_gemm(): | ||
| from dlblas.kernels.fused_moe_v3 import use_deep_gemm | ||
| if use_deep_gemm: | ||
| yield | ||
| return | ||
|
|
||
| # patch deep_gemm | ||
| import deep_gemm | ||
| import dlblas | ||
|
|
||
| from lmdeploy.pytorch.third_party import deep_gemm as patched_deep_gemm | ||
| func0_ = getattr(deep_gemm, 'get_col_major_tma_aligned_tensor', None) | ||
| func1_ = getattr(deep_gemm, 'm_grouped_gemm_fp8_fp8_bf16_nt_masked', None) | ||
| deep_gemm.get_col_major_tma_aligned_tensor = patched_deep_gemm.get_mn_major_tma_aligned_tensor | ||
| deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = patched_deep_gemm.m_grouped_fp8_gemm_nt_masked | ||
|
|
||
| # patch dlblas | ||
| dlblas.kernels.fused_moe_v3.use_deep_gemm = True | ||
| dlblas.kernels.fused_moe_v3.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous = \ | ||
| patched_deep_gemm.m_grouped_fp8_gemm_nt_contiguous | ||
| yield | ||
|
|
||
| # unpatch dlblas | ||
| dlblas.kernels.fused_moe_v3.use_deep_gemm = False | ||
|
|
||
| # unpatch deep_gemm | ||
| if func0_ is not None: | ||
| deep_gemm.get_col_major_tma_aligned_tensor = func0_ | ||
| deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = func1_ | ||
| else: | ||
| del deep_gemm.get_col_major_tma_aligned_tensor | ||
| del deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked | ||
|
|
||
|
|
||
| class FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl): | ||
|
|
||
| def __init__(self, | ||
| ep_size: int, | ||
| ep_group: dist.ProcessGroup, | ||
| top_k: int, | ||
| num_experts: int, | ||
| hidden_dim: int, | ||
| renormalize: bool = False, | ||
| block_size: int = 128, | ||
| out_dtype: torch.dtype = torch.bfloat16, | ||
| layer_idx: int = 0): | ||
| super().__init__(top_k, num_experts, renormalize, block_size, out_dtype) | ||
| self.num_experts = num_experts | ||
| self.ep_size = ep_size | ||
| self.ep_group = ep_group | ||
| self.hidden_dim = hidden_dim | ||
| self.block_size = block_size | ||
| self.out_dtype = out_dtype | ||
| self.layer_idx = layer_idx | ||
| try: | ||
| import deep_gemm # noqa: F401 | ||
| self.use_deep_gemm = True | ||
| except ImportError: | ||
| self.use_deep_gemm = False | ||
| logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM') | ||
|
|
||
| # pre-allocate buffer | ||
| self.fusedmoe_build(True) | ||
|
|
||
| def ep_expert_list(self, world_size: int, rank: int): | ||
| """Experts list of current rank.""" | ||
| if get_dist_manager().current_context().dist_config.enable_eplb: | ||
| from dlblas.layers.moe.eplb import get_eplb_phy2log_metadata_by_layer | ||
| phy2log = get_eplb_phy2log_metadata_by_layer(self.layer_idx) | ||
| expert_per_rank = (self.num_experts + world_size - 1) // world_size | ||
| first_expert = rank * expert_per_rank | ||
| last_expert = min(first_expert + expert_per_rank, self.num_experts) | ||
| sliced_phy2log = phy2log[first_expert:last_expert].tolist() | ||
| return sliced_phy2log | ||
| else: | ||
| return super().ep_expert_list(world_size=world_size, rank=rank) | ||
|
|
||
| def forward(self, | ||
| hidden_states: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.LongTensor, | ||
| gate_up_weights: torch.Tensor, | ||
| gate_up_scale: torch.Tensor, | ||
| down_weights: torch.Tensor, | ||
| down_scale: torch.Tensor, | ||
| gate_up_bias: torch.Tensor = None, | ||
| down_bias: torch.Tensor = None, | ||
| expert_list: List[int] = None, | ||
| act_func: Callable = None, | ||
| **kwargs): | ||
| """forward.""" | ||
| hidden_states, topk_weights, topk_ids, split_size = split_inputs_by_attn_tp(hidden_states, topk_weights, | ||
| topk_ids) | ||
|
|
||
| topk_weights = self.do_renormalize(topk_weights) | ||
| step_ctx = get_step_ctx_manager().current_context() | ||
| low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm | ||
| moe = self.fusedmoe_build(low_latency_mode) | ||
| out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights, | ||
| down_scale, expert_list) | ||
|
|
||
| out_states = gather_outputs_by_attn_tp(out_states, split_size) | ||
| return out_states | ||
|
|
||
| def do_renormalize(self, topk_weights): | ||
| return _renormalize(topk_weights, self.renormalize) | ||
|
|
||
| def fusedmoe_build(self, low_latency_mode: bool = False): | ||
| from dlblas.layers.moe.ep_moe import build_deepep_moe | ||
| deepep_moe = build_deepep_moe(low_latency_mode, | ||
| self.ep_size, | ||
| self.ep_group, | ||
| self.num_experts, | ||
| self.hidden_dim, | ||
| self.block_size, | ||
| self.top_k, | ||
| self.out_dtype, | ||
| layer_idx=self.layer_idx, | ||
| chunk_size=16 * 1024) | ||
|
|
||
| # patch forward | ||
| _origin_forward = deepep_moe.forward | ||
| _origin_fusedmoe_forward = deepep_moe.fusedmoe_forward | ||
|
|
||
| def _patched_forward(*args, **kwargs): | ||
| with monk_deep_gemm(): | ||
| out = _origin_forward(*args, **kwargs) | ||
| return out | ||
|
|
||
| def _patched_fusedmoe_forward(*args, **kwargs): | ||
| with monk_deep_gemm(): | ||
| out = _origin_fusedmoe_forward(*args, **kwargs) | ||
| return out | ||
|
|
||
| deepep_moe.forward = _patched_forward | ||
| deepep_moe.fusedmoe_forward = _patched_fusedmoe_forward | ||
|
|
||
| return deepep_moe | ||
|
|
||
|
|
||
| class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder): | ||
| """Triton fused moe blocked f8 builder.""" | ||
|
|
||
| @staticmethod | ||
| def build(top_k: int, | ||
| num_experts: int, | ||
| hidden_dim: int = 1, | ||
| renormalize: bool = False, | ||
| block_size: int = 128, | ||
| ep_size: int = 1, | ||
| ep_group: dist.ProcessGroup = None, | ||
| out_dtype: torch.dtype = torch.float16, | ||
| layer_idx: int = 0, | ||
| custom_gateup_act: bool = False): | ||
| """Build from mlp.""" | ||
| if ep_size > 1: | ||
| assert custom_gateup_act is False, 'Custom gate up activation is not supported in EP MoE.' | ||
| return FusedDeepEpMoEBlockedF8Impl(ep_size=ep_size, | ||
| ep_group=ep_group, | ||
| top_k=top_k, | ||
| num_experts=num_experts, | ||
| hidden_dim=hidden_dim, | ||
| renormalize=renormalize, | ||
| block_size=block_size, | ||
| out_dtype=out_dtype, | ||
| layer_idx=layer_idx) | ||
| else: | ||
| return TritonFusedMoEBlockedF8Impl(top_k=top_k, | ||
| num_experts=num_experts, | ||
| renormalize=renormalize, | ||
| block_size=block_size, | ||
| out_dtype=out_dtype) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.