|
| 1 | +"""Fused linear cross entropy implementation for Helion. |
| 2 | +
|
| 3 | +This implementation uses Liger's chunking strategy to reduce memory usage. |
| 4 | +""" |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +import os |
| 9 | + |
| 10 | +import torch |
| 11 | + |
| 12 | +import helion |
| 13 | +from helion._testing import run_example |
| 14 | +import helion.language as hl |
| 15 | + |
| 16 | +# TritonBench configuration |
| 17 | +if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1": |
| 18 | + # Low memory configuration |
| 19 | + TRITONBENCH_ARGS = {"hidden_size": 2048, "vocab_size": 32000} |
| 20 | + |
| 21 | +# Maximum chunk size (similar to Liger's MAX_FUSED_SIZE) |
| 22 | +MAX_FUSED_SIZE = 65536 // 2 |
| 23 | + |
| 24 | + |
| 25 | +@helion.kernel(static_shapes=True) |
| 26 | +def cross_entropy_kernel( |
| 27 | + logits_chunk: torch.Tensor, # [chunk_size, vocab_size] |
| 28 | + target_chunk: torch.Tensor, # [chunk_size] |
| 29 | + loss_chunk: torch.Tensor, # [chunk_size] |
| 30 | + chunk_size: int, |
| 31 | + vocab_size: int, |
| 32 | + n_total_samples: int, # Total number of samples for mean reduction |
| 33 | +) -> torch.Tensor: |
| 34 | + # Grid over samples - each program handles one sample |
| 35 | + for program_id in hl.grid(chunk_size): |
| 36 | + target_idx = target_chunk[program_id].unsqueeze(0) |
| 37 | + |
| 38 | + # Online softmax: first pass - find max and sum |
| 39 | + m = hl.full([], float("-inf")) # max value |
| 40 | + d = hl.full([], 0.0) # sum of exp |
| 41 | + |
| 42 | + # Store original value at target |
| 43 | + ori_logit_y = logits_chunk[program_id, target_idx] |
| 44 | + |
| 45 | + # Process in blocks like Liger |
| 46 | + for vocab_tile in hl.tile(vocab_size): |
| 47 | + # Create block offsets (like tl.arange in Triton) |
| 48 | + block_offsets = vocab_tile.index |
| 49 | + |
| 50 | + # Masked load of block |
| 51 | + mask = block_offsets < vocab_size |
| 52 | + logits_block = torch.where( |
| 53 | + mask, logits_chunk[program_id, block_offsets], float("-inf") |
| 54 | + ) |
| 55 | + |
| 56 | + # Find block max |
| 57 | + block_max = torch.max(logits_block) |
| 58 | + |
| 59 | + # Online softmax update |
| 60 | + m_new = torch.maximum(m, block_max) |
| 61 | + d = d * torch.exp(m - m_new) + torch.sum(torch.exp(logits_block - m_new)) |
| 62 | + m = m_new |
| 63 | + |
| 64 | + # Compute log-sum-exp |
| 65 | + lse = m + torch.log(d) |
| 66 | + loss = lse - ori_logit_y |
| 67 | + # Apply mean reduction inside the kernel |
| 68 | + loss_chunk[program_id] = (loss / n_total_samples).squeeze(0) |
| 69 | + |
| 70 | + # Second pass: compute gradients with block processing |
| 71 | + for vocab_tile in hl.tile(vocab_size): |
| 72 | + block_offsets = vocab_tile.index |
| 73 | + mask = block_offsets < vocab_size |
| 74 | + |
| 75 | + # Load block |
| 76 | + logits_block = torch.where( |
| 77 | + mask, logits_chunk[program_id, block_offsets], 0.0 |
| 78 | + ) |
| 79 | + |
| 80 | + # Compute softmax |
| 81 | + softmax_block = torch.exp(logits_block - m) / d |
| 82 | + |
| 83 | + # Special handling for target |
| 84 | + is_target_block = block_offsets == target_idx |
| 85 | + grad_block = torch.where( |
| 86 | + is_target_block, softmax_block - 1.0, softmax_block |
| 87 | + ) |
| 88 | + |
| 89 | + # Apply mean reduction to gradients |
| 90 | + grad_block = grad_block / n_total_samples |
| 91 | + |
| 92 | + # Masked store using torch.where pattern |
| 93 | + # First, load existing values for positions that will be masked out |
| 94 | + existing_values = logits_chunk[program_id, block_offsets] |
| 95 | + |
| 96 | + # Apply mask to the gradient block |
| 97 | + logits_chunk[program_id, block_offsets] = torch.where( |
| 98 | + mask, grad_block, existing_values |
| 99 | + ) |
| 100 | + |
| 101 | + # Return the loss chunk for testing purposes |
| 102 | + return loss_chunk |
| 103 | + |
| 104 | + |
| 105 | +def fused_linear_cross_entropy_forward( |
| 106 | + _input: torch.Tensor, |
| 107 | + weight: torch.Tensor, |
| 108 | + target: torch.Tensor, |
| 109 | + bias: torch.Tensor | None = None, |
| 110 | +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: |
| 111 | + """Forward pass with chunking strategy similar to Liger.""" |
| 112 | + device = _input.device |
| 113 | + BT, H = _input.shape |
| 114 | + V = weight.shape[0] |
| 115 | + |
| 116 | + # Calculate chunk size to limit memory usage |
| 117 | + inc_factor = (V + H - 1) // H |
| 118 | + chunk_size = min(MAX_FUSED_SIZE, (BT + inc_factor - 1) // inc_factor) |
| 119 | + chunk_size = min(chunk_size, BT) |
| 120 | + num_chunks = (BT + chunk_size - 1) // chunk_size |
| 121 | + |
| 122 | + # Initialize gradients and loss |
| 123 | + grad_input = torch.zeros_like(_input) |
| 124 | + grad_weight = torch.zeros_like(weight) if weight.requires_grad else None |
| 125 | + grad_bias = torch.zeros_like(bias) if bias is not None else None |
| 126 | + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) |
| 127 | + |
| 128 | + # Process in chunks |
| 129 | + for chunk_id in range(num_chunks): |
| 130 | + start_idx = chunk_id * chunk_size |
| 131 | + end_idx = min((chunk_id + 1) * chunk_size, BT) |
| 132 | + actual_chunk_size = end_idx - start_idx |
| 133 | + |
| 134 | + # Get chunk of input and target |
| 135 | + input_chunk = _input[start_idx:end_idx] # [chunk_size, H] |
| 136 | + target_chunk = target[start_idx:end_idx] # [chunk_size] |
| 137 | + |
| 138 | + # Compute logits for this chunk |
| 139 | + logits_chunk = input_chunk @ weight.t() # [chunk_size, V] |
| 140 | + if bias is not None: |
| 141 | + logits_chunk = logits_chunk + bias |
| 142 | + |
| 143 | + # Ensure contiguous for kernel |
| 144 | + logits_chunk = logits_chunk.contiguous() |
| 145 | + target_chunk = target_chunk.contiguous() |
| 146 | + |
| 147 | + # Get loss slice |
| 148 | + loss_chunk = loss_1d[start_idx:end_idx] |
| 149 | + |
| 150 | + # Call kernel - this modifies logits_chunk in-place to contain gradients |
| 151 | + cross_entropy_kernel( |
| 152 | + logits_chunk, |
| 153 | + target_chunk, |
| 154 | + loss_chunk, |
| 155 | + actual_chunk_size, |
| 156 | + V, |
| 157 | + BT, # Pass total number of samples for mean reduction |
| 158 | + ) |
| 159 | + |
| 160 | + # Now logits_chunk contains gradients |
| 161 | + # Compute input gradient: grad_input = grad_logits @ weight |
| 162 | + grad_input[start_idx:end_idx] = logits_chunk.detach() @ weight.detach() |
| 163 | + |
| 164 | + # Accumulate weight gradients if needed |
| 165 | + if grad_weight is not None: |
| 166 | + # grad_weight += grad_logits.T @ input |
| 167 | + # Detach tensors to avoid autograd issues with in-place operations |
| 168 | + torch.addmm( |
| 169 | + input=grad_weight, |
| 170 | + mat1=logits_chunk.detach().t(), |
| 171 | + mat2=input_chunk.detach(), |
| 172 | + out=grad_weight, |
| 173 | + alpha=1.0, |
| 174 | + beta=1.0, |
| 175 | + ) |
| 176 | + |
| 177 | + if grad_bias is not None: |
| 178 | + torch.add( |
| 179 | + input=grad_bias, |
| 180 | + other=logits_chunk.detach().sum(dim=0), |
| 181 | + out=grad_bias, |
| 182 | + alpha=1.0, |
| 183 | + ) |
| 184 | + |
| 185 | + # Return total loss |
| 186 | + loss = loss_1d.sum() |
| 187 | + |
| 188 | + return loss, grad_input, grad_weight, grad_bias |
| 189 | + |
| 190 | + |
| 191 | +# User-facing function |
| 192 | +def fused_linear_cross_entropy( |
| 193 | + input_tensor: torch.Tensor, |
| 194 | + weight: torch.Tensor, |
| 195 | + labels: torch.Tensor, |
| 196 | + bias: torch.Tensor | None = None, |
| 197 | +) -> torch.Tensor: |
| 198 | + """Fused linear + cross entropy.""" |
| 199 | + loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( |
| 200 | + input_tensor, weight, labels, bias |
| 201 | + ) |
| 202 | + |
| 203 | + # For this example, we just return the loss |
| 204 | + # In a real implementation with autograd, we'd save gradients for backward |
| 205 | + return loss |
| 206 | + |
| 207 | + |
| 208 | +def fused_linear_cross_entropy_pytorch( |
| 209 | + input_tensor: torch.Tensor, |
| 210 | + weight: torch.Tensor, |
| 211 | + labels: torch.Tensor, |
| 212 | + bias: torch.Tensor | None = None, |
| 213 | +) -> torch.Tensor: |
| 214 | + """PyTorch reference implementation for fused linear cross entropy.""" |
| 215 | + # Compute logits |
| 216 | + logits = torch.matmul(input_tensor, weight.T) |
| 217 | + if bias is not None: |
| 218 | + logits = logits + bias |
| 219 | + # Compute cross entropy |
| 220 | + return torch.nn.functional.cross_entropy(logits, labels) |
| 221 | + |
| 222 | + |
| 223 | +def main() -> None: |
| 224 | + n, h, v = 128, 512, 1000 |
| 225 | + torch.manual_seed(42) |
| 226 | + input_tensor = torch.randn(n, h, device="cuda", dtype=torch.float32) |
| 227 | + weight = torch.randn(v, h, device="cuda", dtype=torch.float32) |
| 228 | + labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long) |
| 229 | + |
| 230 | + run_example( |
| 231 | + fused_linear_cross_entropy, |
| 232 | + fused_linear_cross_entropy_pytorch, |
| 233 | + (input_tensor, weight, labels), |
| 234 | + kernel_name="helion", |
| 235 | + baseline_name="torch", |
| 236 | + rtol=1e-3, |
| 237 | + atol=1e-3, |
| 238 | + ) |
| 239 | + |
| 240 | + |
| 241 | +if __name__ == "__main__": |
| 242 | + main() |
0 commit comments