Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 153 additions & 1 deletion unsloth/kernels/moe/grouped_gemm/kernels/autotuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
# Copyright 2023-present the Unsloth team. All rights reserved.

"""
Autotuning utils
Autotuning utilities for GPU kernel configuration generation and optimization.

This module provides functions to generate and prune kernel configurations
for grouped GEMM operations, including forward pass, backward pass (dX), and
weight gradient (dW) computations.
"""

import logging
Expand All @@ -24,6 +28,16 @@


def val_to_list(val):
"""
Convert a single value to a list or return None if input is None.

Args:
val: Input value that can be None, a list, or a single value

Returns:
None if input is None, the original list if input is a list,
or a single-element list containing the input value
"""
if val is None:
return None
elif isinstance(val, list):
Expand All @@ -33,6 +47,15 @@ def val_to_list(val):


def convert_args_to_list(args):
"""
Convert each argument in a list to a list format using val_to_list.

Args:
args: List of arguments to convert

Returns:
List where each element has been processed by val_to_list
"""
return [val_to_list(arg) for arg in args]


Expand All @@ -47,6 +70,23 @@ def get_forward_configs(
num_stages=DEFAULT_NUM_STAGES,
num_ctas=DEFAULT_NUM_CTAS,
):
"""
Generate kernel configurations for forward pass GEMM operations.

Args:
BLOCK_M: Block sizes for M dimension
BLOCK_N: Block sizes for N dimension
BLOCK_K: Block sizes for K dimension
TMA_LOAD_X: Whether to use TMA (Tensor Memory Accelerator) for loading X
TMA_LOAD_W: Whether to use TMA for loading weights
TMA_STORE: Whether to use TMA for storing results (currently disabled)
num_warps: Number of warps per thread block
num_stages: Number of pipeline stages
num_ctas: Number of cooperative thread arrays

Returns:
List of triton.Config objects containing all combinations of the input parameters
"""
(
BLOCK_M,
BLOCK_N,
Expand Down Expand Up @@ -122,6 +162,23 @@ def get_dX_kernel_configs(
num_stages=DEFAULT_NUM_STAGES,
num_ctas=DEFAULT_NUM_CTAS,
):
"""
Generate kernel configurations for backward pass dX gradient computation.

Args:
BLOCK_M: Block sizes for M dimension
BLOCK_N: Block sizes for N dimension
BLOCK_K: Block sizes for K dimension
TMA_LOAD_dY: Whether to use TMA for loading output gradients
TMA_LOAD_W: Whether to use TMA for loading weights
TMA_STORE: Whether to use TMA for storing results (currently disabled)
num_warps: Number of warps per thread block
num_stages: Number of pipeline stages
num_ctas: Number of cooperative thread arrays

Returns:
List of triton.Config objects for dX gradient computation
"""
(
BLOCK_M,
BLOCK_N,
Expand Down Expand Up @@ -197,6 +254,23 @@ def get_dW_kernel_configs(
TMA_LOAD_X=True,
TMA_STORE=False,
):
"""
Generate kernel configurations for weight gradient (dW) computation.

Args:
BLOCK_M: Block sizes for M dimension
BLOCK_N: Block sizes for N dimension
BLOCK_K: Block sizes for K dimension
num_warps: Number of warps per thread block
num_stages: Number of pipeline stages
num_ctas: Number of cooperative thread arrays
TMA_LOAD_dY: Whether to use TMA for loading output gradients
TMA_LOAD_X: Whether to use TMA for loading input data
TMA_STORE: Whether to use TMA for storing results

Returns:
List of triton.Config objects for weight gradient computation
"""
(
BLOCK_M,
BLOCK_N,
Expand Down Expand Up @@ -268,6 +342,19 @@ def estimate_smem_reqs(
BLOCK_SIZE_K: int,
dtype: torch.dtype,
):
"""
Estimate shared memory requirements for a kernel configuration.

Args:
num_stages: Number of pipeline stages
BLOCK_SIZE_M: Block size in M dimension
BLOCK_SIZE_N: Block size in N dimension
BLOCK_SIZE_K: Block size in K dimension
dtype: Data type of the tensors

Returns:
Estimated shared memory requirement in bytes
"""
num_bytes = dtype.itemsize
return (
num_stages * BLOCK_SIZE_K * (BLOCK_SIZE_M + BLOCK_SIZE_N)
Expand All @@ -284,13 +371,39 @@ def exceeds_smem_capacity(
smem_size: int,
slack: float = 50000,
):
"""
Check if a kernel configuration exceeds shared memory capacity.

Args:
num_stages: Number of pipeline stages
BLOCK_SIZE_M: Block size in M dimension
BLOCK_SIZE_N: Block size in N dimension
BLOCK_SIZE_K: Block size in K dimension
dtype: Data type of the tensors
smem_size: Available shared memory size in bytes
slack: Additional buffer space to account for overhead

Returns:
True if the configuration exceeds shared memory capacity
"""
smem_reqs = estimate_smem_reqs(
num_stages, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, dtype
)
return smem_reqs > smem_size + slack


def common_prune_criteria(config: triton.Config, kwargs: dict, dtype):
"""
Apply common pruning criteria to filter out invalid kernel configurations.

Args:
config: Triton kernel configuration to evaluate
kwargs: Kernel arguments containing problem dimensions and flags
dtype: Data type of the tensors

Returns:
True if the configuration should be pruned (removed)
"""
from grouped_gemm.interface import supports_tma
from grouped_gemm.kernels.tuning import get_device_properties

Expand Down Expand Up @@ -323,6 +436,12 @@ def common_prune_criteria(config: triton.Config, kwargs: dict, dtype):


def maybe_disable_tma(config: triton.Config):
"""
Disable TMA (Tensor Memory Accelerator) features if not supported by the GPU.

Args:
config: Triton kernel configuration to modify in-place
"""
from grouped_gemm.interface import supports_tma

tma_keys = [k for k in config.kwargs.keys() if k.startswith("USE_TMA_")]
Expand All @@ -333,6 +452,17 @@ def maybe_disable_tma(config: triton.Config):


def prune_kernel_configs_fwd(configs: list[triton.Config], args, **kwargs):
"""
Prune kernel configurations for forward pass operations.

Args:
configs: List of kernel configurations to filter
args: Positional arguments (unused)
**kwargs: Keyword arguments containing tensor pointers and operation flags

Returns:
Filtered list of valid kernel configurations
"""
x = kwargs["x_ptr"]
dtype = x.dtype

Expand All @@ -358,6 +488,17 @@ def prune_kernel_configs_fwd(configs: list[triton.Config], args, **kwargs):


def prune_dX_configs(configs: List[triton.Config], args, **kwargs):
"""
Prune kernel configurations for dX gradient computation.

Args:
configs: List of kernel configurations to filter
args: Positional arguments (unused)
**kwargs: Keyword arguments containing tensor pointers and operation flags

Returns:
Filtered list of valid kernel configurations for dX computation
"""
dtype = kwargs["w_ptr"].dtype

logger.debug(f"Pruning configs: {len(configs)}")
Expand All @@ -378,6 +519,17 @@ def prune_dX_configs(configs: List[triton.Config], args, **kwargs):


def prune_kernel_configs_backward_dW(configs: list[triton.Config], args, **kwargs):
"""
Prune kernel configurations for weight gradient (dW) computation.

Args:
configs: List of kernel configurations to filter
args: Positional arguments (unused)
**kwargs: Keyword arguments containing tensor pointers and operation flags

Returns:
Filtered list of valid kernel configurations for dW computation
"""
dtype = kwargs["x_ptr"].dtype

pruned_configs = []
Expand Down
Loading