diff --git a/examples/fused_linear_cross_entropy.py b/examples/fused_linear_cross_entropy.py new file mode 100644 index 00000000..19314938 --- /dev/null +++ b/examples/fused_linear_cross_entropy.py @@ -0,0 +1,242 @@ +"""Fused linear cross entropy implementation for Helion. + +This implementation uses Liger's chunking strategy to reduce memory usage. +""" + +from __future__ import annotations + +import os + +import torch + +import helion +from helion._testing import run_example +import helion.language as hl + +# TritonBench configuration +if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1": + # Low memory configuration + TRITONBENCH_ARGS = {"hidden_size": 2048, "vocab_size": 32000} + +# Maximum chunk size (similar to Liger's MAX_FUSED_SIZE) +MAX_FUSED_SIZE = 65536 // 2 + + +@helion.kernel(static_shapes=True) +def cross_entropy_kernel( + logits_chunk: torch.Tensor, # [chunk_size, vocab_size] + target_chunk: torch.Tensor, # [chunk_size] + loss_chunk: torch.Tensor, # [chunk_size] + chunk_size: int, + vocab_size: int, + n_total_samples: int, # Total number of samples for mean reduction +) -> torch.Tensor: + # Grid over samples - each program handles one sample + for program_id in hl.grid(chunk_size): + target_idx = target_chunk[program_id].unsqueeze(0) + + # Online softmax: first pass - find max and sum + m = hl.full([], float("-inf")) # max value + d = hl.full([], 0.0) # sum of exp + + # Store original value at target + ori_logit_y = logits_chunk[program_id, target_idx] + + # Process in blocks like Liger + for vocab_tile in hl.tile(vocab_size): + # Create block offsets (like tl.arange in Triton) + block_offsets = vocab_tile.index + + # Masked load of block + mask = block_offsets < vocab_size + logits_block = torch.where( + mask, logits_chunk[program_id, block_offsets], float("-inf") + ) + + # Find block max + block_max = torch.max(logits_block) + + # Online softmax update + m_new = torch.maximum(m, block_max) + d = d * torch.exp(m - m_new) + torch.sum(torch.exp(logits_block - m_new)) + m = m_new + + # Compute log-sum-exp + lse = m + torch.log(d) + loss = lse - ori_logit_y + # Apply mean reduction inside the kernel + loss_chunk[program_id] = (loss / n_total_samples).squeeze(0) + + # Second pass: compute gradients with block processing + for vocab_tile in hl.tile(vocab_size): + block_offsets = vocab_tile.index + mask = block_offsets < vocab_size + + # Load block + logits_block = torch.where( + mask, logits_chunk[program_id, block_offsets], 0.0 + ) + + # Compute softmax + softmax_block = torch.exp(logits_block - m) / d + + # Special handling for target + is_target_block = block_offsets == target_idx + grad_block = torch.where( + is_target_block, softmax_block - 1.0, softmax_block + ) + + # Apply mean reduction to gradients + grad_block = grad_block / n_total_samples + + # Masked store using torch.where pattern + # First, load existing values for positions that will be masked out + existing_values = logits_chunk[program_id, block_offsets] + + # Apply mask to the gradient block + logits_chunk[program_id, block_offsets] = torch.where( + mask, grad_block, existing_values + ) + + # Return the loss chunk for testing purposes + return loss_chunk + + +def fused_linear_cross_entropy_forward( + _input: torch.Tensor, + weight: torch.Tensor, + target: torch.Tensor, + bias: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """Forward pass with chunking strategy similar to Liger.""" + device = _input.device + BT, H = _input.shape + V = weight.shape[0] + + # Calculate chunk size to limit memory usage + inc_factor = (V + H - 1) // H + chunk_size = min(MAX_FUSED_SIZE, (BT + inc_factor - 1) // inc_factor) + chunk_size = min(chunk_size, BT) + num_chunks = (BT + chunk_size - 1) // chunk_size + + # Initialize gradients and loss + grad_input = torch.zeros_like(_input) + grad_weight = torch.zeros_like(weight) if weight.requires_grad else None + grad_bias = torch.zeros_like(bias) if bias is not None else None + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) + + # Process in chunks + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + actual_chunk_size = end_idx - start_idx + + # Get chunk of input and target + input_chunk = _input[start_idx:end_idx] # [chunk_size, H] + target_chunk = target[start_idx:end_idx] # [chunk_size] + + # Compute logits for this chunk + logits_chunk = input_chunk @ weight.t() # [chunk_size, V] + if bias is not None: + logits_chunk = logits_chunk + bias + + # Ensure contiguous for kernel + logits_chunk = logits_chunk.contiguous() + target_chunk = target_chunk.contiguous() + + # Get loss slice + loss_chunk = loss_1d[start_idx:end_idx] + + # Call kernel - this modifies logits_chunk in-place to contain gradients + cross_entropy_kernel( + logits_chunk, + target_chunk, + loss_chunk, + actual_chunk_size, + V, + BT, # Pass total number of samples for mean reduction + ) + + # Now logits_chunk contains gradients + # Compute input gradient: grad_input = grad_logits @ weight + grad_input[start_idx:end_idx] = logits_chunk.detach() @ weight.detach() + + # Accumulate weight gradients if needed + if grad_weight is not None: + # grad_weight += grad_logits.T @ input + # Detach tensors to avoid autograd issues with in-place operations + torch.addmm( + input=grad_weight, + mat1=logits_chunk.detach().t(), + mat2=input_chunk.detach(), + out=grad_weight, + alpha=1.0, + beta=1.0, + ) + + if grad_bias is not None: + torch.add( + input=grad_bias, + other=logits_chunk.detach().sum(dim=0), + out=grad_bias, + alpha=1.0, + ) + + # Return total loss + loss = loss_1d.sum() + + return loss, grad_input, grad_weight, grad_bias + + +# User-facing function +def fused_linear_cross_entropy( + input_tensor: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + bias: torch.Tensor | None = None, +) -> torch.Tensor: + """Fused linear + cross entropy.""" + loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( + input_tensor, weight, labels, bias + ) + + # For this example, we just return the loss + # In a real implementation with autograd, we'd save gradients for backward + return loss + + +def fused_linear_cross_entropy_pytorch( + input_tensor: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + bias: torch.Tensor | None = None, +) -> torch.Tensor: + """PyTorch reference implementation for fused linear cross entropy.""" + # Compute logits + logits = torch.matmul(input_tensor, weight.T) + if bias is not None: + logits = logits + bias + # Compute cross entropy + return torch.nn.functional.cross_entropy(logits, labels) + + +def main() -> None: + n, h, v = 128, 512, 1000 + torch.manual_seed(42) + input_tensor = torch.randn(n, h, device="cuda", dtype=torch.float32) + weight = torch.randn(v, h, device="cuda", dtype=torch.float32) + labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long) + + run_example( + fused_linear_cross_entropy, + fused_linear_cross_entropy_pytorch, + (input_tensor, weight, labels), + kernel_name="helion", + baseline_name="torch", + rtol=1e-3, + atol=1e-3, + ) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index bad1eef1..dc44cae8 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -705,6 +705,139 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): _launcher(_fp8_gemm_kernel, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestExamples.test_fused_linear_cross_entropy) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from torch._inductor.runtime.triton_helpers import math as tl_math +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _cross_entropy_kernel_kernel(target_chunk, logits_chunk, loss_chunk, vocab_size, n_total_samples, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + load = tl.load(target_chunk + offset_0 * 1, None) + target_idx = load[None] + m = tl.full([], float('-inf'), tl.float32) + d = tl.full([], 0.0, tl.float32) + ori_logit_y = tl.load(logits_chunk + (offset_0 * 1000 + target_idx * 1), None) + for offset_1 in tl.range(0, vocab_size.to(tl.int32), _BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_1 < vocab_size + m_copy = m + d_copy = d + m_copy_0 = m_copy + d_copy_0 = d_copy + v_0 = vocab_size.to(tl.int32) + v_1 = indices_1 < v_0 + load_1 = tl.load(logits_chunk + (offset_0 * 1000 + indices_1 * 1), mask_1, other=0) + v_2 = float('-inf') + v_3 = v_2[None] + v_4 = tl.where(v_1, load_1, v_3) + _mask_to = tl.where(mask_1, v_4, float('-inf')) + block_max = tl.max(_mask_to, 0) + v_5 = triton_helpers.maximum(m_copy_0, block_max) + v_6 = m_copy_0 - v_5 + v_7 = tl_math.exp(v_6) + v_8 = d_copy_0 * v_7 + v_9 = v_5[None] + v_10 = v_4 - v_9 + v_11 = tl_math.exp(v_10) + _mask_to_1 = tl.where(mask_1, v_11, 0) + sum_1 = tl.sum(_mask_to_1, 0) + d = v_8 + sum_1 + m = v_5 + v_13 = tl_math.log(d) + v_14 = m + v_13 + v_15 = v_14[None] + v_16 = v_15 - ori_logit_y + v_17 = n_total_samples.to(tl.float32) + v_18 = v_16 / v_17 + squeeze = tl.reshape(v_18, []) + tl.store(loss_chunk + offset_0 * 1, squeeze, None) + for offset_2 in tl.range(0, vocab_size.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < vocab_size + m_copy_1 = m + d_copy_1 = d + target_idx_copy = target_idx + m_copy_1_0 = m_copy_1 + d_copy_1_0 = d_copy_1 + target_idx_copy_0 = target_idx_copy + v_19 = vocab_size.to(tl.int32) + v_20 = indices_2 < v_19 + load_2 = tl.load(logits_chunk + (offset_0 * 1000 + indices_2 * 1), mask_2, other=0) + v_21 = 0.0 + v_22 = v_21[None] + v_23 = tl.where(v_20, load_2, v_22) + v_24 = m_copy_1_0[None] + v_25 = v_23 - v_24 + v_26 = tl_math.exp(v_25) + v_27 = d_copy_1_0[None] + v_28 = v_26 / v_27 + v_29 = indices_2.to(tl.int64) + v_30 = v_29 == target_idx_copy_0 + v_31 = 1.0 + v_32 = v_28 - v_31 + v_33 = tl.where(v_30, v_32, v_28) + v_34 = n_total_samples.to(tl.float32) + v_35 = v_33 / v_34 + existing_values = tl.load(logits_chunk + (offset_0 * 1000 + indices_2 * 1), mask_2, other=0) + v_36 = tl.where(v_20, v_35, existing_values) + tl.store(logits_chunk + (offset_0 * 1000 + indices_2 * 1), v_36, mask_2) + +def cross_entropy_kernel(logits_chunk: torch.Tensor, target_chunk: torch.Tensor, loss_chunk: torch.Tensor, chunk_size: int, vocab_size: int, n_total_samples: int, *, _launcher=_default_launcher): + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_cross_entropy_kernel_kernel, (chunk_size,), target_chunk, logits_chunk, loss_chunk, vocab_size, n_total_samples, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return loss_chunk + +--- assertExpectedJournal(TestExamples.test_fused_linear_cross_entropy) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime.triton_helpers import math as tl_math +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _cross_entropy_loss_kernel(labels, base_indices, logits_flat, logits, losses, base_indices_stride_0, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + mask_1 = indices_1 < v + labels_tile = tl.load(labels + indices_0 * labels_stride_0, None) + base_indices_tile = tl.load(base_indices + indices_0 * base_indices_stride_0, None) + v_0 = base_indices_tile + labels_tile + logits_at_target = tl.load(logits_flat + v_0 * logits_flat_stride_0, None) + logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0) + _mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf')) + max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1]) + v_1 = logits_rows - max_logits + v_2 = tl_math.exp(v_1) + _mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_2, 0) + sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1]) + squeeze = tl.reshape(max_logits, [1]) + squeeze_1 = tl.reshape(sum_exp, [1]) + v_3 = tl_math.log(squeeze_1) + v_4 = squeeze + v_3 + v_5 = v_4 - logits_at_target + tl.store(losses + indices_0 * losses_stride_0, v_5, None) + +def cross_entropy_loss(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_default_launcher): + n, v = logits.shape + losses = torch.zeros([n], dtype=torch.float32, device=logits.device) + base_indices = torch.arange(n, device=logits.device) * v + logits_flat = logits.view(-1) + _RDIM_SIZE_1 = triton.next_power_of_2(v) + _launcher(_cross_entropy_loss_kernel, (n,), labels, base_indices, logits_flat, logits, losses, base_indices.stride(0), labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return losses.mean() + --- assertExpectedJournal(TestExamples.test_jagged_dense_add) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 95a41575..8c9a9875 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -281,6 +281,51 @@ def test_cross_entropy(self): ) ) + def test_fused_linear_cross_entropy(self): + # Test the cross_entropy_kernel + chunk_size, vocab_size = 32, 1000 + n_total_samples = 128 + + # Create test data + logits_chunk = torch.randn( + chunk_size, vocab_size, device=DEVICE, dtype=torch.float32 + ) + target_chunk = torch.randint( + 0, vocab_size, (chunk_size,), device=DEVICE, dtype=torch.long + ) + loss_chunk = torch.zeros(chunk_size, device=DEVICE, dtype=torch.float32) + + # Make a copy for reference computation + logits_copy = logits_chunk.clone() + + # Prepare args for the kernel + args = ( + logits_chunk, + target_chunk, + loss_chunk, + chunk_size, + vocab_size, + n_total_samples, + ) + + # Compute expected loss per sample + expected_losses = ( + torch.nn.functional.cross_entropy( + logits_copy, target_chunk, reduction="none" + ) + / n_total_samples + ) + + # Test using check_example + self.assertExpectedJournal( + check_example( + "fused_linear_cross_entropy", + args, + expected_losses, + fn_name="cross_entropy_kernel", + ) + ) + def test_rms_norm(self): args = ( torch.randn([128, 256], device=DEVICE, dtype=torch.float16),