diff --git a/benchmark/scripts/benchmark_megatron_cross_entropy.py b/benchmark/scripts/benchmark_megatron_cross_entropy.py new file mode 100644 index 000000000..6eca12edc --- /dev/null +++ b/benchmark/scripts/benchmark_megatron_cross_entropy.py @@ -0,0 +1,179 @@ +"""Benchmark Liger's Megatron-LM cross-entropy wrapper. + +Benchmarks the Liger [seq, batch, vocab] cross-entropy wrapper against PyTorch's +native ``F.cross_entropy`` on equivalent input shapes. When megatron-core is +installed, Megatron's own ``fused_vocab_parallel_cross_entropy`` is added as a +third provider to reproduce end-to-end comparisons. + +Requires a Liger-supported accelerator (CUDA / ROCm). With megatron-core not +installed, the "megatron" provider is silently skipped. +""" + +import torch +import torch.nn.functional as F +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.megatron.cross_entropy import _build_wrapper +from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.utils import infer_device + +device = infer_device() + +try: + from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy + + _MEGATRON_AVAILABLE = True +except ImportError: + fused_vocab_parallel_cross_entropy = None + _MEGATRON_AVAILABLE = False + + +def _make_inputs(s: int, b: int, v: int, requires_grad: bool = True): + logits = torch.randn(s, b, v, device=device, dtype=torch.bfloat16, requires_grad=requires_grad) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + return logits, target + + +def _pytorch_cross_entropy(logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + s, b, v = logits.shape + return F.cross_entropy( + logits.reshape(-1, v).float(), + target.reshape(-1), + reduction="none", + ).reshape(s, b) + + +def _ensure_single_rank_tp_group(): + """Initialize torch.distributed (single-rank) and return a usable TP group. + + For a single-process benchmark we use the world group of + size 1, where the internal all-reduce becomes a no-op. + """ + import os + + import torch.distributed as dist + + if not dist.is_initialized(): + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", "29500") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("LOCAL_RANK", "0") + dist.init_process_group(backend="nccl") + return dist.group.WORLD + + +def _select_fwd(provider: str): + if provider == "liger": + wrapper = _build_wrapper(LigerCrossEntropyLoss(reduction="none")) + return wrapper + if provider == "torch": + return _pytorch_cross_entropy + if provider == "megatron": + if not _MEGATRON_AVAILABLE: + raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron' provider") + tp_group = _ensure_single_rank_tp_group() + + def _megatron_call(logits, target): + return fused_vocab_parallel_cross_entropy(logits, target, tp_group) + + return _megatron_call + raise ValueError(f"unknown provider: {provider!r}") + + +def bench_speed_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + v = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + s = input.extra_benchmark_config["S"] + b = input.extra_benchmark_config["B"] + + logits, target = _make_inputs(s, b, v) + fwd_fn = _select_fwd(provider) + + def fwd(): + return fwd_fn(logits, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": + # Megatron's fused CE writes gradients in-place into saved tensors during backward, + # which breaks the standard retain_graph=True / repeated-backward pattern do_bench + # uses elsewhere. Run a fresh fwd+bwd each iteration so each backward sees an + # unmodified autograd graph. Measurement therefore includes forward time — + # subtract the "forward" measurement to derive backward-only timing. + def _fwd_bwd(): + if logits.grad is not None: + logits.grad = None + out = fwd_fn(logits, target) + out.sum().backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(_fwd_bwd, rep=100, quantiles=QUANTILES) + elif mode == "full": + + def full(): + y = fwd() + y.sum().backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"unknown mode: {mode!r}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + v = input.x + provider = input.kernel_provider + s = input.extra_benchmark_config["S"] + b = input.extra_benchmark_config["B"] + + logits, target = _make_inputs(s, b, v) + fwd_fn = _select_fwd(provider) + + def full(): + y = fwd_fn(logits, target) + y.sum().backward() + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + providers = ["liger", "torch"] + if _MEGATRON_AVAILABLE: + providers.append("megatron") + + common_configs = { + "kernel_name": "megatron_cross_entropy", + "x_name": "V", + "x_label": "vocab size", + "x_values": [2**i for i in range(12, 18)], + "kernel_providers": providers, + "extra_benchmark_configs": [{"S": 2048, "B": 4}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_megatron_cross_entropy, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_megatron_cross_entropy, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/docs/High-Level-APIs.md b/docs/High-Level-APIs.md index 5433e03d3..b90fc6e96 100644 --- a/docs/High-Level-APIs.md +++ b/docs/High-Level-APIs.md @@ -91,3 +91,45 @@ You can also use the Patching APIs to use the kernels for a specific model archi extra: show_docstring: true show_signature: true + +--- + +## Megatron-LM + +Liger also exposes a patch for the [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) +training framework, replacing Megatron's native +`fused_vocab_parallel_cross_entropy` with Liger's Triton cross-entropy kernel. + +| **Framework** | **API** | **Supported Operations** | +|---------------|--------------------------------------------------------|--------------------------| +| Megatron-LM | `liger_kernel.megatron.apply_liger_kernel_to_megatron` | CrossEntropyLoss | + +**Scope**: Initial release supports `tensor_model_parallel_size=1` only. +Vocab-parallel cross-entropy (TP>1) is follow-up work — with TP>1, each rank +holds a sharded `[N, V/tp]` logits slice and cross-entropy requires cross-rank +all-reduces that Liger's kernel does not perform. The patch raises a +`RuntimeError` at patch time or call time if TP>1 is detected. + +**Usage**: + +```python +from liger_kernel.megatron import apply_liger_kernel_to_megatron + +# Call before Megatron's forward pass reaches compute_language_model_loss. +# Match Megatron's config: pass the same ignore_index and label_smoothing +# values used by your training setup (Liger does not auto-detect them). +apply_liger_kernel_to_megatron( + ignore_index=-100, + label_smoothing=cfg.label_smoothing_factor, +) +``` + +Ensure Megatron's fused-CE code path is enabled in your training config (e.g. +`--cross-entropy-loss-fusion` in the Megatron-LM CLI) — if the unfused path is +selected, the patched symbol is never called. + +::: liger_kernel.megatron.apply_liger_kernel_to_megatron + options: + extra: + show_docstring: true + show_signature: true diff --git a/src/liger_kernel/megatron/__init__.py b/src/liger_kernel/megatron/__init__.py new file mode 100644 index 000000000..f002d10f9 --- /dev/null +++ b/src/liger_kernel/megatron/__init__.py @@ -0,0 +1,3 @@ +from liger_kernel.megatron.cross_entropy import apply_liger_kernel_to_megatron + +__all__ = ["apply_liger_kernel_to_megatron"] diff --git a/src/liger_kernel/megatron/cross_entropy.py b/src/liger_kernel/megatron/cross_entropy.py new file mode 100644 index 000000000..5ac423efb --- /dev/null +++ b/src/liger_kernel/megatron/cross_entropy.py @@ -0,0 +1,151 @@ +import logging + +import torch + +from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss + +logger = logging.getLogger(__name__) + +_ACTIVATION_LOGGED = False + + +def _check_tensor_parallel_size_at_patch_time() -> None: + """Raise RuntimeError if Megatron's parallel state already reports TP>1. + + If Megatron is importable but the parallel state is not yet initialized + (for example, ``apply_liger_kernel_to_megatron`` is called before + ``initialize_megatron``), silently defer; the wrapper checks again at call + time against the ``tp_group`` argument Megatron supplies. + """ + try: + from megatron.core import parallel_state + except ImportError: + return + try: + tp_size = parallel_state.get_tensor_model_parallel_world_size() + except (AssertionError, RuntimeError): + return + if tp_size > 1: + raise RuntimeError( + f"apply_liger_kernel_to_megatron currently requires tensor_model_parallel_size=1, " + f"got {tp_size}. Vocab-parallel cross-entropy support is planned as follow-up work." + ) + + +def _build_wrapper(loss_fn: LigerCrossEntropyLoss): + """Build a drop-in replacement for ``fused_vocab_parallel_cross_entropy``. + + The returned callable has exactly the same parameter list Megatron expects + (``vocab_parallel_logits``, ``target``, ``tp_group``). Any unknown kwargs + will raise ``TypeError`` naturally — this is intentional: if a future + Megatron release adds new parameters to the fused-CE contract, we want to + fail loudly rather than silently drop them. + """ + + def liger_fused_vocab_parallel_cross_entropy( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + tp_group=None, + ) -> torch.Tensor: + global _ACTIVATION_LOGGED + if not _ACTIVATION_LOGGED: + logger.info( + "Liger cross-entropy kernel is active for Megatron-LM " + "(replacing megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy)." + ) + _ACTIVATION_LOGGED = True + + if tp_group is not None and hasattr(tp_group, "size") and tp_group.size() > 1: + raise RuntimeError( + f"Liger Megatron cross-entropy wrapper requires tensor_model_parallel_size=1, " + f"got tp_group.size()={tp_group.size()}. Vocab-parallel support is tracked as " + f"follow-up work." + ) + + s, b, v = vocab_parallel_logits.shape + logits_2d = vocab_parallel_logits.reshape(-1, v) + target_1d = target.reshape(-1) + loss = loss_fn(logits_2d, target_1d) + return loss.reshape(s, b) + + return liger_fused_vocab_parallel_cross_entropy + + +def apply_liger_kernel_to_megatron( + reduction: str = "none", + ignore_index: int = -100, + label_smoothing: float = 0.0, +) -> None: + """Replace Megatron-LM's fused_vocab_parallel_cross_entropy with Liger's Triton cross-entropy. + + This monkey-patches + ``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy`` + so that Megatron training pipelines use Liger's Triton kernel (online + softmax, in-place gradients, no full-softmax materialization) instead of + Megatron's native fused implementation. + + Args: + reduction: Must be ``"none"``; Megatron's fused-CE contract returns + per-token loss shaped ``[seq, batch]`` and handles reduction itself + downstream. + ignore_index: Target index to ignore. Pass the value used in your + Megatron training config. + label_smoothing: Cross-entropy label smoothing factor. Liger does not + auto-detect this — callers should pass + ``cfg.label_smoothing_factor`` (or equivalent) from their + Megatron ``TransformerConfig`` if label smoothing is enabled, to + preserve the native behavior. + + Scope: + Initial release supports ``tensor_model_parallel_size=1`` only. With + TP>1, each rank holds a vocab-sharded logits slice ``[N, V/tp]`` and + computing cross-entropy requires cross-rank all-reduces that Liger's + kernel does not perform. A ``RuntimeError`` is raised at patch time if + the Megatron parallel state already reports TP>1, and again at call + time if a multi-rank ``tp_group`` is passed. + + Raises: + AssertionError: If ``reduction != "none"``. + ImportError: If ``megatron.core.fusions.fused_cross_entropy`` is not + importable, or if the expected + ``fused_vocab_parallel_cross_entropy`` symbol is missing from that + module (indicating an incompatible Megatron version). + RuntimeError: If tensor model parallelism > 1 is detected. + + Example: + >>> from liger_kernel.megatron import apply_liger_kernel_to_megatron + >>> apply_liger_kernel_to_megatron( + ... ignore_index=-100, + ... label_smoothing=cfg.label_smoothing_factor, + ... ) + >>> # call before Megatron's forward pass reaches compute_language_model_loss + """ + assert reduction == "none", ( + f"Megatron's fused_vocab_parallel_cross_entropy contract requires per-token loss; " + f"reduction must be 'none', got {reduction!r}." + ) + + try: + import megatron.core.fusions.fused_cross_entropy as fce + except ImportError as exc: + raise ImportError( + "apply_liger_kernel_to_megatron requires megatron-core to be installed. " + "Expected symbol path: " + "megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy." + ) from exc + + if not hasattr(fce, "fused_vocab_parallel_cross_entropy"): + raise ImportError( + "megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy not found. " + "The symbol path may have changed in your Megatron-LM version. Please file an issue " + "on https://github.com/linkedin/Liger-Kernel with your megatron-core version." + ) + + _check_tensor_parallel_size_at_patch_time() + + loss_fn = LigerCrossEntropyLoss( + ignore_index=ignore_index, + label_smoothing=label_smoothing, + reduction="none", + ) + fce.fused_vocab_parallel_cross_entropy = _build_wrapper(loss_fn) diff --git a/test/megatron/__init__.py b/test/megatron/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/megatron/test_cross_entropy.py b/test/megatron/test_cross_entropy.py new file mode 100644 index 000000000..f995b051e --- /dev/null +++ b/test/megatron/test_cross_entropy.py @@ -0,0 +1,157 @@ +"""Correctness tests for the Liger Megatron cross-entropy wrapper. + +These tests exercise ``_build_wrapper`` directly without importing +megatron-core — the wrapper is the [s, b, v] -> [s, b] reshape shim around +``LigerCrossEntropyLoss`` and is meaningful to test on its own. + +The wrapper calls the underlying Triton kernel, so these tests require a +Liger-supported accelerator (same as ``test/transformers/test_cross_entropy.py``). +""" + +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.megatron.cross_entropy import _build_wrapper +from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.utils import infer_device +from test.utils import assert_verbose_allclose +from test.utils import set_seed +from test.utils import supports_bfloat16 + +device = infer_device() +set_seed(42) + + +def _make_wrapper(ignore_index: int = -100, label_smoothing: float = 0.0): + loss_fn = LigerCrossEntropyLoss( + ignore_index=ignore_index, + label_smoothing=label_smoothing, + reduction="none", + ) + return _build_wrapper(loss_fn) + + +def _reference_loss( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + ignore_index: int, + label_smoothing: float, +) -> torch.Tensor: + s, b, v = vocab_parallel_logits.shape + loss_flat = F.cross_entropy( + vocab_parallel_logits.reshape(-1, v).float(), + target.reshape(-1), + reduction="none", + ignore_index=ignore_index, + label_smoothing=label_smoothing, + ) + return loss_flat.reshape(s, b) + + +@pytest.mark.parametrize( + "s, b, v", + [ + (8, 2, 128), + (16, 4, 4096), + (32, 1, 32000), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-7, 1e-6), + pytest.param( + torch.bfloat16, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_wrapper_matches_pytorch_cross_entropy(s, b, v, dtype, atol, rtol): + wrapper = _make_wrapper() + + logits = torch.randn(s, b, v, device=device, dtype=dtype) * 0.5 + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + ref = _reference_loss(logits, target, ignore_index=-100, label_smoothing=0.0) + got = wrapper(logits, target) + + assert got.shape == (s, b) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("ignore_index", [-100, 0]) +def test_wrapper_respects_ignore_index(ignore_index): + s, b, v = 16, 2, 1024 + wrapper = _make_wrapper(ignore_index=ignore_index) + + logits = torch.randn(s, b, v, device=device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + target.view(-1)[: (s * b) // 4] = ignore_index + + ref = _reference_loss(logits, target, ignore_index=ignore_index, label_smoothing=0.0) + got = wrapper(logits, target) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) + + +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) +def test_wrapper_respects_label_smoothing(label_smoothing): + s, b, v = 8, 2, 512 + wrapper = _make_wrapper(label_smoothing=label_smoothing) + + logits = torch.randn(s, b, v, device=device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + ref = _reference_loss(logits, target, ignore_index=-100, label_smoothing=label_smoothing) + got = wrapper(logits, target) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-5, rtol=1e-4) + + +def test_wrapper_rejects_unknown_kwargs(): + wrapper = _make_wrapper() + logits = torch.randn(4, 1, 32, device=device) + target = torch.randint(0, 32, (4, 1), device=device, dtype=torch.long) + with pytest.raises(TypeError): + wrapper(logits, target, unknown_arg=123) + + +def test_wrapper_rejects_multi_rank_tp_group(): + wrapper = _make_wrapper() + logits = torch.randn(4, 1, 32, device=device) + target = torch.randint(0, 32, (4, 1), device=device, dtype=torch.long) + + class _FakeGroup: + def size(self): + return 2 + + with pytest.raises(RuntimeError, match="tensor_model_parallel_size=1"): + wrapper(logits, target, tp_group=_FakeGroup()) + + +def test_wrapper_accepts_single_rank_tp_group(): + wrapper = _make_wrapper() + logits = torch.randn(4, 1, 32, device=device) + target = torch.randint(0, 32, (4, 1), device=device, dtype=torch.long) + + class _FakeGroup: + def size(self): + return 1 + + out = wrapper(logits, target, tp_group=_FakeGroup()) + assert out.shape == (4, 1) + + +def test_wrapper_preserves_gradients(): + s, b, v = 8, 2, 256 + wrapper = _make_wrapper() + + logits = torch.randn(s, b, v, device=device, dtype=torch.float32, requires_grad=True) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + loss = wrapper(logits, target).sum() + loss.backward() + + assert logits.grad is not None + assert logits.grad.shape == logits.shape diff --git a/test/megatron/test_monkey_patch.py b/test/megatron/test_monkey_patch.py new file mode 100644 index 000000000..f033e9b70 --- /dev/null +++ b/test/megatron/test_monkey_patch.py @@ -0,0 +1,166 @@ +"""Tests for apply_liger_kernel_to_megatron's patch mechanism. + +Megatron-LM is not a test dependency. We inject stub modules into +``sys.modules`` so the patch function can run entirely on CPU without a real +megatron-core install. Tests verify: + +- the patch replaces ``fused_vocab_parallel_cross_entropy`` on the stub module +- ``reduction != "none"`` is rejected +- TP>1 at patch time raises RuntimeError +- missing megatron-core raises a helpful ImportError +- missing symbol path raises a helpful ImportError +- the constructed LigerCrossEntropyLoss receives the user-supplied kwargs +""" + +import sys +import types + +from unittest.mock import patch + +import pytest + + +def _install_fake_megatron(tp_size: int = 1, with_fused_symbol: bool = True): + """Install stub megatron modules into sys.modules; return the fused module.""" + megatron = types.ModuleType("megatron") + megatron_core = types.ModuleType("megatron.core") + fusions = types.ModuleType("megatron.core.fusions") + fused_ce = types.ModuleType("megatron.core.fusions.fused_cross_entropy") + parallel_state = types.ModuleType("megatron.core.parallel_state") + + if with_fused_symbol: + + def original_fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target, tp_group=None): + raise AssertionError("original megatron kernel called — patch failed") + + fused_ce.fused_vocab_parallel_cross_entropy = original_fused_vocab_parallel_cross_entropy + + parallel_state.get_tensor_model_parallel_world_size = lambda: tp_size + + sys.modules["megatron"] = megatron + sys.modules["megatron.core"] = megatron_core + sys.modules["megatron.core.fusions"] = fusions + sys.modules["megatron.core.fusions.fused_cross_entropy"] = fused_ce + sys.modules["megatron.core.parallel_state"] = parallel_state + + megatron.core = megatron_core + megatron_core.fusions = fusions + megatron_core.parallel_state = parallel_state + fusions.fused_cross_entropy = fused_ce + + return fused_ce + + +def _uninstall_fake_megatron(): + for mod in [ + "megatron.core.parallel_state", + "megatron.core.fusions.fused_cross_entropy", + "megatron.core.fusions", + "megatron.core", + "megatron", + ]: + sys.modules.pop(mod, None) + + +@pytest.fixture +def fake_megatron(): + fused_ce = _install_fake_megatron(tp_size=1) + try: + yield fused_ce + finally: + _uninstall_fake_megatron() + + +def test_patch_replaces_fused_symbol(fake_megatron): + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + original = fake_megatron.fused_vocab_parallel_cross_entropy + apply_liger_kernel_to_megatron() + patched = fake_megatron.fused_vocab_parallel_cross_entropy + + assert patched is not original + assert patched.__name__ == "liger_fused_vocab_parallel_cross_entropy" + + +def test_patch_rejects_non_none_reduction(fake_megatron): + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(AssertionError, match="reduction must be 'none'"): + apply_liger_kernel_to_megatron(reduction="mean") + + +def test_patch_raises_on_tp_greater_than_one(): + _install_fake_megatron(tp_size=2) + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(RuntimeError, match="tensor_model_parallel_size=1"): + apply_liger_kernel_to_megatron() + finally: + _uninstall_fake_megatron() + + +def test_patch_defers_tp_check_when_parallel_state_not_initialized(): + """If get_tensor_model_parallel_world_size() raises, patch should still succeed.""" + fused_ce = _install_fake_megatron(tp_size=1) + + def raising_tp_size(): + raise AssertionError("parallel_state not initialized") + + sys.modules["megatron.core.parallel_state"].get_tensor_model_parallel_world_size = raising_tp_size + + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron() + assert fused_ce.fused_vocab_parallel_cross_entropy.__name__ == "liger_fused_vocab_parallel_cross_entropy" + finally: + _uninstall_fake_megatron() + + +def test_patch_raises_when_megatron_not_installed(): + _uninstall_fake_megatron() + # Block imports of any "megatron*" module to simulate absent install. + real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ + + def blocking_import(name, *args, **kwargs): + if name == "megatron" or name.startswith("megatron."): + raise ImportError(f"No module named {name!r}") + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=blocking_import): + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(ImportError, match="requires megatron-core"): + apply_liger_kernel_to_megatron() + + +def test_patch_raises_when_fused_symbol_missing(): + _install_fake_megatron(tp_size=1, with_fused_symbol=False) + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(ImportError, match="symbol path may have changed"): + apply_liger_kernel_to_megatron() + finally: + _uninstall_fake_megatron() + + +def test_patch_forwards_ignore_index_and_label_smoothing(fake_megatron): + from liger_kernel.megatron import cross_entropy as mod + + captured = {} + + class FakeLoss: + def __init__(self, ignore_index, label_smoothing, reduction): + captured["ignore_index"] = ignore_index + captured["label_smoothing"] = label_smoothing + captured["reduction"] = reduction + + def __call__(self, _input, target): + raise AssertionError("not expected to be called in this test") + + with patch.object(mod, "LigerCrossEntropyLoss", FakeLoss): + mod.apply_liger_kernel_to_megatron(ignore_index=42, label_smoothing=0.25) + + assert captured == {"ignore_index": 42, "label_smoothing": 0.25, "reduction": "none"}