Skip to content

Commit 6d34cb8

Browse files
committed
[Example] Add fused_linear_cross_entropy example and unit test
stack-info: PR: #342, branch: yf225/stack/37
1 parent 41fe6e9 commit 6d34cb8

File tree

3 files changed

+411
-0
lines changed

3 files changed

+411
-0
lines changed
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
"""Fused linear cross entropy implementation for Helion.
2+
3+
This implementation uses Liger's chunking strategy to reduce memory usage.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
import os
9+
10+
import torch
11+
12+
import helion
13+
from helion._testing import run_example
14+
import helion.language as hl
15+
16+
# TritonBench configuration
17+
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1":
18+
# Low memory configuration
19+
TRITONBENCH_ARGS = {"hidden_size": 2048, "vocab_size": 32000}
20+
21+
# Maximum chunk size (similar to Liger's MAX_FUSED_SIZE)
22+
MAX_FUSED_SIZE = 65536 // 2
23+
24+
25+
@helion.kernel(static_shapes=True)
26+
def cross_entropy_kernel(
27+
logits_chunk: torch.Tensor, # [chunk_size, vocab_size]
28+
target_chunk: torch.Tensor, # [chunk_size]
29+
loss_chunk: torch.Tensor, # [chunk_size]
30+
chunk_size: int,
31+
vocab_size: int,
32+
n_total_samples: int, # Total number of samples for mean reduction
33+
) -> torch.Tensor:
34+
# Grid over samples - each program handles one sample (matching Liger kernel behavior)
35+
for program_id in hl.grid(chunk_size):
36+
# Track original dtype for writing gradients back
37+
orig_dtype = logits_chunk.dtype
38+
target_idx = target_chunk[program_id].unsqueeze(0)
39+
40+
# Online softmax: first pass - find max and sum
41+
# Use float32 accumulators for numerical stability, matching Liger kernel behavior
42+
m = hl.full([], float("-inf")).to(torch.float32) # max value
43+
d = hl.full([], 0.0).to(torch.float32) # sum of exp
44+
45+
# Store original value at target
46+
ori_logit_y = logits_chunk[program_id, target_idx].to(torch.float32)
47+
48+
# Process in blocks like Liger
49+
for vocab_tile in hl.tile(vocab_size):
50+
# Create block offsets (like tl.arange in Triton)
51+
block_offsets = vocab_tile.index
52+
53+
# Rely on Helion's auto-masked load; compute in float32
54+
logits_block = logits_chunk[program_id, block_offsets].to(torch.float32)
55+
56+
# Find block max
57+
block_max = torch.max(logits_block)
58+
59+
# Online softmax update
60+
m_new = torch.maximum(m, block_max)
61+
d = d * torch.exp(m - m_new) + torch.sum(torch.exp(logits_block - m_new))
62+
m = m_new
63+
64+
# Compute log-sum-exp
65+
lse = m + torch.log(d)
66+
loss = lse - ori_logit_y
67+
# Apply mean reduction inside the kernel
68+
loss_chunk[program_id] = (loss / n_total_samples).squeeze(0)
69+
70+
# Second pass: compute gradients with block processing
71+
for vocab_tile in hl.tile(vocab_size):
72+
block_offsets = vocab_tile.index
73+
74+
# Load block and compute in float32
75+
logits_block = logits_chunk[program_id, block_offsets].to(torch.float32)
76+
77+
# Compute softmax
78+
softmax_block = torch.exp(logits_block - m) / d
79+
80+
# Special handling for target
81+
is_target_block = block_offsets == target_idx
82+
grad_block = torch.where(
83+
is_target_block, softmax_block - 1.0, softmax_block
84+
)
85+
86+
# Apply mean reduction to gradients
87+
grad_block = (grad_block / n_total_samples).to(orig_dtype)
88+
89+
# Store back; Helion will auto-mask OOB lanes on the store
90+
logits_chunk[program_id, block_offsets] = grad_block
91+
92+
# Return the loss chunk for testing purposes
93+
return loss_chunk
94+
95+
96+
def fused_linear_cross_entropy_forward(
97+
_input: torch.Tensor,
98+
weight: torch.Tensor,
99+
target: torch.Tensor,
100+
bias: torch.Tensor | None = None,
101+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
102+
"""Forward pass with chunking strategy similar to Liger."""
103+
device = _input.device
104+
BT, H = _input.shape
105+
V = weight.shape[0]
106+
107+
# Calculate chunk size to limit memory usage
108+
inc_factor = (V + H - 1) // H
109+
chunk_size = min(MAX_FUSED_SIZE, (BT + inc_factor - 1) // inc_factor)
110+
chunk_size = min(chunk_size, BT)
111+
num_chunks = (BT + chunk_size - 1) // chunk_size
112+
113+
# Initialize gradients and loss
114+
grad_input = torch.zeros_like(_input)
115+
grad_weight = torch.zeros_like(weight) if weight.requires_grad else None
116+
grad_bias = torch.zeros_like(bias) if bias is not None else None
117+
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
118+
119+
# Process in chunks
120+
for chunk_id in range(num_chunks):
121+
start_idx = chunk_id * chunk_size
122+
end_idx = min((chunk_id + 1) * chunk_size, BT)
123+
actual_chunk_size = end_idx - start_idx
124+
125+
# Get chunk of input and target
126+
input_chunk = _input[start_idx:end_idx] # [chunk_size, H]
127+
target_chunk = target[start_idx:end_idx] # [chunk_size]
128+
129+
# Compute logits for this chunk
130+
logits_chunk = input_chunk @ weight.t() # [chunk_size, V]
131+
if bias is not None:
132+
logits_chunk = logits_chunk + bias
133+
134+
# Ensure contiguous for kernel
135+
logits_chunk = logits_chunk.contiguous()
136+
target_chunk = target_chunk.contiguous()
137+
138+
# Get loss slice
139+
loss_chunk = loss_1d[start_idx:end_idx]
140+
141+
# Call kernel - this modifies logits_chunk in-place to contain gradients
142+
cross_entropy_kernel(
143+
logits_chunk,
144+
target_chunk,
145+
loss_chunk,
146+
actual_chunk_size,
147+
V,
148+
BT, # Pass total number of samples for mean reduction
149+
)
150+
151+
# Now logits_chunk contains gradients
152+
# Compute input gradient: grad_input = grad_logits @ weight
153+
grad_input[start_idx:end_idx] = logits_chunk.detach() @ weight.detach()
154+
155+
# Accumulate weight gradients if needed
156+
if grad_weight is not None:
157+
# grad_weight += grad_logits.T @ input
158+
# Detach tensors to avoid autograd issues with in-place operations
159+
torch.addmm(
160+
input=grad_weight,
161+
mat1=logits_chunk.detach().t(),
162+
mat2=input_chunk.detach(),
163+
out=grad_weight,
164+
alpha=1.0,
165+
beta=1.0,
166+
)
167+
168+
if grad_bias is not None:
169+
torch.add(
170+
input=grad_bias,
171+
other=logits_chunk.detach().sum(dim=0),
172+
out=grad_bias,
173+
alpha=1.0,
174+
)
175+
176+
# Return total loss
177+
loss = loss_1d.sum()
178+
179+
return loss, grad_input, grad_weight, grad_bias
180+
181+
182+
# User-facing function
183+
def fused_linear_cross_entropy(
184+
input_tensor: torch.Tensor,
185+
weight: torch.Tensor,
186+
labels: torch.Tensor,
187+
bias: torch.Tensor | None = None,
188+
) -> torch.Tensor:
189+
"""Fused linear + cross entropy."""
190+
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
191+
input_tensor, weight, labels, bias
192+
)
193+
194+
# For this example, we just return the loss
195+
# In a real implementation with autograd, we'd save gradients for backward
196+
return loss
197+
198+
199+
def fused_linear_cross_entropy_pytorch(
200+
input_tensor: torch.Tensor,
201+
weight: torch.Tensor,
202+
labels: torch.Tensor,
203+
bias: torch.Tensor | None = None,
204+
) -> torch.Tensor:
205+
"""PyTorch reference implementation for fused linear cross entropy."""
206+
# Compute logits
207+
logits = torch.matmul(input_tensor, weight.T)
208+
if bias is not None:
209+
logits = logits + bias
210+
# Compute cross entropy
211+
return torch.nn.functional.cross_entropy(logits, labels)
212+
213+
214+
def main() -> None:
215+
n, h, v = 128, 512, 1000
216+
torch.manual_seed(42)
217+
input_tensor = torch.randn(n, h, device="cuda", dtype=torch.float32)
218+
weight = torch.randn(v, h, device="cuda", dtype=torch.float32)
219+
labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long)
220+
221+
run_example(
222+
fused_linear_cross_entropy,
223+
fused_linear_cross_entropy_pytorch,
224+
(input_tensor, weight, labels),
225+
kernel_name="helion",
226+
baseline_name="torch",
227+
rtol=1e-3,
228+
atol=1e-3,
229+
)
230+
231+
232+
if __name__ == "__main__":
233+
main()

test/test_examples.expected

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,139 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
705705
_launcher(_fp8_gemm_kernel, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
706706
return out
707707

708+
--- assertExpectedJournal(TestExamples.test_fused_linear_cross_entropy)
709+
from __future__ import annotations
710+
711+
import torch
712+
import triton
713+
import triton.language as tl
714+
from torch._inductor.runtime import triton_helpers
715+
from torch._inductor.runtime.triton_helpers import math as tl_math
716+
from helion.runtime import default_launcher as _default_launcher
717+
718+
@triton.jit
719+
def _cross_entropy_kernel_kernel(target_chunk, logits_chunk, loss_chunk, vocab_size, n_total_samples, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
720+
pid_0 = tl.program_id(0)
721+
offset_0 = pid_0
722+
load = tl.load(target_chunk + offset_0 * 1, None)
723+
target_idx = load[None]
724+
m = tl.full([], float('-inf'), tl.float32)
725+
d = tl.full([], 0.0, tl.float32)
726+
ori_logit_y = tl.load(logits_chunk + (offset_0 * 1000 + target_idx * 1), None)
727+
for offset_1 in tl.range(0, vocab_size.to(tl.int32), _BLOCK_SIZE_1):
728+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
729+
mask_1 = indices_1 < vocab_size
730+
m_copy = m
731+
d_copy = d
732+
m_copy_0 = m_copy
733+
d_copy_0 = d_copy
734+
v_0 = vocab_size.to(tl.int32)
735+
v_1 = indices_1 < v_0
736+
load_1 = tl.load(logits_chunk + (offset_0 * 1000 + indices_1 * 1), mask_1, other=0)
737+
v_2 = float('-inf')
738+
v_3 = v_2[None]
739+
v_4 = tl.where(v_1, load_1, v_3)
740+
_mask_to = tl.where(mask_1, v_4, float('-inf'))
741+
block_max = tl.max(_mask_to, 0)
742+
v_5 = triton_helpers.maximum(m_copy_0, block_max)
743+
v_6 = m_copy_0 - v_5
744+
v_7 = tl_math.exp(v_6)
745+
v_8 = d_copy_0 * v_7
746+
v_9 = v_5[None]
747+
v_10 = v_4 - v_9
748+
v_11 = tl_math.exp(v_10)
749+
_mask_to_1 = tl.where(mask_1, v_11, 0)
750+
sum_1 = tl.sum(_mask_to_1, 0)
751+
d = v_8 + sum_1
752+
m = v_5
753+
v_13 = tl_math.log(d)
754+
v_14 = m + v_13
755+
v_15 = v_14[None]
756+
v_16 = v_15 - ori_logit_y
757+
v_17 = n_total_samples.to(tl.float32)
758+
v_18 = v_16 / v_17
759+
squeeze = tl.reshape(v_18, [])
760+
tl.store(loss_chunk + offset_0 * 1, squeeze, None)
761+
for offset_2 in tl.range(0, vocab_size.to(tl.int32), _BLOCK_SIZE_2):
762+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
763+
mask_2 = indices_2 < vocab_size
764+
m_copy_1 = m
765+
d_copy_1 = d
766+
target_idx_copy = target_idx
767+
m_copy_1_0 = m_copy_1
768+
d_copy_1_0 = d_copy_1
769+
target_idx_copy_0 = target_idx_copy
770+
v_19 = vocab_size.to(tl.int32)
771+
v_20 = indices_2 < v_19
772+
load_2 = tl.load(logits_chunk + (offset_0 * 1000 + indices_2 * 1), mask_2, other=0)
773+
v_21 = 0.0
774+
v_22 = v_21[None]
775+
v_23 = tl.where(v_20, load_2, v_22)
776+
v_24 = m_copy_1_0[None]
777+
v_25 = v_23 - v_24
778+
v_26 = tl_math.exp(v_25)
779+
v_27 = d_copy_1_0[None]
780+
v_28 = v_26 / v_27
781+
v_29 = indices_2.to(tl.int64)
782+
v_30 = v_29 == target_idx_copy_0
783+
v_31 = 1.0
784+
v_32 = v_28 - v_31
785+
v_33 = tl.where(v_30, v_32, v_28)
786+
v_34 = n_total_samples.to(tl.float32)
787+
v_35 = v_33 / v_34
788+
existing_values = tl.load(logits_chunk + (offset_0 * 1000 + indices_2 * 1), mask_2, other=0)
789+
v_36 = tl.where(v_20, v_35, existing_values)
790+
tl.store(logits_chunk + (offset_0 * 1000 + indices_2 * 1), v_36, mask_2)
791+
792+
def cross_entropy_kernel(logits_chunk: torch.Tensor, target_chunk: torch.Tensor, loss_chunk: torch.Tensor, chunk_size: int, vocab_size: int, n_total_samples: int, *, _launcher=_default_launcher):
793+
_BLOCK_SIZE_1 = 32
794+
_BLOCK_SIZE_2 = 32
795+
_launcher(_cross_entropy_kernel_kernel, (chunk_size,), target_chunk, logits_chunk, loss_chunk, vocab_size, n_total_samples, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
796+
return loss_chunk
797+
798+
--- assertExpectedJournal(TestExamples.test_fused_linear_cross_entropy)
799+
from __future__ import annotations
800+
801+
import torch
802+
import triton
803+
import triton.language as tl
804+
from torch._inductor.runtime.triton_helpers import math as tl_math
805+
from helion.runtime import default_launcher as _default_launcher
806+
807+
@triton.jit
808+
def _cross_entropy_loss_kernel(labels, base_indices, logits_flat, logits, losses, base_indices_stride_0, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr):
809+
pid_0 = tl.program_id(0)
810+
offset_0 = pid_0
811+
indices_0 = offset_0 + tl.zeros([1], tl.int32)
812+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
813+
mask_1 = indices_1 < v
814+
labels_tile = tl.load(labels + indices_0 * labels_stride_0, None)
815+
base_indices_tile = tl.load(base_indices + indices_0 * base_indices_stride_0, None)
816+
v_0 = base_indices_tile + labels_tile
817+
logits_at_target = tl.load(logits_flat + v_0 * logits_flat_stride_0, None)
818+
logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0)
819+
_mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf'))
820+
max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1])
821+
v_1 = logits_rows - max_logits
822+
v_2 = tl_math.exp(v_1)
823+
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_2, 0)
824+
sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1])
825+
squeeze = tl.reshape(max_logits, [1])
826+
squeeze_1 = tl.reshape(sum_exp, [1])
827+
v_3 = tl_math.log(squeeze_1)
828+
v_4 = squeeze + v_3
829+
v_5 = v_4 - logits_at_target
830+
tl.store(losses + indices_0 * losses_stride_0, v_5, None)
831+
832+
def cross_entropy_loss(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_default_launcher):
833+
n, v = logits.shape
834+
losses = torch.zeros([n], dtype=torch.float32, device=logits.device)
835+
base_indices = torch.arange(n, device=logits.device) * v
836+
logits_flat = logits.view(-1)
837+
_RDIM_SIZE_1 = triton.next_power_of_2(v)
838+
_launcher(_cross_entropy_loss_kernel, (n,), labels, base_indices, logits_flat, logits, losses, base_indices.stride(0), labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
839+
return losses.mean()
840+
708841
--- assertExpectedJournal(TestExamples.test_jagged_dense_add)
709842
from __future__ import annotations
710843

0 commit comments

Comments
 (0)