Skip to content

[Example] Add fused_linear_cross_entropy example and unit test #342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 242 additions & 0 deletions examples/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
"""Fused linear cross entropy implementation for Helion.

This implementation uses Liger's chunking strategy to reduce memory usage.
"""

from __future__ import annotations

import os

import torch

import helion
from helion._testing import run_example
import helion.language as hl

# TritonBench configuration
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1":
# Low memory configuration
TRITONBENCH_ARGS = {"hidden_size": 2048, "vocab_size": 32000}

# Maximum chunk size (similar to Liger's MAX_FUSED_SIZE)
MAX_FUSED_SIZE = 65536 // 2


@helion.kernel(static_shapes=True)
def cross_entropy_kernel(
logits_chunk: torch.Tensor, # [chunk_size, vocab_size]
target_chunk: torch.Tensor, # [chunk_size]
loss_chunk: torch.Tensor, # [chunk_size]
chunk_size: int,
vocab_size: int,
n_total_samples: int, # Total number of samples for mean reduction
) -> torch.Tensor:
# Grid over samples - each program handles one sample
for program_id in hl.grid(chunk_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a hl.tile loop to allow tiling this dimension with block_size>1?

target_idx = target_chunk[program_id].unsqueeze(0)

# Online softmax: first pass - find max and sum
m = hl.full([], float("-inf")) # max value
d = hl.full([], 0.0) # sum of exp

# Store original value at target
ori_logit_y = logits_chunk[program_id, target_idx]

# Process in blocks like Liger
for vocab_tile in hl.tile(vocab_size):
# Create block offsets (like tl.arange in Triton)
block_offsets = vocab_tile.index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the alias needed after you remove the extra masking?


# Masked load of block
mask = block_offsets < vocab_size
logits_block = torch.where(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This masking should be added automatically in helion.

mask, logits_chunk[program_id, block_offsets], float("-inf")
)

# Find block max
block_max = torch.max(logits_block)

# Online softmax update
m_new = torch.maximum(m, block_max)
d = d * torch.exp(m - m_new) + torch.sum(torch.exp(logits_block - m_new))
m = m_new

# Compute log-sum-exp
lse = m + torch.log(d)
loss = lse - ori_logit_y
# Apply mean reduction inside the kernel
loss_chunk[program_id] = (loss / n_total_samples).squeeze(0)

# Second pass: compute gradients with block processing
for vocab_tile in hl.tile(vocab_size):
block_offsets = vocab_tile.index
mask = block_offsets < vocab_size

# Load block
logits_block = torch.where(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Masking should be automatic. Are you sure this needed?

mask, logits_chunk[program_id, block_offsets], 0.0
)

# Compute softmax
softmax_block = torch.exp(logits_block - m) / d

# Special handling for target
is_target_block = block_offsets == target_idx
grad_block = torch.where(
is_target_block, softmax_block - 1.0, softmax_block
)

# Apply mean reduction to gradients
grad_block = grad_block / n_total_samples

# Masked store using torch.where pattern
# First, load existing values for positions that will be masked out
existing_values = logits_chunk[program_id, block_offsets]

# Apply mask to the gradient block
logits_chunk[program_id, block_offsets] = torch.where(
mask, grad_block, existing_values
)

# Return the loss chunk for testing purposes
return loss_chunk


def fused_linear_cross_entropy_forward(
_input: torch.Tensor,
weight: torch.Tensor,
target: torch.Tensor,
bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""Forward pass with chunking strategy similar to Liger."""
device = _input.device
BT, H = _input.shape
V = weight.shape[0]

# Calculate chunk size to limit memory usage
inc_factor = (V + H - 1) // H
chunk_size = min(MAX_FUSED_SIZE, (BT + inc_factor - 1) // inc_factor)
chunk_size = min(chunk_size, BT)
num_chunks = (BT + chunk_size - 1) // chunk_size

# Initialize gradients and loss
grad_input = torch.zeros_like(_input)
grad_weight = torch.zeros_like(weight) if weight.requires_grad else None
grad_bias = torch.zeros_like(bias) if bias is not None else None
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)

# Process in chunks
for chunk_id in range(num_chunks):
start_idx = chunk_id * chunk_size
end_idx = min((chunk_id + 1) * chunk_size, BT)
actual_chunk_size = end_idx - start_idx

# Get chunk of input and target
input_chunk = _input[start_idx:end_idx] # [chunk_size, H]
target_chunk = target[start_idx:end_idx] # [chunk_size]

# Compute logits for this chunk
logits_chunk = input_chunk @ weight.t() # [chunk_size, V]
if bias is not None:
logits_chunk = logits_chunk + bias

# Ensure contiguous for kernel
logits_chunk = logits_chunk.contiguous()
target_chunk = target_chunk.contiguous()

# Get loss slice
loss_chunk = loss_1d[start_idx:end_idx]

# Call kernel - this modifies logits_chunk in-place to contain gradients
cross_entropy_kernel(
logits_chunk,
target_chunk,
loss_chunk,
actual_chunk_size,
V,
BT, # Pass total number of samples for mean reduction
)

# Now logits_chunk contains gradients
# Compute input gradient: grad_input = grad_logits @ weight
grad_input[start_idx:end_idx] = logits_chunk.detach() @ weight.detach()

# Accumulate weight gradients if needed
if grad_weight is not None:
# grad_weight += grad_logits.T @ input
# Detach tensors to avoid autograd issues with in-place operations
torch.addmm(
input=grad_weight,
mat1=logits_chunk.detach().t(),
mat2=input_chunk.detach(),
out=grad_weight,
alpha=1.0,
beta=1.0,
)

if grad_bias is not None:
torch.add(
input=grad_bias,
other=logits_chunk.detach().sum(dim=0),
out=grad_bias,
alpha=1.0,
)

# Return total loss
loss = loss_1d.sum()

return loss, grad_input, grad_weight, grad_bias


# User-facing function
def fused_linear_cross_entropy(
input_tensor: torch.Tensor,
weight: torch.Tensor,
labels: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""Fused linear + cross entropy."""
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
input_tensor, weight, labels, bias
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other implementation of fused_linear_cross_entropy in tritonbench also include both the Python-level chunking code and the Triton kernel (per-chunk processing) in the benchmark timing measurement. So we do the same here for a fair comparison.


# For this example, we just return the loss
# In a real implementation with autograd, we'd save gradients for backward
return loss


def fused_linear_cross_entropy_pytorch(
input_tensor: torch.Tensor,
weight: torch.Tensor,
labels: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""PyTorch reference implementation for fused linear cross entropy."""
# Compute logits
logits = torch.matmul(input_tensor, weight.T)
if bias is not None:
logits = logits + bias
# Compute cross entropy
return torch.nn.functional.cross_entropy(logits, labels)


def main() -> None:
n, h, v = 128, 512, 1000
torch.manual_seed(42)
input_tensor = torch.randn(n, h, device="cuda", dtype=torch.float32)
weight = torch.randn(v, h, device="cuda", dtype=torch.float32)
labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long)

run_example(
fused_linear_cross_entropy,
fused_linear_cross_entropy_pytorch,
(input_tensor, weight, labels),
kernel_name="helion",
baseline_name="torch",
rtol=1e-3,
atol=1e-3,
)


if __name__ == "__main__":
main()
133 changes: 133 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,139 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
_launcher(_fp8_gemm_kernel, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_fused_linear_cross_entropy)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import math as tl_math
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _cross_entropy_kernel_kernel(target_chunk, logits_chunk, loss_chunk, vocab_size, n_total_samples, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0
load = tl.load(target_chunk + offset_0 * 1, None)
target_idx = load[None]
m = tl.full([], float('-inf'), tl.float32)
d = tl.full([], 0.0, tl.float32)
ori_logit_y = tl.load(logits_chunk + (offset_0 * 1000 + target_idx * 1), None)
for offset_1 in tl.range(0, vocab_size.to(tl.int32), _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < vocab_size
m_copy = m
d_copy = d
m_copy_0 = m_copy
d_copy_0 = d_copy
v_0 = vocab_size.to(tl.int32)
v_1 = indices_1 < v_0
load_1 = tl.load(logits_chunk + (offset_0 * 1000 + indices_1 * 1), mask_1, other=0)
v_2 = float('-inf')
v_3 = v_2[None]
v_4 = tl.where(v_1, load_1, v_3)
_mask_to = tl.where(mask_1, v_4, float('-inf'))
block_max = tl.max(_mask_to, 0)
v_5 = triton_helpers.maximum(m_copy_0, block_max)
v_6 = m_copy_0 - v_5
v_7 = tl_math.exp(v_6)
v_8 = d_copy_0 * v_7
v_9 = v_5[None]
v_10 = v_4 - v_9
v_11 = tl_math.exp(v_10)
_mask_to_1 = tl.where(mask_1, v_11, 0)
sum_1 = tl.sum(_mask_to_1, 0)
d = v_8 + sum_1
m = v_5
v_13 = tl_math.log(d)
v_14 = m + v_13
v_15 = v_14[None]
v_16 = v_15 - ori_logit_y
v_17 = n_total_samples.to(tl.float32)
v_18 = v_16 / v_17
squeeze = tl.reshape(v_18, [])
tl.store(loss_chunk + offset_0 * 1, squeeze, None)
for offset_2 in tl.range(0, vocab_size.to(tl.int32), _BLOCK_SIZE_2):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
mask_2 = indices_2 < vocab_size
m_copy_1 = m
d_copy_1 = d
target_idx_copy = target_idx
m_copy_1_0 = m_copy_1
d_copy_1_0 = d_copy_1
target_idx_copy_0 = target_idx_copy
v_19 = vocab_size.to(tl.int32)
v_20 = indices_2 < v_19
load_2 = tl.load(logits_chunk + (offset_0 * 1000 + indices_2 * 1), mask_2, other=0)
v_21 = 0.0
v_22 = v_21[None]
v_23 = tl.where(v_20, load_2, v_22)
v_24 = m_copy_1_0[None]
v_25 = v_23 - v_24
v_26 = tl_math.exp(v_25)
v_27 = d_copy_1_0[None]
v_28 = v_26 / v_27
v_29 = indices_2.to(tl.int64)
v_30 = v_29 == target_idx_copy_0
v_31 = 1.0
v_32 = v_28 - v_31
v_33 = tl.where(v_30, v_32, v_28)
v_34 = n_total_samples.to(tl.float32)
v_35 = v_33 / v_34
existing_values = tl.load(logits_chunk + (offset_0 * 1000 + indices_2 * 1), mask_2, other=0)
v_36 = tl.where(v_20, v_35, existing_values)
tl.store(logits_chunk + (offset_0 * 1000 + indices_2 * 1), v_36, mask_2)

def cross_entropy_kernel(logits_chunk: torch.Tensor, target_chunk: torch.Tensor, loss_chunk: torch.Tensor, chunk_size: int, vocab_size: int, n_total_samples: int, *, _launcher=_default_launcher):
_BLOCK_SIZE_1 = 32
_BLOCK_SIZE_2 = 32
_launcher(_cross_entropy_kernel_kernel, (chunk_size,), target_chunk, logits_chunk, loss_chunk, vocab_size, n_total_samples, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
return loss_chunk

--- assertExpectedJournal(TestExamples.test_fused_linear_cross_entropy)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_helpers import math as tl_math
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _cross_entropy_loss_kernel(labels, base_indices, logits_flat, logits, losses, base_indices_stride_0, labels_stride_0, logits_stride_0, logits_stride_1, logits_flat_stride_0, losses_stride_0, v, _RDIM_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0
indices_0 = offset_0 + tl.zeros([1], tl.int32)
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
mask_1 = indices_1 < v
labels_tile = tl.load(labels + indices_0 * labels_stride_0, None)
base_indices_tile = tl.load(base_indices + indices_0 * base_indices_stride_0, None)
v_0 = base_indices_tile + labels_tile
logits_at_target = tl.load(logits_flat + v_0 * logits_flat_stride_0, None)
logits_rows = tl.load(logits + (indices_0[:, None] * logits_stride_0 + indices_1[None, :] * logits_stride_1), mask_1[None, :], other=0)
_mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), logits_rows, float('-inf'))
max_logits = tl.reshape(tl.max(_mask_to, 1), [1, 1])
v_1 = logits_rows - max_logits
v_2 = tl_math.exp(v_1)
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _RDIM_SIZE_1]), v_2, 0)
sum_exp = tl.reshape(tl.sum(_mask_to_1, 1), [1, 1])
squeeze = tl.reshape(max_logits, [1])
squeeze_1 = tl.reshape(sum_exp, [1])
v_3 = tl_math.log(squeeze_1)
v_4 = squeeze + v_3
v_5 = v_4 - logits_at_target
tl.store(losses + indices_0 * losses_stride_0, v_5, None)

def cross_entropy_loss(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_default_launcher):
n, v = logits.shape
losses = torch.zeros([n], dtype=torch.float32, device=logits.device)
base_indices = torch.arange(n, device=logits.device) * v
logits_flat = logits.view(-1)
_RDIM_SIZE_1 = triton.next_power_of_2(v)
_launcher(_cross_entropy_loss_kernel, (n,), labels, base_indices, logits_flat, logits, losses, base_indices.stride(0), labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return losses.mean()

--- assertExpectedJournal(TestExamples.test_jagged_dense_add)
from __future__ import annotations

Expand Down
Loading
Loading