Skip to content

Commit 4a910ae

Browse files
committed
[Example] Add fused_linear_cross_entropy example and unit test
stack-info: PR: #342, branch: yf225/stack/37
1 parent 41fe6e9 commit 4a910ae

File tree

3 files changed

+420
-0
lines changed

3 files changed

+420
-0
lines changed
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
"""Fused linear cross entropy implementation for Helion.
2+
3+
This implementation uses Liger's chunking strategy to reduce memory usage.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
import os
9+
10+
import torch
11+
12+
import helion
13+
from helion._testing import run_example
14+
import helion.language as hl
15+
16+
# TritonBench configuration
17+
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1":
18+
# Low memory configuration
19+
TRITONBENCH_ARGS = {"hidden_size": 2048, "vocab_size": 32000}
20+
21+
# Maximum chunk size (similar to Liger's MAX_FUSED_SIZE)
22+
MAX_FUSED_SIZE = 65536 // 2
23+
24+
25+
@helion.kernel(static_shapes=True)
26+
def cross_entropy_kernel(
27+
logits_chunk: torch.Tensor, # [chunk_size, vocab_size]
28+
target_chunk: torch.Tensor, # [chunk_size]
29+
loss_chunk: torch.Tensor, # [chunk_size]
30+
chunk_size: int,
31+
vocab_size: int,
32+
n_total_samples: int, # Total number of samples for mean reduction
33+
) -> torch.Tensor:
34+
# Grid over samples - each program handles one sample
35+
for program_id in hl.grid(chunk_size):
36+
target_idx = target_chunk[program_id].unsqueeze(0)
37+
38+
# Online softmax: first pass - find max and sum
39+
m = hl.full([], float("-inf")) # max value
40+
d = hl.full([], 0.0) # sum of exp
41+
42+
# Store original value at target
43+
ori_logit_y = logits_chunk[program_id, target_idx]
44+
45+
# Process in blocks like Liger
46+
for vocab_tile in hl.tile(vocab_size):
47+
# Create block offsets (like tl.arange in Triton)
48+
block_offsets = vocab_tile.index
49+
50+
# Masked load of block
51+
mask = block_offsets < vocab_size
52+
logits_block = torch.where(
53+
mask, logits_chunk[program_id, block_offsets], float("-inf")
54+
)
55+
56+
# Find block max
57+
block_max = torch.max(logits_block)
58+
59+
# Online softmax update
60+
m_new = torch.maximum(m, block_max)
61+
d = d * torch.exp(m - m_new) + torch.sum(torch.exp(logits_block - m_new))
62+
m = m_new
63+
64+
# Compute log-sum-exp
65+
lse = m + torch.log(d)
66+
loss = lse - ori_logit_y
67+
# Apply mean reduction inside the kernel
68+
loss_chunk[program_id] = (loss / n_total_samples).squeeze(0)
69+
70+
# Second pass: compute gradients with block processing
71+
for vocab_tile in hl.tile(vocab_size):
72+
block_offsets = vocab_tile.index
73+
mask = block_offsets < vocab_size
74+
75+
# Load block
76+
logits_block = torch.where(
77+
mask, logits_chunk[program_id, block_offsets], 0.0
78+
)
79+
80+
# Compute softmax
81+
softmax_block = torch.exp(logits_block - m) / d
82+
83+
# Special handling for target
84+
is_target_block = block_offsets == target_idx
85+
grad_block = torch.where(
86+
is_target_block, softmax_block - 1.0, softmax_block
87+
)
88+
89+
# Apply mean reduction to gradients
90+
grad_block = grad_block / n_total_samples
91+
92+
# Masked store using torch.where pattern
93+
# First, load existing values for positions that will be masked out
94+
existing_values = logits_chunk[program_id, block_offsets]
95+
96+
# Apply mask to the gradient block
97+
logits_chunk[program_id, block_offsets] = torch.where(
98+
mask, grad_block, existing_values
99+
)
100+
101+
# Return the loss chunk for testing purposes
102+
return loss_chunk
103+
104+
105+
def fused_linear_cross_entropy_forward(
106+
_input: torch.Tensor,
107+
weight: torch.Tensor,
108+
target: torch.Tensor,
109+
bias: torch.Tensor | None = None,
110+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
111+
"""Forward pass with chunking strategy similar to Liger."""
112+
device = _input.device
113+
BT, H = _input.shape
114+
V = weight.shape[0]
115+
116+
# Calculate chunk size to limit memory usage
117+
inc_factor = (V + H - 1) // H
118+
chunk_size = min(MAX_FUSED_SIZE, (BT + inc_factor - 1) // inc_factor)
119+
chunk_size = min(chunk_size, BT)
120+
num_chunks = (BT + chunk_size - 1) // chunk_size
121+
122+
# Initialize gradients and loss
123+
grad_input = torch.zeros_like(_input)
124+
grad_weight = torch.zeros_like(weight) if weight.requires_grad else None
125+
grad_bias = torch.zeros_like(bias) if bias is not None else None
126+
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
127+
128+
# Process in chunks
129+
for chunk_id in range(num_chunks):
130+
start_idx = chunk_id * chunk_size
131+
end_idx = min((chunk_id + 1) * chunk_size, BT)
132+
actual_chunk_size = end_idx - start_idx
133+
134+
# Get chunk of input and target
135+
input_chunk = _input[start_idx:end_idx] # [chunk_size, H]
136+
target_chunk = target[start_idx:end_idx] # [chunk_size]
137+
138+
# Compute logits for this chunk
139+
logits_chunk = input_chunk @ weight.t() # [chunk_size, V]
140+
if bias is not None:
141+
logits_chunk = logits_chunk + bias
142+
143+
# Ensure contiguous for kernel
144+
logits_chunk = logits_chunk.contiguous()
145+
target_chunk = target_chunk.contiguous()
146+
147+
# Get loss slice
148+
loss_chunk = loss_1d[start_idx:end_idx]
149+
150+
# Call kernel - this modifies logits_chunk in-place to contain gradients
151+
cross_entropy_kernel(
152+
logits_chunk,
153+
target_chunk,
154+
loss_chunk,
155+
actual_chunk_size,
156+
V,
157+
BT, # Pass total number of samples for mean reduction
158+
)
159+
160+
# Now logits_chunk contains gradients
161+
# Compute input gradient: grad_input = grad_logits @ weight
162+
grad_input[start_idx:end_idx] = logits_chunk.detach() @ weight.detach()
163+
164+
# Accumulate weight gradients if needed
165+
if grad_weight is not None:
166+
# grad_weight += grad_logits.T @ input
167+
# Detach tensors to avoid autograd issues with in-place operations
168+
torch.addmm(
169+
input=grad_weight,
170+
mat1=logits_chunk.detach().t(),
171+
mat2=input_chunk.detach(),
172+
out=grad_weight,
173+
alpha=1.0,
174+
beta=1.0,
175+
)
176+
177+
if grad_bias is not None:
178+
torch.add(
179+
input=grad_bias,
180+
other=logits_chunk.detach().sum(dim=0),
181+
out=grad_bias,
182+
alpha=1.0,
183+
)
184+
185+
# Return total loss
186+
loss = loss_1d.sum()
187+
188+
return loss, grad_input, grad_weight, grad_bias
189+
190+
191+
# User-facing function
192+
def fused_linear_cross_entropy(
193+
input_tensor: torch.Tensor,
194+
weight: torch.Tensor,
195+
labels: torch.Tensor,
196+
bias: torch.Tensor | None = None,
197+
) -> torch.Tensor:
198+
"""Fused linear + cross entropy."""
199+
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
200+
input_tensor, weight, labels, bias
201+
)
202+
203+
# For this example, we just return the loss
204+
# In a real implementation with autograd, we'd save gradients for backward
205+
return loss
206+
207+
208+
def fused_linear_cross_entropy_pytorch(
209+
input_tensor: torch.Tensor,
210+
weight: torch.Tensor,
211+
labels: torch.Tensor,
212+
bias: torch.Tensor | None = None,
213+
) -> torch.Tensor:
214+
"""PyTorch reference implementation for fused linear cross entropy."""
215+
# Compute logits
216+
logits = torch.matmul(input_tensor, weight.T)
217+
if bias is not None:
218+
logits = logits + bias
219+
# Compute cross entropy
220+
return torch.nn.functional.cross_entropy(logits, labels)
221+
222+
223+
def main() -> None:
224+
n, h, v = 128, 512, 1000
225+
torch.manual_seed(42)
226+
input_tensor = torch.randn(n, h, device="cuda", dtype=torch.float32)
227+
weight = torch.randn(v, h, device="cuda", dtype=torch.float32)
228+
labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long)
229+
230+
run_example(
231+
fused_linear_cross_entropy,
232+
fused_linear_cross_entropy_pytorch,
233+
(input_tensor, weight, labels),
234+
kernel_name="helion",
235+
baseline_name="torch",
236+
rtol=1e-3,
237+
atol=1e-3,
238+
)
239+
240+
241+
if __name__ == "__main__":
242+
main()

test/test_examples.expected

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,139 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
705705
_launcher(_fp8_gemm_kernel, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
706706
return out
707707

708+
--- assertExpectedJournal(TestExamples.test_fused_linear_cross_entropy)
709+
from __future__ import annotations
710+
711+
import torch
712+
import triton
713+
import triton.language as tl
714+
from torch._inductor.runtime import triton_helpers
715+
from torch._inductor.runtime.triton_helpers import math as tl_math
716+
from helion.runtime import default_launcher as _default_launcher
717+
718+
@triton.jit
719+
def _cross_entropy_kernel_kernel(target_chunk, logits_chunk, loss_chunk, vocab_size, n_total_samples, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
720+
pid_0 = tl.program_id(0)
721+
offset_0 = pid_0
722+
load = tl.load(target_chunk + offset_0 * 1, None)
723+
target_idx = load[None]
724+
m = tl.full([], float('-inf'), tl.float32)
725+
d = tl.full([], 0.0, tl.float32)
726+
ori_logit_y = tl.load(logits_chunk + (offset_0 * 1000 + target_idx * 1), None)
727+
for offset_1 in tl.range(0, vocab_size.to(tl.int32), _BLOCK_SIZE_1):
728+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
729+
mask_1 = indices_1 < vocab_size
730+
m_copy = m
731+
d_copy = d
732+
m_copy_0 = m_copy
733+
d_copy_0 = d_copy
734+
v_0 = vocab_size.to(tl.int32)
735+
v_1 = indices_1 < v_0
736+
load_1 = tl.load(logits_chunk + (offset_0 * 1000 + indices_1 * 1), mask_1, other=0)
737+
v_2 = float('-inf')
738+
v_3 = v_2[None]
739+
v_4 = tl.where(v_1, load_1, v_3)
740+
_mask_to = tl.where(mask_1, v_4, float('-inf'))
741+
block_max = tl.max(_mask_to, 0)
742+
v_5 = triton_helpers.maximum(m_copy_0, block_max)
743+
v_6 = m_copy_0 - v_5
744+
v_7 = tl_math.exp(v_6)
745+
v_8 = d_copy_0 * v_7
746+
v_9 = v_5[None]
747+
v_10 = v_4 - v_9
748+
v_11 = tl_math.exp(v_10)
749+
_mask_to_1 = tl.where(mask_1, v_11, 0)
750+
sum_1 = tl.sum(_mask_to_1, 0)
751+
d = v_8 + sum_1
752+
m = v_5
753+
v_13 = tl_math.log(d)
754+
v_14 = m + v_13
755+
v_15 = v_14[None]
756+
v_16 = v_15 - ori_logit_y
757+
v_17 = n_total_samples.to(tl.float32)
758+
v_18 = v_16 / v_17
759+
squeeze = tl.reshape(v_18, [])
760+
tl.store(loss_chunk + offset_0 * 1, squeeze, None)
761+
for offset_2 in tl.range(0, vocab_size.to(tl.int32), _BLOCK_SIZE_2):
762+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
763+
mask_2 = indices_2 < vocab_size
764+
m_copy_1 = m
765+
d_copy_1 = d
766+
target_idx_copy = target_idx
767+
m_copy_1_0 = m_copy_1
768+
d_copy_1_0 = d_copy_1
769+
target_idx_copy_0 = target_idx_copy
770+
v_19 = vocab_size.to(tl.int32)
771+
v_20 = indices_2 < v_19
772+
load_2 = tl.load(logits_chunk + (offset_0 * 1000 + indices_2 * 1), mask_2, other=0)
773+
v_21 = 0.0
774+
v_22 = v_21[None]
775+
v_23 = tl.where(v_20, load_2, v_22)
776+
v_24 = m_copy_1_0[None]
777+
v_25 = v_23 - v_24
778+
v_26 = tl_math.exp(v_25)
779+
v_27 = d_copy_1_0[None]
780+
v_28 = v_26 / v_27
781+
v_29 = indices_2.to(tl.int64)
782+
v_30 = v_29 == target_idx_copy_0
783+
v_31 = 1.0
784+
v_32 = v_28 - v_31
785+
v_33 = tl.where(v_30, v_32, v_28)
786+
v_34 = n_total_samples.to(tl.float32)
787+
v_35 = v_33 / v_34
788+
existing_values = tl.load(logits_chunk + (offset_0 * 1000 + indices_2 * 1), mask_2, other=0)
789+
v_36 = tl.where(v_20, v_35, existing_values)
790+
tl.store(logits_chunk + (offset_0 * 1000 + indices_2 * 1), v_36, mask_2)
791+
792+
def cross_entropy_kernel(logits_chunk: torch.Tensor, target_chunk: torch.Tensor, loss_chunk: torch.Tensor, chunk_size: int, vocab_size: int, n_total_samples: int, *, _launcher=_default_launcher):
793+
_BLOCK_SIZE_1 = 32
794+
_BLOCK_SIZE_2 = 32
795+
_launcher(_cross_entropy_kernel_kernel, (chunk_size,), target_chunk, logits_chunk, loss_chunk, vocab_size, n_total_samples, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
796+
return loss_chunk
797+
798+
--- assertExpectedJournal(TestExamples.test_fused_linear_cross_entropy)
799+
from __future__ import annotations
800+
801+
import torch
802+
import triton
803+
import triton.language as tl
804+
from torch._inductor.runtime.triton_helpers import math as tl_math
805+
from helion.runtime import default_launcher as _default_launcher
806+
807+
@triton.jit
808+
def _cross_entropy_loss_kernel(labels, base_indices, logits_flat, logits, losses, base_indices_stride_0, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr):
809+
pid_0 = tl.program_id(0)
810+
offset_0 = pid_0
811+
indices_0 = offset_0 + tl.zeros([1], tl.int32)
812+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
813+
mask_1 = indices_1 < v
814+
labels_tile = tl.load(labels + indices_0 * labels_stride_0, None)
815+
base_indices_tile = tl.load(base_indices + indices_0 * base_indices_stride_0, None)
816+
v_0 = base_indices_tile + labels_tile
817+
logits_at_target = tl.load(logits_flat + v_0 * logits_flat_stride_0, None)
818+
logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0)
819+
_mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf'))
820+
max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1])
821+
v_1 = logits_rows - max_logits
822+
v_2 = tl_math.exp(v_1)
823+
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_2, 0)
824+
sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1])
825+
squeeze = tl.reshape(max_logits, [1])
826+
squeeze_1 = tl.reshape(sum_exp, [1])
827+
v_3 = tl_math.log(squeeze_1)
828+
v_4 = squeeze + v_3
829+
v_5 = v_4 - logits_at_target
830+
tl.store(losses + indices_0 * losses_stride_0, v_5, None)
831+
832+
def cross_entropy_loss(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_default_launcher):
833+
n, v = logits.shape
834+
losses = torch.zeros([n], dtype=torch.float32, device=logits.device)
835+
base_indices = torch.arange(n, device=logits.device) * v
836+
logits_flat = logits.view(-1)
837+
_RDIM_SIZE_1 = triton.next_power_of_2(v)
838+
_launcher(_cross_entropy_loss_kernel, (n,), labels, base_indices, logits_flat, logits, losses, base_indices.stride(0), labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
839+
return losses.mean()
840+
708841
--- assertExpectedJournal(TestExamples.test_jagged_dense_add)
709842
from __future__ import annotations
710843

0 commit comments

Comments
 (0)