Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions benchmark/scripts/benchmark_megatron_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Benchmark Liger's Megatron-LM cross-entropy wrapper.

Benchmarks the Liger [seq, batch, vocab] cross-entropy wrapper against PyTorch's
native ``F.cross_entropy`` on equivalent input shapes. When megatron-core is
installed, Megatron's own ``fused_vocab_parallel_cross_entropy`` is added as a
third provider to reproduce end-to-end comparisons.

Requires a Liger-supported accelerator (CUDA / ROCm). With megatron-core not
installed, the "megatron" provider is silently skipped.
"""

import torch
import torch.nn.functional as F
import triton

from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.megatron.cross_entropy import _build_wrapper
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.utils import infer_device

device = infer_device()

try:
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy

_MEGATRON_AVAILABLE = True
except ImportError:
fused_vocab_parallel_cross_entropy = None
_MEGATRON_AVAILABLE = False


def _make_inputs(s: int, b: int, v: int, requires_grad: bool = True):
logits = torch.randn(s, b, v, device=device, dtype=torch.bfloat16, requires_grad=requires_grad)
target = torch.randint(0, v, (s, b), device=device, dtype=torch.long)
return logits, target


def _pytorch_cross_entropy(logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
s, b, v = logits.shape
return F.cross_entropy(
logits.reshape(-1, v).float(),
target.reshape(-1),
reduction="none",
).reshape(s, b)


def _ensure_single_rank_tp_group():
"""Initialize torch.distributed (single-rank) and return a usable TP group.

For a single-process benchmark we use the world group of
size 1, where the internal all-reduce becomes a no-op.
"""
import os

import torch.distributed as dist

if not dist.is_initialized():
os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("LOCAL_RANK", "0")
dist.init_process_group(backend="nccl")
return dist.group.WORLD


def _select_fwd(provider: str):
if provider == "liger":
wrapper = _build_wrapper(LigerCrossEntropyLoss(reduction="none"))
return wrapper
if provider == "torch":
return _pytorch_cross_entropy
if provider == "megatron":
if not _MEGATRON_AVAILABLE:
raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron' provider")
tp_group = _ensure_single_rank_tp_group()

def _megatron_call(logits, target):
return fused_vocab_parallel_cross_entropy(logits, target, tp_group)

return _megatron_call
raise ValueError(f"unknown provider: {provider!r}")


def bench_speed_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
v = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
s = input.extra_benchmark_config["S"]
b = input.extra_benchmark_config["B"]

logits, target = _make_inputs(s, b, v)
fwd_fn = _select_fwd(provider)

def fwd():
return fwd_fn(logits, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
elif mode == "backward":
# Megatron's fused CE writes gradients in-place into saved tensors during backward,
# which breaks the standard retain_graph=True / repeated-backward pattern do_bench
# uses elsewhere. Run a fresh fwd+bwd each iteration so each backward sees an
# unmodified autograd graph. Measurement therefore includes forward time —
# subtract the "forward" measurement to derive backward-only timing.
def _fwd_bwd():
if logits.grad is not None:
logits.grad = None
out = fwd_fn(logits, target)
out.sum().backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(_fwd_bwd, rep=100, quantiles=QUANTILES)
elif mode == "full":

def full():
y = fwd()
y.sum().backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES)
else:
raise ValueError(f"unknown mode: {mode!r}")

return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80)


def bench_memory_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
v = input.x
provider = input.kernel_provider
s = input.extra_benchmark_config["S"]
b = input.extra_benchmark_config["B"]

logits, target = _make_inputs(s, b, v)
fwd_fn = _select_fwd(provider)

def full():
y = fwd_fn(logits, target)
y.sum().backward()

mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80)


if __name__ == "__main__":
args = parse_benchmark_script_args()

providers = ["liger", "torch"]
if _MEGATRON_AVAILABLE:
providers.append("megatron")

common_configs = {
"kernel_name": "megatron_cross_entropy",
"x_name": "V",
"x_label": "vocab size",
"x_values": [2**i for i in range(12, 18)],
"kernel_providers": providers,
"extra_benchmark_configs": [{"S": 2048, "B": 4}],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_megatron_cross_entropy,
kernel_operation_modes=["forward", "backward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_megatron_cross_entropy,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
42 changes: 42 additions & 0 deletions docs/High-Level-APIs.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,45 @@ You can also use the Patching APIs to use the kernels for a specific model archi
extra:
show_docstring: true
show_signature: true

---

## Megatron-LM

Liger also exposes a patch for the [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)
training framework, replacing Megatron's native
`fused_vocab_parallel_cross_entropy` with Liger's Triton cross-entropy kernel.

| **Framework** | **API** | **Supported Operations** |
|---------------|--------------------------------------------------------|--------------------------|
| Megatron-LM | `liger_kernel.megatron.apply_liger_kernel_to_megatron` | CrossEntropyLoss |

**Scope**: Initial release supports `tensor_model_parallel_size=1` only.
Vocab-parallel cross-entropy (TP>1) is follow-up work — with TP>1, each rank
holds a sharded `[N, V/tp]` logits slice and cross-entropy requires cross-rank
all-reduces that Liger's kernel does not perform. The patch raises a
`RuntimeError` at patch time or call time if TP>1 is detected.

**Usage**:

```python
from liger_kernel.megatron import apply_liger_kernel_to_megatron

# Call before Megatron's forward pass reaches compute_language_model_loss.
# Match Megatron's config: pass the same ignore_index and label_smoothing
# values used by your training setup (Liger does not auto-detect them).
apply_liger_kernel_to_megatron(
ignore_index=-100,
label_smoothing=cfg.label_smoothing_factor,
)
```

Ensure Megatron's fused-CE code path is enabled in your training config (e.g.
`--cross-entropy-loss-fusion` in the Megatron-LM CLI) — if the unfused path is
selected, the patched symbol is never called.

::: liger_kernel.megatron.apply_liger_kernel_to_megatron
options:
extra:
show_docstring: true
show_signature: true
3 changes: 3 additions & 0 deletions src/liger_kernel/megatron/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from liger_kernel.megatron.cross_entropy import apply_liger_kernel_to_megatron

__all__ = ["apply_liger_kernel_to_megatron"]
151 changes: 151 additions & 0 deletions src/liger_kernel/megatron/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import logging

import torch

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss

logger = logging.getLogger(__name__)

_ACTIVATION_LOGGED = False


def _check_tensor_parallel_size_at_patch_time() -> None:
"""Raise RuntimeError if Megatron's parallel state already reports TP>1.

