diff --git a/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py b/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py index d57105cee..a611797d1 100644 --- a/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py +++ b/unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py @@ -2,7 +2,11 @@ # Copyright 2023-present the Unsloth team. All rights reserved. """ -Autotuning utils +Autotuning utilities for GPU kernel configuration generation and optimization. + +This module provides functions to generate and prune kernel configurations +for grouped GEMM operations, including forward pass, backward pass (dX), and +weight gradient (dW) computations. """ import logging @@ -24,6 +28,16 @@ def val_to_list(val): + """ + Convert a single value to a list or return None if input is None. + + Args: + val: Input value that can be None, a list, or a single value + + Returns: + None if input is None, the original list if input is a list, + or a single-element list containing the input value + """ if val is None: return None elif isinstance(val, list): @@ -33,6 +47,15 @@ def val_to_list(val): def convert_args_to_list(args): + """ + Convert each argument in a list to a list format using val_to_list. + + Args: + args: List of arguments to convert + + Returns: + List where each element has been processed by val_to_list + """ return [val_to_list(arg) for arg in args] @@ -47,6 +70,23 @@ def get_forward_configs( num_stages=DEFAULT_NUM_STAGES, num_ctas=DEFAULT_NUM_CTAS, ): + """ + Generate kernel configurations for forward pass GEMM operations. + + Args: + BLOCK_M: Block sizes for M dimension + BLOCK_N: Block sizes for N dimension + BLOCK_K: Block sizes for K dimension + TMA_LOAD_X: Whether to use TMA (Tensor Memory Accelerator) for loading X + TMA_LOAD_W: Whether to use TMA for loading weights + TMA_STORE: Whether to use TMA for storing results (currently disabled) + num_warps: Number of warps per thread block + num_stages: Number of pipeline stages + num_ctas: Number of cooperative thread arrays + + Returns: + List of triton.Config objects containing all combinations of the input parameters + """ ( BLOCK_M, BLOCK_N, @@ -122,6 +162,23 @@ def get_dX_kernel_configs( num_stages=DEFAULT_NUM_STAGES, num_ctas=DEFAULT_NUM_CTAS, ): + """ + Generate kernel configurations for backward pass dX gradient computation. + + Args: + BLOCK_M: Block sizes for M dimension + BLOCK_N: Block sizes for N dimension + BLOCK_K: Block sizes for K dimension + TMA_LOAD_dY: Whether to use TMA for loading output gradients + TMA_LOAD_W: Whether to use TMA for loading weights + TMA_STORE: Whether to use TMA for storing results (currently disabled) + num_warps: Number of warps per thread block + num_stages: Number of pipeline stages + num_ctas: Number of cooperative thread arrays + + Returns: + List of triton.Config objects for dX gradient computation + """ ( BLOCK_M, BLOCK_N, @@ -197,6 +254,23 @@ def get_dW_kernel_configs( TMA_LOAD_X=True, TMA_STORE=False, ): + """ + Generate kernel configurations for weight gradient (dW) computation. + + Args: + BLOCK_M: Block sizes for M dimension + BLOCK_N: Block sizes for N dimension + BLOCK_K: Block sizes for K dimension + num_warps: Number of warps per thread block + num_stages: Number of pipeline stages + num_ctas: Number of cooperative thread arrays + TMA_LOAD_dY: Whether to use TMA for loading output gradients + TMA_LOAD_X: Whether to use TMA for loading input data + TMA_STORE: Whether to use TMA for storing results + + Returns: + List of triton.Config objects for weight gradient computation + """ ( BLOCK_M, BLOCK_N, @@ -268,6 +342,19 @@ def estimate_smem_reqs( BLOCK_SIZE_K: int, dtype: torch.dtype, ): + """ + Estimate shared memory requirements for a kernel configuration. + + Args: + num_stages: Number of pipeline stages + BLOCK_SIZE_M: Block size in M dimension + BLOCK_SIZE_N: Block size in N dimension + BLOCK_SIZE_K: Block size in K dimension + dtype: Data type of the tensors + + Returns: + Estimated shared memory requirement in bytes + """ num_bytes = dtype.itemsize return ( num_stages * BLOCK_SIZE_K * (BLOCK_SIZE_M + BLOCK_SIZE_N) @@ -284,6 +371,21 @@ def exceeds_smem_capacity( smem_size: int, slack: float = 50000, ): + """ + Check if a kernel configuration exceeds shared memory capacity. + + Args: + num_stages: Number of pipeline stages + BLOCK_SIZE_M: Block size in M dimension + BLOCK_SIZE_N: Block size in N dimension + BLOCK_SIZE_K: Block size in K dimension + dtype: Data type of the tensors + smem_size: Available shared memory size in bytes + slack: Additional buffer space to account for overhead + + Returns: + True if the configuration exceeds shared memory capacity + """ smem_reqs = estimate_smem_reqs( num_stages, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, dtype ) @@ -291,6 +393,17 @@ def exceeds_smem_capacity( def common_prune_criteria(config: triton.Config, kwargs: dict, dtype): + """ + Apply common pruning criteria to filter out invalid kernel configurations. + + Args: + config: Triton kernel configuration to evaluate + kwargs: Kernel arguments containing problem dimensions and flags + dtype: Data type of the tensors + + Returns: + True if the configuration should be pruned (removed) + """ from grouped_gemm.interface import supports_tma from grouped_gemm.kernels.tuning import get_device_properties @@ -323,6 +436,12 @@ def common_prune_criteria(config: triton.Config, kwargs: dict, dtype): def maybe_disable_tma(config: triton.Config): + """ + Disable TMA (Tensor Memory Accelerator) features if not supported by the GPU. + + Args: + config: Triton kernel configuration to modify in-place + """ from grouped_gemm.interface import supports_tma tma_keys = [k for k in config.kwargs.keys() if k.startswith("USE_TMA_")] @@ -333,6 +452,17 @@ def maybe_disable_tma(config: triton.Config): def prune_kernel_configs_fwd(configs: list[triton.Config], args, **kwargs): + """ + Prune kernel configurations for forward pass operations. + + Args: + configs: List of kernel configurations to filter + args: Positional arguments (unused) + **kwargs: Keyword arguments containing tensor pointers and operation flags + + Returns: + Filtered list of valid kernel configurations + """ x = kwargs["x_ptr"] dtype = x.dtype @@ -358,6 +488,17 @@ def prune_kernel_configs_fwd(configs: list[triton.Config], args, **kwargs): def prune_dX_configs(configs: List[triton.Config], args, **kwargs): + """ + Prune kernel configurations for dX gradient computation. + + Args: + configs: List of kernel configurations to filter + args: Positional arguments (unused) + **kwargs: Keyword arguments containing tensor pointers and operation flags + + Returns: + Filtered list of valid kernel configurations for dX computation + """ dtype = kwargs["w_ptr"].dtype logger.debug(f"Pruning configs: {len(configs)}") @@ -378,6 +519,17 @@ def prune_dX_configs(configs: List[triton.Config], args, **kwargs): def prune_kernel_configs_backward_dW(configs: list[triton.Config], args, **kwargs): + """ + Prune kernel configurations for weight gradient (dW) computation. + + Args: + configs: List of kernel configurations to filter + args: Positional arguments (unused) + **kwargs: Keyword arguments containing tensor pointers and operation flags + + Returns: + Filtered list of valid kernel configurations for dW computation + """ dtype = kwargs["x_ptr"].dtype pruned_configs = [] diff --git a/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py b/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py index 3d1afac9e..6e833332e 100644 --- a/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py +++ b/unsloth/kernels/moe/grouped_gemm/reference/layers/llama4_moe.py @@ -32,6 +32,24 @@ @dataclass class Llama4MoeResult: + """Container for storing intermediate and final results from Llama4 MoE forward pass. + + This dataclass holds all the intermediate tensors and computations from the MoE + forward pass, useful for debugging and analysis purposes. + + Attributes: + token_counts_by_expert: Number of tokens assigned to each expert + gather_indices: Indices for gathering tokens in expert order + topk_weights: Top-k routing weights from the router + hidden_states_after_weight_merge: Hidden states after applying routing weights + first_gemm: Output of the first grouped GEMM operation + intermediate: Output after activation and multiplication + second_gemm: Output of the second grouped GEMM operation + hidden_states_unpermute: Hidden states after unpermuting from expert order + shared_expert_out: Output from the shared expert + final_out: Final output combining expert and shared expert outputs + router_logits: Raw logits from the router (optional) + """ token_counts_by_expert: torch.Tensor gather_indices: torch.Tensor topk_weights: torch.Tensor @@ -46,6 +64,12 @@ class Llama4MoeResult: class Llama4GroupedGemmTextMoe(Llama4TextMoe): + """Llama4 MoE implementation using torch-native grouped GEMM operations. + + This class extends the standard Llama4TextMoe with optimized grouped GEMM + operations for better performance. It permutes expert weights in-place for + optimal memory layout and supports optional router-shared expert overlap. + """ EXPERT_WEIGHT_NAMES = ["experts.gate_up_proj", "experts.down_proj"] def __init__( @@ -55,6 +79,14 @@ def __init__( verbose=False, debug=False, ): + """Initialize the Llama4GroupedGemmTextMoe layer. + + Args: + config: Llama4 text configuration containing model parameters + overlap_router_shared: Whether to overlap router and shared expert computation + verbose: Whether to print detailed initialization information + debug: Whether to return detailed debug information in forward pass + """ super().__init__(config) self.overlap_router_shared = overlap_router_shared self.verbose = verbose @@ -102,6 +134,17 @@ def __init__( @torch.no_grad def copy_weights(self, other: Llama4TextMoe): + """Copy weights from another Llama4TextMoe instance. + + Copies all parameters from the source model, applying necessary permutations + for expert weights to match the expected layout. + + Args: + other: Source Llama4TextMoe model to copy weights from + + Returns: + Self for method chaining + """ for name, param_to_copy in other.named_parameters(): if self.verbose: print(f"Copying {name} with shape {param_to_copy.shape}") @@ -118,6 +161,17 @@ def copy_weights(self, other: Llama4TextMoe): return self def check_weights(self, other: Llama4TextMoe): + """Verify that weights match another Llama4TextMoe instance. + + Compares all parameters with the reference model, applying necessary + permutations for expert weights before comparison. + + Args: + other: Reference Llama4TextMoe model to compare against + + Raises: + AssertionError: If any weights don't match or aren't contiguous + """ for name, other_param in other.named_parameters(): if any(n in name for n in self.EXPERT_WEIGHT_NAMES): other_param = other_param.permute(0, 2, 1) @@ -126,12 +180,34 @@ def check_weights(self, other: Llama4TextMoe): assert param.is_contiguous(), f"{name} not contiguous!" def act_and_mul(self, x: torch.Tensor) -> torch.Tensor: + """Apply activation function to gate projection and multiply with up projection. + + Splits the input tensor into gate and up projections, applies the activation + function to the gate projection, and multiplies with the up projection. + + Args: + x: Input tensor with shape [..., 2 * expert_dim] + + Returns: + Result of activation(gate_proj) * up_proj with shape [..., expert_dim] + """ assert x.shape[-1] == 2 * self.experts.expert_dim gate_proj = x[..., : self.experts.expert_dim] up_proj = x[..., self.experts.expert_dim :] return self.experts.act_fn(gate_proj) * up_proj def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Run the router to compute routing weights and select experts. + + Computes router logits, selects top-k experts, and applies sigmoid + activation to the routing weights. + + Args: + hidden_states: Input hidden states tensor + + Returns: + Tuple of (router_logits, routing_weights, selected_experts) + """ # router_logits: (batch * sequence_length, n_experts) hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = self.router(hidden_states) @@ -146,6 +222,17 @@ def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor: def get_token_counts_and_gather_indices( self, selected_experts: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute token counts per expert and gather indices for permutation. + + Calculates how many tokens are assigned to each expert and generates + indices for gathering tokens in expert order. + + Args: + selected_experts: Tensor of selected expert indices for each token + + Returns: + Tuple of (token_counts_by_expert, gather_indices) + """ token_counts_by_expert, gather_indices = get_routing_indices( selected_experts, self.num_experts ) @@ -154,7 +241,17 @@ def get_token_counts_and_gather_indices( return token_counts_by_expert, gather_indices def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ + """Forward pass through the MoE layer. + + Processes input through router, applies routing weights, performs grouped + GEMM operations through experts, and combines with shared expert output. + + Args: + hidden_states: Input tensor with shape (batch_size, sequence_length, hidden_dim) + + Returns: + Either Llama4MoeResult (if debug=True) or tuple of (final_output, routing_weights) + """ batch_size, sequence_length, hidden_dim = hidden_states.shape num_tokens = batch_size * sequence_length total_tokens = num_tokens * self.top_k @@ -253,6 +350,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Llama4TritonTextMoe(Llama4GroupedGemmTextMoe): + """Llama4 MoE implementation using Triton-optimized grouped GEMM operations. + + This class extends Llama4GroupedGemmTextMoe to use Triton kernels for grouped + GEMM operations, providing better performance through kernel fusion and + optimized memory access patterns. + """ def __init__( self, config: Llama4TextConfig, @@ -267,6 +370,21 @@ def __init__( dX_only: bool = False, verbose=False, ): + """Initialize the Llama4TritonTextMoe layer. + + Args: + config: Llama4 text configuration containing model parameters + overlap_router_shared: Whether to overlap router and shared expert computation + permute_x: Whether to permute input tensor (not supported for Llama4) + permute_y: Whether to permute output tensor + autotune: Whether to use kernel autotuning + kernel_config_fwd: Forward kernel configuration (required if autotune=False) + kernel_config_bwd_dW: Backward dW kernel configuration (required if autotune=False) + kernel_config_bwd_dX: Backward dX kernel configuration (required if autotune=False) + dW_only: Whether to compute only weight gradients + dX_only: Whether to compute only input gradients + verbose: Whether to print detailed information + """ super().__init__(config, overlap_router_shared=overlap_router_shared) assert not permute_x, ( "Llama4 triton grouped gemm does not support permute x due to pre-multiplication of router weights" @@ -288,6 +406,17 @@ def __init__( @torch.no_grad def copy_weights(self, other: Llama4TextMoe): + """Copy weights from another Llama4TextMoe instance. + + Copies all parameters from the source model, applying necessary permutations + for expert weights to match the expected layout. + + Args: + other: Source Llama4TextMoe model to copy weights from + + Returns: + Self for method chaining + """ for name, param_to_copy in other.named_parameters(): if self.verbose: print(f"Copying {name} with shape {param_to_copy.shape}") @@ -304,6 +433,17 @@ def copy_weights(self, other: Llama4TextMoe): return self def check_weights(self, other: Llama4TextMoe): + """Verify that weights match another Llama4TextMoe instance. + + Compares all parameters with the reference model, applying necessary + permutations for expert weights before comparison. + + Args: + other: Reference Llama4TextMoe model to compare against + + Raises: + AssertionError: If any weights don't match or aren't contiguous + """ for name, other_param in other.named_parameters(): if any(n in name for n in self.EXPERT_WEIGHT_NAMES): other_param = other_param.permute(0, 2, 1) @@ -312,12 +452,34 @@ def check_weights(self, other: Llama4TextMoe): assert param.is_contiguous(), f"{name} not contiguous!" def act_and_mul(self, x: torch.Tensor) -> torch.Tensor: + """Apply activation function to gate projection and multiply with up projection. + + Splits the input tensor into gate and up projections, applies the activation + function to the gate projection, and multiplies with the up projection. + + Args: + x: Input tensor with shape [..., 2 * expert_dim] + + Returns: + Result of activation(gate_proj) * up_proj with shape [..., expert_dim] + """ assert x.shape[-1] == 2 * self.experts.expert_dim gate_proj = x[..., : self.experts.expert_dim] up_proj = x[..., self.experts.expert_dim :] return self.experts.act_fn(gate_proj) * up_proj def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Run the router to compute routing weights and select experts. + + Computes router logits, selects top-k experts, and applies sigmoid + activation to the routing weights. + + Args: + hidden_states: Input hidden states tensor + + Returns: + Tuple of (router_logits, routing_weights, selected_experts) + """ # router_logits: (batch * sequence_length, n_experts) hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = self.router(hidden_states) @@ -332,6 +494,17 @@ def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor: def get_token_counts_and_gather_indices( self, selected_experts: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute token counts per expert and gather indices for permutation. + + Calculates how many tokens are assigned to each expert and generates + indices for gathering tokens in expert order. + + Args: + selected_experts: Tensor of selected expert indices for each token + + Returns: + Tuple of (token_counts_by_expert, gather_indices) + """ token_counts_by_expert, gather_indices = get_routing_indices( selected_experts, self.num_experts ) @@ -340,7 +513,17 @@ def get_token_counts_and_gather_indices( return token_counts_by_expert, gather_indices def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ + """Forward pass through the MoE layer using Triton grouped GEMM. + + Processes input through router, applies routing weights, performs Triton-optimized + grouped GEMM operations through experts, and combines with shared expert output. + + Args: + hidden_states: Input tensor with shape (batch_size, sequence_length, hidden_dim) + + Returns: + Tuple of (final_output, routing_weights) + """ batch_size, sequence_length, hidden_dim = hidden_states.shape num_tokens = batch_size * sequence_length total_tokens = num_tokens * self.top_k diff --git a/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py b/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py index 0ca4391b2..d6fd65ba6 100644 --- a/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py +++ b/unsloth/kernels/moe/grouped_gemm/reference/layers/qwen3_moe.py @@ -37,6 +37,18 @@ @dataclass class GroupedGEMMResult: + """Container for storing intermediate and final results from grouped GEMM operations. + + Attributes: + token_counts_by_expert: Number of tokens assigned to each expert + gather_indices: Indices used for token permutation and unpermutation + topk_weights: Routing weights for the top-k selected experts + first_gemm: Output from the first grouped matrix multiplication + intermediate: Result after applying activation function and element-wise multiplication + second_gemm: Output from the second grouped matrix multiplication + hidden_states_unpermute: Hidden states after unpermutation from expert order to token order + hidden_states: Final output hidden states + """ token_counts_by_expert: torch.Tensor gather_indices: torch.Tensor topk_weights: torch.Tensor @@ -48,6 +60,12 @@ class GroupedGEMMResult: class Qwen3MoeGroupedGEMMBlock(torch.nn.Module): + """Reference implementation of Qwen3 Mixture of Experts block using grouped GEMM operations. + + This implementation uses torch-native operations and stores intermediate results for debugging. + It implements the MoE routing mechanism with top-k expert selection and grouped matrix multiplications. + """ + def __init__( self, config, @@ -55,6 +73,14 @@ def __init__( gate_up_proj: torch.Tensor, down_proj: torch.Tensor, ): + """Initialize the Qwen3 MoE block with expert weights. + + Args: + config: Qwen3MoeConfig containing model configuration parameters + gate: Router gate weights for expert selection [num_experts, hidden_size] + gate_up_proj: Combined gate and up projection weights [num_experts, 2*moe_intermediate_size, hidden_size] + down_proj: Down projection weights [num_experts, hidden_size, moe_intermediate_size] + """ super().__init__() self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok @@ -84,6 +110,17 @@ def __init__( @staticmethod def extract_hf_weights(moe_block: Qwen3MoeSparseMoeBlock): + """Extract and reorganize weights from a HuggingFace Qwen3MoeSparseMoeBlock. + + Args: + moe_block: HuggingFace Qwen3MoeSparseMoeBlock instance + + Returns: + Tuple containing: + - gate: Router gate weights + - gate_up_proj: Combined gate and up projection weights + - down_proj: Down projection weights + """ config: Qwen3MoeConfig = moe_block.experts[0].config num_experts = config.num_experts @@ -105,11 +142,24 @@ def extract_hf_weights(moe_block: Qwen3MoeSparseMoeBlock): @classmethod def from_hf(cls, moe_block: Qwen3MoeSparseMoeBlock): + """Create a Qwen3MoeGroupedGEMMBlock from a HuggingFace MoE block. + + Args: + moe_block: HuggingFace Qwen3MoeSparseMoeBlock instance + + Returns: + Qwen3MoeGroupedGEMMBlock instance with extracted weights + """ config: Qwen3MoeConfig = moe_block.experts[0].config gate, gate_up_proj, down_proj = cls.extract_hf_weights(moe_block) return cls(config, gate, gate_up_proj, down_proj) def check_weights(self, moe_block: Qwen3MoeSparseMoeBlock): + """Verify that the weights match those in the original HuggingFace MoE block. + + Args: + moe_block: HuggingFace Qwen3MoeSparseMoeBlock to compare against + """ for i in range(self.num_experts): assert self.gate_up_proj[i].equal( torch.cat( @@ -123,12 +173,31 @@ def check_weights(self, moe_block: Qwen3MoeSparseMoeBlock): assert self.down_proj[i].equal(moe_block.experts[i].down_proj.weight.data) def act_and_mul(self, x: torch.Tensor) -> torch.Tensor: + """Apply activation function to gate projection and multiply with up projection. + + Args: + x: Input tensor with shape [..., 2 * moe_intermediate_size] + + Returns: + Result of activation(gate_proj) * up_proj with shape [..., moe_intermediate_size] + """ assert x.shape[-1] == 2 * self.moe_intermediate_size gate_proj = x[..., : self.moe_intermediate_size] up_proj = x[..., self.moe_intermediate_size :] return self.act_fn(gate_proj) * up_proj def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Run the routing mechanism to select top-k experts for each token. + + Args: + hidden_states: Input hidden states [batch_size * seq_len, hidden_size] + + Returns: + Tuple containing: + - router_logits: Raw logits from the router + - routing_weights: Normalized weights for selected experts + - selected_experts: Indices of selected experts for each token + """ # router_logits: (batch * sequence_length, n_experts) router_logits = torch.nn.functional.linear(hidden_states, self.gate) @@ -146,6 +215,16 @@ def run_router(self, hidden_states: torch.Tensor) -> torch.Tensor: def get_token_counts_and_gather_indices( self, selected_experts: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute token counts per expert and gather indices for permutation. + + Args: + selected_experts: Indices of selected experts for each token + + Returns: + Tuple containing: + - token_counts_by_expert: Number of tokens assigned to each expert + - gather_indices: Indices for permuting tokens from token order to expert order + """ token_counts_by_expert, gather_indices = get_routing_indices( selected_experts, self.num_experts ) @@ -154,7 +233,16 @@ def get_token_counts_and_gather_indices( return token_counts_by_expert, gather_indices def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ + """Forward pass through the MoE block. + + Args: + hidden_states: Input tensor [batch_size, seq_len, hidden_size] + + Returns: + Tuple containing: + - GroupedGEMMResult: Container with all intermediate results + - router_logits: Raw routing logits for auxiliary loss computation + """ batch_size, sequence_length, hidden_dim = hidden_states.shape num_tokens = batch_size * sequence_length total_tokens = num_tokens * self.top_k @@ -214,6 +302,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Qwen3MoeFusedGroupedGEMMBlock(Qwen3MoeGroupedGEMMBlock): + """Optimized Qwen3 MoE block using fused grouped GEMM kernels. + + This implementation uses Triton-based grouped GEMM kernels for improved performance + and supports various optimization options like permutation fusion and kernel tuning. + """ + def __init__( self, config: Qwen3MoeConfig, @@ -229,6 +323,22 @@ def __init__( dW_only: bool = False, dX_only: bool = False, ): + """Initialize the fused grouped GEMM MoE block. + + Args: + config: Qwen3MoeConfig containing model configuration + gate: Router gate weights + gate_up_proj: Combined gate and up projection weights + down_proj: Down projection weights + permute_x: Whether to fuse input permutation in the first GEMM + permute_y: Whether to fuse output unpermutation in the second GEMM + autotune: Whether to automatically tune kernel configurations + kernel_config_fwd: Manual kernel configuration for forward pass + kernel_config_bwd_dW: Manual kernel configuration for weight gradients + kernel_config_bwd_dX: Manual kernel configuration for input gradients + dW_only: Whether to compute only weight gradients + dX_only: Whether to compute only input gradients + """ super().__init__(config, gate, gate_up_proj, down_proj) self.permute_x = permute_x self.permute_y = permute_y @@ -258,6 +368,22 @@ def from_hf( dW_only: bool = False, dX_only: bool = False, ): + """Create a fused grouped GEMM MoE block from a HuggingFace MoE block. + + Args: + moe_block: HuggingFace Qwen3MoeSparseMoeBlock instance + permute_x: Whether to fuse input permutation in the first GEMM + permute_y: Whether to fuse output unpermutation in the second GEMM + autotune: Whether to automatically tune kernel configurations + kernel_config_fwd: Manual kernel configuration for forward pass + kernel_config_bwd_dW: Manual kernel configuration for weight gradients + kernel_config_bwd_dX: Manual kernel configuration for input gradients + dW_only: Whether to compute only weight gradients + dX_only: Whether to compute only input gradients + + Returns: + Qwen3MoeFusedGroupedGEMMBlock instance with extracted weights and configurations + """ config: Qwen3MoeConfig = moe_block.experts[0].config gate, gate_up_proj, down_proj = Qwen3MoeGroupedGEMMBlock.extract_hf_weights( moe_block @@ -278,6 +404,16 @@ def from_hf( ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward pass using fused grouped GEMM kernels. + + Args: + hidden_states: Input tensor [batch_size, seq_len, hidden_size] + + Returns: + Tuple containing: + - hidden_states: Output tensor [batch_size, seq_len, hidden_size] + - router_logits: Raw routing logits for auxiliary loss computation + """ batch_size, sequence_length, hidden_dim = hidden_states.shape num_tokens = batch_size * sequence_length total_tokens = num_tokens * self.top_k