diff --git a/rectokens/kernels/ema_update.py b/rectokens/kernels/ema_update.py new file mode 100644 index 0000000..65065bd --- /dev/null +++ b/rectokens/kernels/ema_update.py @@ -0,0 +1,185 @@ +"""Fused Triton kernel for VQ-codebook EMA update. + +Fuses all five steps of :meth:`VQQuantizer._ema_update` into a single GPU kernel pass: + +1. **Scatter-accumulate** — for each codebook entry *k*, scan the batch and sum + the encoder outputs ``x[i]`` whose assigned code equals *k*, accumulating + ``cluster_size[k]`` and ``embed_sum[k]`` without materialising the + ``(B, K)`` one-hot matrix. +2. **EMA update** — blend the new statistics into the running EMA buffers, + restricted to codes that received at least one assignment (*active-only* + update). +3. **Codebook refresh** — recompute each active codebook entry as + ``ema_embed_sum[k] / max(ema_cluster_size[k], ε)``. +4. **Dead-code counter** — reset ``steps_since_active`` to 0 for active codes + and increment by 1 for inactive ones. +5. **Dead-code restart** — replace stranded codes (those whose counter reaches + ``restart_after_steps``) with a random encoder output drawn from the current + batch, and zero out their EMA accumulators. + +All five steps are executed in a single kernel launch per codebook level, +eliminating the intermediate ``(B, K)`` allocation of the PyTorch reference and +merging many element-wise passes into one. + +Grid shape: ``(K,)`` — one thread block per codebook entry. This avoids +inter-block atomics at the cost of reading the full ``(B,)`` codes array ``K`` +times. For the typical regime (K ≤ 4096, B ≤ 32768) the resulting memory +traffic is dominated by the ``x`` reads, which are L2-cached across blocks. + +Constraints (matching the existing kernel style): + * ``D`` must be a power of two (used as ``tl.constexpr``; Triton recompiles + one kernel variant per unique ``D`` value). + * All tensors must be contiguous and on CUDA (enforced by the Python + wrapper in :mod:`rectokens.ops.ema_update`). +""" + +from __future__ import annotations + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_B": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_B": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_B": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_B": 256}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_B": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_B": 1024}, num_warps=16, num_stages=4), + ], + key=["B", "K", "D"], + restore_value=[ + "ema_cluster_size_ptr", + "ema_embed_sum_ptr", + "codebook_ptr", + "steps_since_active_ptr", + ], +) +@triton.jit +def ema_update_kernel( + # ── inputs ────────────────────────────────────────────────────────────── + x_ptr, # (B, D) fp32 — encoder outputs (read-only) + codes_ptr, # (B,) int64 — nearest-code indices (read-only) + rand_idx_ptr, # (K,) int64 — pre-drawn random batch indices for + # dead-code replacement (read-only) + decay, # scalar fp32 — EMA decay factor γ + restart_after_steps, # scalar int32 — dead-code restart threshold + B, # int — batch size + K, # int — codebook size (== grid dimension) + D: tl.constexpr, # constexpr int — embedding dimension + x_stride_B, # stride of x along the batch axis + x_stride_D, # stride of x along the feature axis + es_stride_K, # stride of ema_embed_sum along K axis + es_stride_D, # stride of ema_embed_sum along D axis + cb_stride_K, # stride of codebook along K axis + cb_stride_D, # stride of codebook along D axis + # ── in-place buffers ──────────────────────────────────────────────────── + ema_cluster_size_ptr, # (K,) fp32 — EMA cluster-size statistics + ema_embed_sum_ptr, # (K, D) fp32 — EMA embedding-sum statistics + codebook_ptr, # (K, D) fp32 — codebook embeddings + steps_since_active_ptr, # (K,) int64 — consecutive inactive steps per code + BLOCK_B: tl.constexpr, # autotuned batch-tile size +): + """One program per codebook entry *k*. + + Scans the batch in ``BLOCK_B``-sized tiles, accumulates the statistics for + entry *k*, then writes back the updated EMA buffers, codebook entry, and + dead-code counter — all without communicating with other programs. + """ + k = tl.program_id(0) + + offs_D = tl.arange(0, D) # (D,) — D-dimension offsets (D is constexpr) + + # ── Step 1: scatter-accumulate cluster_size and embed_sum ──────────────── + # Scan every batch sample; accumulate those whose code equals k. + cluster_size = 0.0 # scalar fp32 accumulator + embed_sum = tl.zeros((D,), dtype=tl.float32) # (D,) fp32 accumulator + + for b_start in range(0, tl.cdiv(B, BLOCK_B)): + offs_B = b_start * BLOCK_B + tl.arange(0, BLOCK_B) + b_mask = offs_B < B + + # Load code assignments for this batch tile; use int32 for comparison + # (codebook size K < 2^31 in all practical settings). + batch_codes = tl.load(codes_ptr + offs_B, mask=b_mask, other=-1).to(tl.int32) + matches = (batch_codes == k) & b_mask # (BLOCK_B,) bool + + # Count assignments for code k in this tile. + cluster_size += tl.sum(matches.to(tl.float32)) + + # Load encoder outputs for the tile; zero-fill out-of-bounds rows. + x_block = tl.load( + x_ptr + offs_B[:, None] * x_stride_B + offs_D[None, :] * x_stride_D, + mask=b_mask[:, None], + other=0.0, + ).to(tl.float32) # (BLOCK_B, D) + + # Accumulate x[i] for every matched sample. + embed_sum = embed_sum + tl.sum( + tl.where(matches[:, None], x_block, tl.zeros_like(x_block)), + axis=0, + ) # (D,) + + # ── Step 2: active-only EMA update ────────────────────────────────────── + ema_cs = tl.load(ema_cluster_size_ptr + k).to(tl.float32) # scalar + ema_es = tl.load( + ema_embed_sum_ptr + k * es_stride_K + offs_D * es_stride_D + ).to(tl.float32) # (D,) + + active = cluster_size > 0.0 # scalar bool + + new_ema_cs = decay * ema_cs + (1.0 - decay) * cluster_size + new_ema_es = decay * ema_es + (1.0 - decay) * embed_sum # (D,) + + # Only overwrite statistics for codes that received an assignment; + # inactive codes keep their existing EMA values. + updated_ema_cs = tl.where(active, new_ema_cs, ema_cs) # scalar + updated_ema_es = tl.where(active, new_ema_es, ema_es) # (D,) + + # ── Step 3: recompute codebook entry for active codes ─────────────────── + n = tl.maximum(updated_ema_cs, 1e-5) # avoid division by zero + new_embedding = updated_ema_es / n # (D,) + + # ── Step 4: dead-code counter ──────────────────────────────────────────── + steps = tl.load(steps_since_active_ptr + k) # int64 scalar + # Reset to 0 for active codes; increment by 1 for inactive ones. + # Multiply to preserve int64 type without a literal cast. + new_steps = tl.where(active, steps * 0, steps + 1) # int64 scalar + dead = new_steps >= restart_after_steps # bool scalar + + # ── Step 5: dead-code restart ──────────────────────────────────────────── + # Load the pre-drawn random replacement sample from the batch. + rand_b = tl.load(rand_idx_ptr + k) + replacement = tl.load( + x_ptr + rand_b * x_stride_B + offs_D * x_stride_D + ).to(tl.float32) # (D,) + + # Current codebook entry — needed for inactive, non-dead codes. + cur_embedding = tl.load( + codebook_ptr + k * cb_stride_K + offs_D * cb_stride_D + ).to(tl.float32) # (D,) + + # Priority: dead overrides active; active overrides unchanged. + final_embedding = tl.where( + dead, + replacement, + tl.where(active, new_embedding, cur_embedding), + ) # (D,) + final_ema_cs = tl.where(dead, 0.0, updated_ema_cs) # scalar + final_ema_es = tl.where( + dead, tl.zeros((D,), dtype=tl.float32), updated_ema_es + ) # (D,) + final_steps = tl.where(dead, steps * 0, new_steps) # int64 + + # ── Write back ─────────────────────────────────────────────────────────── + tl.store(ema_cluster_size_ptr + k, final_ema_cs) + tl.store( + ema_embed_sum_ptr + k * es_stride_K + offs_D * es_stride_D, + final_ema_es, + ) + tl.store( + codebook_ptr + k * cb_stride_K + offs_D * cb_stride_D, + final_embedding, + ) + tl.store(steps_since_active_ptr + k, final_steps) diff --git a/rectokens/ops/ema_update.py b/rectokens/ops/ema_update.py new file mode 100644 index 0000000..83db754 --- /dev/null +++ b/rectokens/ops/ema_update.py @@ -0,0 +1,181 @@ +"""Python dispatch layer for the fused EMA-update operation. + +Selects the Triton kernel on CUDA and falls back to pure PyTorch on CPU. +Callers should use :func:`ema_update` exclusively; the ``_cuda_*`` helper is an +implementation detail. +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F + +IS_GPU_AVAILABLE = torch.cuda.is_available() + +if IS_GPU_AVAILABLE: + from rectokens.kernels.ema_update import ema_update_kernel + + +def _cuda_ema_update( + x: torch.Tensor, + codes: torch.Tensor, + ema_cluster_size: torch.Tensor, + ema_embed_sum: torch.Tensor, + codebook: torch.Tensor, + steps_since_active: torch.Tensor, + decay: float, + restart_after_steps: int, +) -> None: + """In-place EMA codebook update via the fused Triton kernel. + + All tensor arguments are mutated in-place. ``x`` and ``codes`` are + read-only inputs; the remaining four tensors are updated buffers. + + Args: + x: Encoder outputs, shape ``(B, D)``, float32, CUDA, contiguous. + codes: Nearest-code assignments, shape ``(B,)``, int64, CUDA. + ema_cluster_size: EMA cluster-size buffer, shape ``(K,)``. + ema_embed_sum: EMA embedding-sum buffer, shape ``(K, D)``. + codebook: Codebook embedding matrix, shape ``(K, D)``. + steps_since_active: Consecutive-inactive-step counter, shape ``(K,)``. + decay: EMA decay factor γ. + restart_after_steps: Dead-code restart threshold. + """ + B, D = x.shape + K = ema_cluster_size.shape[0] + + # Pre-draw one random batch index per codebook entry for dead-code + # replacement. Done on the host to keep the kernel deterministic and + # avoid random-state complexity inside the Triton JIT. + rand_idx = torch.randint(B, (K,), device=x.device, dtype=torch.int64) + + # Ensure contiguous layout (kernel assumes unit inner stride for the + # D-dimension). + x = x.contiguous() + codes = codes.contiguous() + + grid = (K,) + ema_update_kernel[grid]( + x_ptr=x, + codes_ptr=codes, + rand_idx_ptr=rand_idx, + decay=decay, + restart_after_steps=restart_after_steps, + B=B, + K=K, + D=D, + x_stride_B=x.stride(0), + x_stride_D=x.stride(1), + es_stride_K=ema_embed_sum.stride(0), + es_stride_D=ema_embed_sum.stride(1), + cb_stride_K=codebook.stride(0), + cb_stride_D=codebook.stride(1), + ema_cluster_size_ptr=ema_cluster_size, + ema_embed_sum_ptr=ema_embed_sum, + codebook_ptr=codebook, + steps_since_active_ptr=steps_since_active, + ) + + +def _cpu_ema_update( + x: torch.Tensor, + codes: torch.Tensor, + ema_cluster_size: torch.Tensor, + ema_embed_sum: torch.Tensor, + codebook: torch.Tensor, + steps_since_active: torch.Tensor, + decay: float, + restart_after_steps: int, +) -> None: + """Pure-PyTorch reference implementation (CPU fallback). + + Semantically equivalent to the Triton kernel; used when CUDA is not + available. + """ + k = codebook.shape[0] + + one_hot = F.one_hot(codes, num_classes=k).float() # (B, K) + cluster_size = one_hot.sum(dim=0) # (K,) + embed_sum = one_hot.t() @ x # (K, D) + + active = cluster_size > 0 # (K,) bool + + new_ema_cs = decay * ema_cluster_size + (1 - decay) * cluster_size + ema_cluster_size.copy_( + torch.where(active, new_ema_cs, ema_cluster_size) + ) + + new_ema_es = decay * ema_embed_sum + (1 - decay) * embed_sum + ema_embed_sum.copy_( + torch.where(active.unsqueeze(1), new_ema_es, ema_embed_sum) + ) + + n = ema_cluster_size.clamp(min=1e-5) + new_embeddings = ema_embed_sum / n.unsqueeze(1) + codebook.copy_( + torch.where(active.unsqueeze(1), new_embeddings, codebook) + ) + + new_steps = torch.where( + active, + torch.zeros_like(steps_since_active), + steps_since_active + 1, + ) + dead = new_steps >= restart_after_steps + + rand_idx = torch.randint(len(x), (k,), device=x.device) + replacement = x[rand_idx] + dead_exp = dead.unsqueeze(1) + + codebook.copy_(torch.where(dead_exp, replacement, codebook)) + ema_cluster_size.copy_( + torch.where(dead, torch.zeros_like(ema_cluster_size), ema_cluster_size) + ) + ema_embed_sum.copy_( + torch.where(dead_exp, torch.zeros_like(ema_embed_sum), ema_embed_sum) + ) + steps_since_active.copy_( + torch.where(dead, torch.zeros_like(new_steps), new_steps) + ) + + +def ema_update( + x: torch.Tensor, + codes: torch.Tensor, + ema_cluster_size: torch.Tensor, + ema_embed_sum: torch.Tensor, + codebook: torch.Tensor, + steps_since_active: torch.Tensor, + decay: float, + restart_after_steps: int, +) -> None: + """Fused EMA codebook update — dispatches to Triton (CUDA) or PyTorch (CPU). + + Mutates ``ema_cluster_size``, ``ema_embed_sum``, ``codebook``, and + ``steps_since_active`` in-place. + + Args: + x: Encoder outputs of shape ``(B, D)``. + codes: Nearest-code indices of shape ``(B,)``, int64. + ema_cluster_size: EMA cluster-size buffer of shape ``(K,)``. + ema_embed_sum: EMA embedding-sum buffer of shape ``(K, D)``. + codebook: Codebook embedding matrix of shape ``(K, D)``. + steps_since_active: Consecutive inactive-step counter of shape ``(K,)``. + decay: EMA decay factor γ ∈ (0, 1). + restart_after_steps: Replace a code once it has gone this many + consecutive steps without any assignment. + """ + if x.is_cuda: + _cuda_ema_update( + x, codes, + ema_cluster_size, ema_embed_sum, + codebook, steps_since_active, + decay, restart_after_steps, + ) + else: + _cpu_ema_update( + x, codes, + ema_cluster_size, ema_embed_sum, + codebook, steps_since_active, + decay, restart_after_steps, + ) diff --git a/rectokens/tokenizers/rqvae.py b/rectokens/tokenizers/rqvae.py index 8dfe0bc..240d8bc 100644 --- a/rectokens/tokenizers/rqvae.py +++ b/rectokens/tokenizers/rqvae.py @@ -10,6 +10,7 @@ from rectokens.codebooks.euclidean import EuclideanCodebook from rectokens.core.quantizer import Quantizer, QuantizerOutput from rectokens.core.tokenizer import TokenSequence, Tokenizer +from rectokens.ops.ema_update import ema_update from rectokens.quantizers.residual import ResidualQuantizer @@ -202,100 +203,44 @@ def _init_from_batch(self, x: torch.Tensor) -> None: def _ema_update(self, x: torch.Tensor, codes: torch.Tensor) -> None: """Update assigned codebook entries via EMA; restart dead codes. - Only codes that received at least one assignment in this batch are - EMA-updated. Applying ``mul_(decay)`` to every code including those - with zero assignments causes their accumulators to decay to 0 and their - codebook entries to be overwritten with ``0 / ε = 0``. - - Dead-code restart: ``_steps_since_active`` is incremented for every - code that receives no assignment in the current batch and reset to 0 - for codes that do. Any code whose counter reaches - ``restart_after_steps`` is replaced with a random encoder output from - the current batch. This handles two distinct failure modes: - - * **Never-used codes** — random-normal or K-means++ entries that the - encoder never visits; the counter starts at 0 and climbs until - restart. + Delegates to :func:`~rectokens.ops.ema_update.ema_update`, which + dispatches to a fused Triton kernel on CUDA and a pure-PyTorch + fallback on CPU. Both implementations are semantically identical: + + 1. Scatter-accumulate ``cluster_size`` and ``embed_sum`` for each + codebook entry without materialising the ``(B, K)`` one-hot matrix. + 2. Blend the new statistics into the EMA buffers for **active codes + only** (codes with zero assignments keep their existing accumulators + so they do not decay to zero). + 3. Recompute each active codebook entry as + ``ema_embed_sum / max(ema_cluster_size, ε)``. + 4. Increment ``_steps_since_active`` for inactive codes; reset to 0 + for active ones. + 5. Replace *dead* codes (counter ≥ ``restart_after_steps``) with a + random encoder output from the current batch and zero their EMA + statistics. + + Dead-code restart handles two failure modes: + + * **Never-used codes** — random-normal or K-means++ entries the + encoder never visits; counter climbs from 0 until restart. * **Abandoned codes** — entries that *were* active but became stranded - as the encoder drifted. Under the active-only EMA update their - ``_ema_cluster_size`` stays positive forever, so a simple - ``ema == 0`` check would never restart them. + as the encoder drifted. The active-only EMA update keeps their + ``_ema_cluster_size`` positive forever, so a pure ``ema == 0`` check + would never restart them. """ - k = self._codebook.size - one_hot = F.one_hot(codes, num_classes=k).float() # (B, K) - cluster_size = one_hot.sum(dim=0) # (K,) - embed_sum = one_hot.t() @ x # (K, D) - - active = cluster_size > 0 # (K,) bool - - # EMA update restricted to active codes — torch.where avoids - # data-dependent boolean indexing that would break torch.compile. - new_ema_cluster_size = ( - self.ema_decay * self._ema_cluster_size - + (1 - self.ema_decay) * cluster_size - ) - self._ema_cluster_size.copy_( - torch.where(active, new_ema_cluster_size, self._ema_cluster_size) - ) - - new_ema_embed_sum = ( - self.ema_decay * self._ema_embed_sum + (1 - self.ema_decay) * embed_sum - ) - self._ema_embed_sum.copy_( - torch.where(active.unsqueeze(1), new_ema_embed_sum, self._ema_embed_sum) - ) - - n = self._ema_cluster_size.clamp(min=1e-5) # (K,) - new_embeddings = self._ema_embed_sum / n.unsqueeze(1) # (K, D) with torch.no_grad(): - self._codebook.embeddings.copy_( - torch.where( - active.unsqueeze(1), new_embeddings, self._codebook.embeddings - ) + ema_update( + x=x, + codes=codes, + ema_cluster_size=self._ema_cluster_size, + ema_embed_sum=self._ema_embed_sum, + codebook=self._codebook.embeddings, + steps_since_active=self._steps_since_active, + decay=self.ema_decay, + restart_after_steps=self.restart_after_steps, ) - # Dead-code restart: track consecutive steps without assignment and - # replace stranded codes with random batch samples. - # - # This handles two failure modes that the EMA-only check misses: - # 1. Codes never seeded into the data manifold (steps_since_active - # increments from 0 until restart_after_steps is reached). - # 2. Codes that WERE active but became stranded as the encoder drifted - # (ema_cluster_size stays > 0 forever under the active-only update, - # so a pure ema==0 check would never restart them). - new_steps = torch.where( - active, - torch.zeros_like(self._steps_since_active), - self._steps_since_active + 1, - ) - dead = new_steps >= self.restart_after_steps # (K,) bool - - # Always draw k replacements (one per code) and apply only for dead - # codes via torch.where — avoids data-dependent shape from n_dead. - rand_idx = torch.randint(len(x), (k,), device=x.device) - replacement = x[rand_idx] # (K, D) - dead_expanded = dead.unsqueeze(1) # (K, 1) - - with torch.no_grad(): - self._codebook.embeddings.copy_( - torch.where(dead_expanded, replacement, self._codebook.embeddings) - ) - self._ema_cluster_size.copy_( - torch.where( - dead, torch.zeros_like(self._ema_cluster_size), self._ema_cluster_size - ) - ) - self._ema_embed_sum.copy_( - torch.where( - dead_expanded, - torch.zeros_like(self._ema_embed_sum), - self._ema_embed_sum, - ) - ) - self._steps_since_active.copy_( - torch.where(dead, torch.zeros_like(new_steps), new_steps) - ) - # Expose as nn.Module forward as well def forward(self, x: torch.Tensor) -> QuantizerOutput: return self.quantize(x) diff --git a/tests/test_ema_update_kernel.py b/tests/test_ema_update_kernel.py new file mode 100644 index 0000000..c4fbbf0 --- /dev/null +++ b/tests/test_ema_update_kernel.py @@ -0,0 +1,302 @@ +"""Tests for the fused Triton EMA-update kernel. + +Each test runs the same scenario through both the pure-PyTorch reference +(``_cpu_ema_update``) and the CUDA Triton kernel (``_cuda_ema_update``) and +asserts numerical equivalence. + +Test matrix: + - All codes active (every codebook entry used at least once) + - Some codes inactive (zero assignments → counter incremented, no EMA change) + - Dead-code restart triggered (counter reaches threshold → replacement) + - Single batch sample (B=1) + - Large codebook relative to batch (forces many inactive / dead codes) + - Non-power-of-two batch size (verifies boundary masking in BLOCK_B loop) +""" + +from __future__ import annotations + +import copy +import unittest + +import torch + +if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA required for EMA kernel tests") + +from rectokens.ops.ema_update import _cpu_ema_update, _cuda_ema_update + +DEVICE = torch.device("cuda") +CPU = torch.device("cpu") + +# Tolerances: the Triton kernel accumulates in fp32; minor rounding differences +# vs. the PyTorch reference are expected. +ATOL = 1e-5 +RTOL = 1e-4 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_state( + K: int, + D: int, + *, + ema_decay: float = 0.99, + seed: int = 0, + device: torch.device = DEVICE, +) -> dict[str, torch.Tensor]: + """Allocate fresh EMA buffers and a codebook for K entries of dimension D.""" + g = torch.Generator(device=device) + g.manual_seed(seed) + return { + "ema_cluster_size": torch.rand(K, device=device, generator=g), + "ema_embed_sum": torch.randn(K, D, device=device, generator=g), + "codebook": torch.randn(K, D, device=device, generator=g), + "steps_since_active": torch.randint(0, 5, (K,), device=device, dtype=torch.long), + } + + +def _clone_state(state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + return {k: v.clone() for k, v in state.items()} + + +def _run_both( + x: torch.Tensor, + codes: torch.Tensor, + state: dict[str, torch.Tensor], + decay: float, + restart_after_steps: int, +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + """Run CPU reference and CUDA kernel on identical state copies. + + Returns ``(cpu_state, gpu_state)`` after the update. + """ + # CPU reference operates on CPU tensors. + cpu_state = {k: v.cpu().clone() for k, v in state.items()} + _cpu_ema_update( + x.cpu(), codes.cpu(), + cpu_state["ema_cluster_size"], + cpu_state["ema_embed_sum"], + cpu_state["codebook"], + cpu_state["steps_since_active"], + decay, + restart_after_steps, + ) + + # CUDA kernel operates on CUDA tensors. + gpu_state = _clone_state(state) + _cuda_ema_update( + x, codes, + gpu_state["ema_cluster_size"], + gpu_state["ema_embed_sum"], + gpu_state["codebook"], + gpu_state["steps_since_active"], + decay, + restart_after_steps, + ) + torch.cuda.synchronize() + + return cpu_state, gpu_state + + +def _assert_states_close( + cpu: dict[str, torch.Tensor], + gpu: dict[str, torch.Tensor], + msg: str = "", +) -> None: + """Assert that the CPU and GPU states are numerically close.""" + for key in ("ema_cluster_size", "ema_embed_sum"): + assert torch.allclose(cpu[key], gpu[key].cpu(), atol=ATOL, rtol=RTOL), ( + f"{msg}: mismatch in {key}\n" + f" cpu={cpu[key].flatten()[:8]}\n" + f" gpu={gpu[key].cpu().flatten()[:8]}" + ) + # steps_since_active must be bit-exact (integer). + assert torch.equal(cpu["steps_since_active"], gpu["steps_since_active"].cpu()), ( + f"{msg}: mismatch in steps_since_active\n" + f" cpu={cpu['steps_since_active']}\n" + f" gpu={gpu['steps_since_active'].cpu()}" + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestEMAUpdateKernel(unittest.TestCase): + + # ── basic EMA update — all codes active ───────────────────────────────── + + def test_all_codes_active(self) -> None: + """Every codebook entry receives at least one assignment.""" + torch.manual_seed(1) + K, D, B = 16, 64, 128 + x = torch.randn(B, D, device=DEVICE) + # Assign at least one sample per code, then fill the rest randomly. + codes = torch.cat([ + torch.arange(K, device=DEVICE), + torch.randint(K, (B - K,), device=DEVICE), + ]) + state = _make_state(K, D) + cpu, gpu = _run_both(x, codes, state, decay=0.99, restart_after_steps=20) + _assert_states_close(cpu, gpu, "all_codes_active") + + # ── some codes inactive ────────────────────────────────────────────────── + + def test_some_codes_inactive(self) -> None: + """Only the first half of codes receive assignments.""" + torch.manual_seed(2) + K, D, B = 16, 64, 64 + x = torch.randn(B, D, device=DEVICE) + # Only codes [0, K//2) are assigned; codes [K//2, K) stay inactive. + codes = torch.randint(K // 2, (B,), device=DEVICE) + state = _make_state(K, D) + cpu, gpu = _run_both(x, codes, state, decay=0.99, restart_after_steps=20) + _assert_states_close(cpu, gpu, "some_codes_inactive") + + # Inactive codes should have their step counter incremented, not reset. + inactive_steps_cpu = cpu["steps_since_active"][K // 2 :] + inactive_steps_gpu = gpu["steps_since_active"].cpu()[K // 2 :] + assert torch.equal(inactive_steps_cpu, inactive_steps_gpu) + + # ── dead-code restart ──────────────────────────────────────────────────── + + def test_dead_code_restart(self) -> None: + """Codes whose counter hits the threshold are replaced. + + We seed ``steps_since_active`` so that inactive codes are exactly one + step away from the threshold, then run one update step without + assigning them. After the update their counters must be 0 and the + EMA accumulators must be zeroed. + """ + torch.manual_seed(3) + K, D, B = 8, 32, 16 + restart_thresh = 5 + + x = torch.randn(B, D, device=DEVICE) + # Only code 0 is active; codes 1..K-1 are inactive and at threshold-1. + codes = torch.zeros(B, device=DEVICE, dtype=torch.long) + state = _make_state(K, D) + # Set inactive codes to one step below the threshold so this update + # pushes them to exactly restart_thresh and triggers a restart. + state["steps_since_active"][1:] = restart_thresh - 1 + + cpu, gpu = _run_both(x, codes, state, decay=0.99, restart_after_steps=restart_thresh) + + # Dead codes (1..K-1) must have their counters reset to 0. + assert (cpu["steps_since_active"][1:] == 0).all(), ( + "CPU: dead-code steps not reset" + ) + assert (gpu["steps_since_active"].cpu()[1:] == 0).all(), ( + "GPU: dead-code steps not reset" + ) + # Dead codes' EMA cluster sizes must be zeroed. + assert (cpu["ema_cluster_size"][1:] == 0).all() + assert torch.allclose(gpu["ema_cluster_size"].cpu()[1:], + torch.zeros(K - 1), atol=ATOL) + # Dead codes' EMA embed sums must be zeroed. + assert (cpu["ema_embed_sum"][1:] == 0).all() + assert torch.allclose(gpu["ema_embed_sum"].cpu()[1:], + torch.zeros(K - 1, D), atol=ATOL) + + # ── single sample ──────────────────────────────────────────────────────── + + def test_single_sample(self) -> None: + """B=1 — exercises the BLOCK_B boundary masking.""" + torch.manual_seed(4) + K, D, B = 8, 64, 1 + x = torch.randn(B, D, device=DEVICE) + codes = torch.zeros(B, device=DEVICE, dtype=torch.long) + state = _make_state(K, D) + cpu, gpu = _run_both(x, codes, state, decay=0.9, restart_after_steps=10) + _assert_states_close(cpu, gpu, "single_sample") + + # ── non-power-of-two batch ─────────────────────────────────────────────── + + def test_non_pow2_batch(self) -> None: + """B=100 — batch size not a multiple of any BLOCK_B config.""" + torch.manual_seed(5) + K, D, B = 32, 64, 100 + x = torch.randn(B, D, device=DEVICE) + codes = torch.randint(K, (B,), device=DEVICE) + state = _make_state(K, D) + cpu, gpu = _run_both(x, codes, state, decay=0.99, restart_after_steps=20) + _assert_states_close(cpu, gpu, "non_pow2_batch") + + # ── large codebook ─────────────────────────────────────────────────────── + + def test_large_codebook(self) -> None: + """K=256, B=64 — most codes are inactive each step.""" + torch.manual_seed(6) + K, D, B = 256, 64, 64 + x = torch.randn(B, D, device=DEVICE) + codes = torch.randint(K, (B,), device=DEVICE) + state = _make_state(K, D) + cpu, gpu = _run_both(x, codes, state, decay=0.99, restart_after_steps=20) + _assert_states_close(cpu, gpu, "large_codebook") + + # ── high-dimensional embeddings ────────────────────────────────────────── + + def test_large_dim(self) -> None: + """D=256 — verifies register pressure is manageable.""" + torch.manual_seed(7) + K, D, B = 64, 256, 128 + x = torch.randn(B, D, device=DEVICE) + codes = torch.randint(K, (B,), device=DEVICE) + state = _make_state(K, D) + cpu, gpu = _run_both(x, codes, state, decay=0.99, restart_after_steps=20) + _assert_states_close(cpu, gpu, "large_dim") + + # ── EMA decay edge cases ───────────────────────────────────────────────── + + def test_decay_zero(self) -> None: + """decay=0 → EMA collapses to per-batch statistics.""" + torch.manual_seed(8) + K, D, B = 16, 64, 64 + x = torch.randn(B, D, device=DEVICE) + codes = torch.arange(K, device=DEVICE).repeat(B // K) + state = _make_state(K, D) + cpu, gpu = _run_both(x, codes, state, decay=0.0, restart_after_steps=20) + _assert_states_close(cpu, gpu, "decay_zero") + + def test_decay_one(self) -> None: + """decay=1 → EMA never changes (new statistics have zero weight).""" + torch.manual_seed(9) + K, D, B = 16, 64, 64 + x = torch.randn(B, D, device=DEVICE) + codes = torch.arange(K, device=DEVICE).repeat(B // K) + state = _make_state(K, D) + # Stash original EMA values. + orig_ema_cs = state["ema_cluster_size"].clone() + orig_ema_es = state["ema_embed_sum"].clone() + + cpu, gpu = _run_both(x, codes, state, decay=1.0, restart_after_steps=20) + _assert_states_close(cpu, gpu, "decay_one") + + # With decay=1 the EMA values must not change at all. + assert torch.allclose(gpu["ema_cluster_size"].cpu(), orig_ema_cs.cpu(), atol=ATOL), ( + "ema_cluster_size changed with decay=1" + ) + assert torch.allclose(gpu["ema_embed_sum"].cpu(), orig_ema_es.cpu(), atol=ATOL), ( + "ema_embed_sum changed with decay=1" + ) + + # ── steps_since_active semantics ───────────────────────────────────────── + + def test_active_code_resets_counter(self) -> None: + """An active code's step counter must be exactly 0 after the update.""" + torch.manual_seed(10) + K, D, B = 4, 32, 8 + x = torch.randn(B, D, device=DEVICE) + codes = torch.zeros(B, device=DEVICE, dtype=torch.long) # only code 0 used + state = _make_state(K, D) + state["steps_since_active"][0] = 99 # should be reset to 0 + + cpu, gpu = _run_both(x, codes, state, decay=0.99, restart_after_steps=20) + _assert_states_close(cpu, gpu, "active_code_resets_counter") + + assert cpu["steps_since_active"][0].item() == 0, "CPU: active code step not reset" + assert gpu["steps_since_active"].cpu()[0].item() == 0, "GPU: active code step not reset"