If Megatron is importable but the parallel state is not yet initialized
(for example, ``apply_liger_kernel_to_megatron`` is called before
``initialize_megatron``), silently defer; the wrapper checks again at call
time against the ``tp_group`` argument Megatron supplies.
"""
try:
from megatron.core import parallel_state
except ImportError:
return
try:
tp_size = parallel_state.get_tensor_model_parallel_world_size()
except (AssertionError, RuntimeError):
return
if tp_size > 1:
raise RuntimeError(
f"apply_liger_kernel_to_megatron currently requires tensor_model_parallel_size=1, "
f"got {tp_size}. Vocab-parallel cross-entropy support is planned as follow-up work."
)
Comment on lines +28 to +32
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a constrain that need to be addressed in the future given that TP is a common use case in Megatron, but it's a great start supporting megatron!

BTW, does this patching also not support other parallel strategy? (Sequence Parallel, etc)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels a bit awkward to me to have patching and function wrapping logics in liger side. Surely it is a simpler way to use liger's ce without touching megatron codebase. However, if supporting megatron framework is not in our roadmap, and not going to add it to our test suite in a short time, it will be quite inconvenient to maintain this support in a long run. WDYT?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, megatron's SP requires TP>1



def _build_wrapper(loss_fn: LigerCrossEntropyLoss):
"""Build a drop-in replacement for ``fused_vocab_parallel_cross_entropy``.

The returned callable has exactly the same parameter list Megatron expects
(``vocab_parallel_logits``, ``target``, ``tp_group``). Any unknown kwargs
will raise ``TypeError`` naturally — this is intentional: if a future
Megatron release adds new parameters to the fused-CE contract, we want to
fail loudly rather than silently drop them.
"""

def liger_fused_vocab_parallel_cross_entropy(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
tp_group=None,
) -> torch.Tensor:
global _ACTIVATION_LOGGED
if not _ACTIVATION_LOGGED:
Comment on lines +50 to +51
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary?

logger.info(
"Liger cross-entropy kernel is active for Megatron-LM "
"(replacing megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy)."
)
_ACTIVATION_LOGGED = True

if tp_group is not None and hasattr(tp_group, "size") and tp_group.size() > 1:
raise RuntimeError(
f"Liger Megatron cross-entropy wrapper requires tensor_model_parallel_size=1, "
f"got tp_group.size()={tp_group.size()}. Vocab-parallel support is tracked as "
f"follow-up work."
)

s, b, v = vocab_parallel_logits.shape
logits_2d = vocab_parallel_logits.reshape(-1, v)
target_1d = target.reshape(-1)
loss = loss_fn(logits_2d, target_1d)
return loss.reshape(s, b)

return liger_fused_vocab_parallel_cross_entropy


def apply_liger_kernel_to_megatron(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move it to another file like monkey_patch.py under the same directory? If we want to add more kernel besides CE, it would be cleaner to separate the framework-level and kernel-specific logic. You can mirror src/liger_kernel/trainsformers/:

src/liger_kernel/metatron/
    monkey_patch.py      # apply_liger_kernel_to_megatron + TP check
    cross_entropy.py       # _build_wrapper + _patch_fused_vocab_parallel_ce
    other_future_kernel.pys

reduction: str = "none",
ignore_index: int = -100,
label_smoothing: float = 0.0,
) -> None:
"""Replace Megatron-LM's fused_vocab_parallel_cross_entropy with Liger's Triton cross-entropy.

This monkey-patches
``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy``
so that Megatron training pipelines use Liger's Triton kernel (online
softmax, in-place gradients, no full-softmax materialization) instead of
Megatron's native fused implementation.

Args:
reduction: Must be ``"none"``; Megatron's fused-CE contract returns
per-token loss shaped ``[seq, batch]`` and handles reduction itself
downstream.
ignore_index: Target index to ignore. Pass the value used in your
Megatron training config.
label_smoothing: Cross-entropy label smoothing factor. Liger does not
auto-detect this — callers should pass
``cfg.label_smoothing_factor`` (or equivalent) from their
Megatron ``TransformerConfig`` if label smoothing is enabled, to
preserve the native behavior.

Scope:
Initial release supports ``tensor_model_parallel_size=1`` only. With
TP>1, each rank holds a vocab-sharded logits slice ``[N, V/tp]`` and
computing cross-entropy requires cross-rank all-reduces that Liger's
kernel does not perform. A ``RuntimeError`` is raised at patch time if
the Megatron parallel state already reports TP>1, and again at call
time if a multi-rank ``tp_group`` is passed.

Raises:
AssertionError: If ``reduction != "none"``.
ImportError: If ``megatron.core.fusions.fused_cross_entropy`` is not
importable, or if the expected
``fused_vocab_parallel_cross_entropy`` symbol is missing from that
module (indicating an incompatible Megatron version).
RuntimeError: If tensor model parallelism > 1 is detected.

Example:
>>> from liger_kernel.megatron import apply_liger_kernel_to_megatron
>>> apply_liger_kernel_to_megatron(
... ignore_index=-100,
... label_smoothing=cfg.label_smoothing_factor,
... )
>>> # call before Megatron's forward pass reaches compute_language_model_loss
"""
assert reduction == "none", (
f"Megatron's fused_vocab_parallel_cross_entropy contract requires per-token loss; "
f"reduction must be 'none', got {reduction!r}."
)

try:
import megatron.core.fusions.fused_cross_entropy as fce
except ImportError as exc:
raise ImportError(
"apply_liger_kernel_to_megatron requires megatron-core to be installed. "
"Expected symbol path: "
"megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy."
) from exc

if not hasattr(fce, "fused_vocab_parallel_cross_entropy"):
raise ImportError(
"megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy not found. "
"The symbol path may have changed in your Megatron-LM version. Please file an issue "
"on https://github.com/linkedin/Liger-Kernel with your megatron-core version."
)

_check_tensor_parallel_size_at_patch_time()

loss_fn = LigerCrossEntropyLoss(
ignore_index=ignore_index,
label_smoothing=label_smoothing,
reduction="none",
)
fce.fused_vocab_parallel_cross_entropy = _build_wrapper(loss_fn)
Empty file added test/megatron/__init__.py
Empty file.
Loading
Loading