-
Notifications
You must be signed in to change notification settings - Fork 529
Add Megatron-LM cross-entropy integration #1207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| """Benchmark Liger's Megatron-LM cross-entropy wrapper. | ||
|
|
||
| Benchmarks the Liger [seq, batch, vocab] cross-entropy wrapper against PyTorch's | ||
| native ``F.cross_entropy`` on equivalent input shapes. When megatron-core is | ||
| installed, Megatron's own ``fused_vocab_parallel_cross_entropy`` is added as a | ||
| third provider to reproduce end-to-end comparisons. | ||
|
|
||
| Requires a Liger-supported accelerator (CUDA / ROCm). With megatron-core not | ||
| installed, the "megatron" provider is silently skipped. | ||
| """ | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| import triton | ||
|
|
||
| from utils import QUANTILES | ||
| from utils import SingleBenchmarkRunInput | ||
| from utils import SingleBenchmarkRunOutput | ||
| from utils import _test_memory | ||
| from utils import parse_benchmark_script_args | ||
| from utils import run_benchmarks | ||
|
|
||
| from liger_kernel.megatron.cross_entropy import _build_wrapper | ||
| from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss | ||
| from liger_kernel.utils import infer_device | ||
|
|
||
| device = infer_device() | ||
|
|
||
| try: | ||
| from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy | ||
|
|
||
| _MEGATRON_AVAILABLE = True | ||
| except ImportError: | ||
| fused_vocab_parallel_cross_entropy = None | ||
| _MEGATRON_AVAILABLE = False | ||
|
|
||
|
|
||
| def _make_inputs(s: int, b: int, v: int, requires_grad: bool = True): | ||
| logits = torch.randn(s, b, v, device=device, dtype=torch.bfloat16, requires_grad=requires_grad) | ||
| target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) | ||
| return logits, target | ||
|
|
||
|
|
||
| def _pytorch_cross_entropy(logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
| s, b, v = logits.shape | ||
| return F.cross_entropy( | ||
| logits.reshape(-1, v).float(), | ||
| target.reshape(-1), | ||
| reduction="none", | ||
| ).reshape(s, b) | ||
|
|
||
|
|
||
| def _ensure_single_rank_tp_group(): | ||
| """Initialize torch.distributed (single-rank) and return a usable TP group. | ||
|
|
||
| For a single-process benchmark we use the world group of | ||
| size 1, where the internal all-reduce becomes a no-op. | ||
| """ | ||
| import os | ||
|
|
||
| import torch.distributed as dist | ||
|
|
||
| if not dist.is_initialized(): | ||
| os.environ.setdefault("MASTER_ADDR", "localhost") | ||
| os.environ.setdefault("MASTER_PORT", "29500") | ||
| os.environ.setdefault("WORLD_SIZE", "1") | ||
| os.environ.setdefault("RANK", "0") | ||
| os.environ.setdefault("LOCAL_RANK", "0") | ||
| dist.init_process_group(backend="nccl") | ||
| return dist.group.WORLD | ||
|
|
||
|
|
||
| def _select_fwd(provider: str): | ||
| if provider == "liger": | ||
| wrapper = _build_wrapper(LigerCrossEntropyLoss(reduction="none")) | ||
| return wrapper | ||
| if provider == "torch": | ||
| return _pytorch_cross_entropy | ||
| if provider == "megatron": | ||
| if not _MEGATRON_AVAILABLE: | ||
| raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron' provider") | ||
| tp_group = _ensure_single_rank_tp_group() | ||
|
|
||
| def _megatron_call(logits, target): | ||
| return fused_vocab_parallel_cross_entropy(logits, target, tp_group) | ||
|
|
||
| return _megatron_call | ||
| raise ValueError(f"unknown provider: {provider!r}") | ||
|
|
||
|
|
||
| def bench_speed_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
| v = input.x | ||
| provider = input.kernel_provider | ||
| mode = input.kernel_operation_mode | ||
| s = input.extra_benchmark_config["S"] | ||
| b = input.extra_benchmark_config["B"] | ||
|
|
||
| logits, target = _make_inputs(s, b, v) | ||
| fwd_fn = _select_fwd(provider) | ||
|
|
||
| def fwd(): | ||
| return fwd_fn(logits, target) | ||
|
|
||
| if mode == "forward": | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) | ||
| elif mode == "backward": | ||
| # Megatron's fused CE writes gradients in-place into saved tensors during backward, | ||
| # which breaks the standard retain_graph=True / repeated-backward pattern do_bench | ||
| # uses elsewhere. Run a fresh fwd+bwd each iteration so each backward sees an | ||
| # unmodified autograd graph. Measurement therefore includes forward time — | ||
| # subtract the "forward" measurement to derive backward-only timing. | ||
| def _fwd_bwd(): | ||
| if logits.grad is not None: | ||
| logits.grad = None | ||
| out = fwd_fn(logits, target) | ||
| out.sum().backward() | ||
|
|
||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(_fwd_bwd, rep=100, quantiles=QUANTILES) | ||
| elif mode == "full": | ||
|
|
||
| def full(): | ||
| y = fwd() | ||
| y.sum().backward() | ||
|
|
||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) | ||
| else: | ||
| raise ValueError(f"unknown mode: {mode!r}") | ||
|
|
||
| return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) | ||
|
|
||
|
|
||
| def bench_memory_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: | ||
| v = input.x | ||
| provider = input.kernel_provider | ||
| s = input.extra_benchmark_config["S"] | ||
| b = input.extra_benchmark_config["B"] | ||
|
|
||
| logits, target = _make_inputs(s, b, v) | ||
| fwd_fn = _select_fwd(provider) | ||
|
|
||
| def full(): | ||
| y = fwd_fn(logits, target) | ||
| y.sum().backward() | ||
|
|
||
| mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) | ||
| return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_benchmark_script_args() | ||
|
|
||
| providers = ["liger", "torch"] | ||
| if _MEGATRON_AVAILABLE: | ||
| providers.append("megatron") | ||
|
|
||
| common_configs = { | ||
| "kernel_name": "megatron_cross_entropy", | ||
| "x_name": "V", | ||
| "x_label": "vocab size", | ||
| "x_values": [2**i for i in range(12, 18)], | ||
| "kernel_providers": providers, | ||
| "extra_benchmark_configs": [{"S": 2048, "B": 4}], | ||
| "overwrite": args.overwrite, | ||
| } | ||
|
|
||
| run_benchmarks( | ||
| bench_test_fn=bench_speed_megatron_cross_entropy, | ||
| kernel_operation_modes=["forward", "backward", "full"], | ||
| metric_name="speed", | ||
| metric_unit="ms", | ||
| **common_configs, | ||
| ) | ||
| run_benchmarks( | ||
| bench_test_fn=bench_memory_megatron_cross_entropy, | ||
| kernel_operation_modes=["full"], | ||
| metric_name="memory", | ||
| metric_unit="MB", | ||
| **common_configs, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from liger_kernel.megatron.cross_entropy import apply_liger_kernel_to_megatron | ||
|
|
||
| __all__ = ["apply_liger_kernel_to_megatron"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| import logging | ||
|
|
||
| import torch | ||
|
|
||
| from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _ACTIVATION_LOGGED = False | ||
|
|
||
|
|
||
| def _check_tensor_parallel_size_at_patch_time() -> None: | ||
| """Raise RuntimeError if Megatron's parallel state already reports TP>1. | ||
|
|
||
| If Megatron is importable but the parallel state is not yet initialized | ||
| (for example, ``apply_liger_kernel_to_megatron`` is called before | ||
| ``initialize_megatron``), silently defer; the wrapper checks again at call | ||
| time against the ``tp_group`` argument Megatron supplies. | ||
| """ | ||
| try: | ||
| from megatron.core import parallel_state | ||
| except ImportError: | ||
| return | ||
| try: | ||
| tp_size = parallel_state.get_tensor_model_parallel_world_size() | ||
| except (AssertionError, RuntimeError): | ||
| return | ||
| if tp_size > 1: | ||
| raise RuntimeError( | ||
| f"apply_liger_kernel_to_megatron currently requires tensor_model_parallel_size=1, " | ||
| f"got {tp_size}. Vocab-parallel cross-entropy support is planned as follow-up work." | ||
| ) | ||
|
|
||
|
|
||
| def _build_wrapper(loss_fn: LigerCrossEntropyLoss): | ||
| """Build a drop-in replacement for ``fused_vocab_parallel_cross_entropy``. | ||
|
|
||
| The returned callable has exactly the same parameter list Megatron expects | ||
| (``vocab_parallel_logits``, ``target``, ``tp_group``). Any unknown kwargs | ||
| will raise ``TypeError`` naturally — this is intentional: if a future | ||
| Megatron release adds new parameters to the fused-CE contract, we want to | ||
| fail loudly rather than silently drop them. | ||
| """ | ||
|
|
||
| def liger_fused_vocab_parallel_cross_entropy( | ||
| vocab_parallel_logits: torch.Tensor, | ||
| target: torch.Tensor, | ||
| tp_group=None, | ||
| ) -> torch.Tensor: | ||
| global _ACTIVATION_LOGGED | ||
| if not _ACTIVATION_LOGGED: | ||
|
Comment on lines
+50
to
+51
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this necessary? |
||
| logger.info( | ||
| "Liger cross-entropy kernel is active for Megatron-LM " | ||
| "(replacing megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy)." | ||
| ) | ||
| _ACTIVATION_LOGGED = True | ||
|
|
||
| if tp_group is not None and hasattr(tp_group, "size") and tp_group.size() > 1: | ||
| raise RuntimeError( | ||
| f"Liger Megatron cross-entropy wrapper requires tensor_model_parallel_size=1, " | ||
| f"got tp_group.size()={tp_group.size()}. Vocab-parallel support is tracked as " | ||
| f"follow-up work." | ||
| ) | ||
|
|
||
| s, b, v = vocab_parallel_logits.shape | ||
| logits_2d = vocab_parallel_logits.reshape(-1, v) | ||
| target_1d = target.reshape(-1) | ||
| loss = loss_fn(logits_2d, target_1d) | ||
| return loss.reshape(s, b) | ||
|
|
||
| return liger_fused_vocab_parallel_cross_entropy | ||
|
|
||
|
|
||
| def apply_liger_kernel_to_megatron( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move it to another file like |
||
| reduction: str = "none", | ||
| ignore_index: int = -100, | ||
| label_smoothing: float = 0.0, | ||
| ) -> None: | ||
| """Replace Megatron-LM's fused_vocab_parallel_cross_entropy with Liger's Triton cross-entropy. | ||
|
|
||
| This monkey-patches | ||
| ``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy`` | ||
| so that Megatron training pipelines use Liger's Triton kernel (online | ||
| softmax, in-place gradients, no full-softmax materialization) instead of | ||
| Megatron's native fused implementation. | ||
|
|
||
| Args: | ||
| reduction: Must be ``"none"``; Megatron's fused-CE contract returns | ||
| per-token loss shaped ``[seq, batch]`` and handles reduction itself | ||
| downstream. | ||
| ignore_index: Target index to ignore. Pass the value used in your | ||
| Megatron training config. | ||
| label_smoothing: Cross-entropy label smoothing factor. Liger does not | ||
| auto-detect this — callers should pass | ||
| ``cfg.label_smoothing_factor`` (or equivalent) from their | ||
| Megatron ``TransformerConfig`` if label smoothing is enabled, to | ||
| preserve the native behavior. | ||
|
|
||
| Scope: | ||
| Initial release supports ``tensor_model_parallel_size=1`` only. With | ||
| TP>1, each rank holds a vocab-sharded logits slice ``[N, V/tp]`` and | ||
| computing cross-entropy requires cross-rank all-reduces that Liger's | ||
| kernel does not perform. A ``RuntimeError`` is raised at patch time if | ||
| the Megatron parallel state already reports TP>1, and again at call | ||
| time if a multi-rank ``tp_group`` is passed. | ||
|
|
||
| Raises: | ||
| AssertionError: If ``reduction != "none"``. | ||
| ImportError: If ``megatron.core.fusions.fused_cross_entropy`` is not | ||
| importable, or if the expected | ||
| ``fused_vocab_parallel_cross_entropy`` symbol is missing from that | ||
| module (indicating an incompatible Megatron version). | ||
| RuntimeError: If tensor model parallelism > 1 is detected. | ||
|
|
||
| Example: | ||
| >>> from liger_kernel.megatron import apply_liger_kernel_to_megatron | ||
| >>> apply_liger_kernel_to_megatron( | ||
| ... ignore_index=-100, | ||
| ... label_smoothing=cfg.label_smoothing_factor, | ||
| ... ) | ||
| >>> # call before Megatron's forward pass reaches compute_language_model_loss | ||
| """ | ||
| assert reduction == "none", ( | ||
| f"Megatron's fused_vocab_parallel_cross_entropy contract requires per-token loss; " | ||
| f"reduction must be 'none', got {reduction!r}." | ||
| ) | ||
|
|
||
| try: | ||
| import megatron.core.fusions.fused_cross_entropy as fce | ||
| except ImportError as exc: | ||
| raise ImportError( | ||
| "apply_liger_kernel_to_megatron requires megatron-core to be installed. " | ||
| "Expected symbol path: " | ||
| "megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy." | ||
| ) from exc | ||
|
|
||
| if not hasattr(fce, "fused_vocab_parallel_cross_entropy"): | ||
| raise ImportError( | ||
| "megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy not found. " | ||
| "The symbol path may have changed in your Megatron-LM version. Please file an issue " | ||
| "on https://github.com/linkedin/Liger-Kernel with your megatron-core version." | ||
| ) | ||
|
|
||
| _check_tensor_parallel_size_at_patch_time() | ||
|
|
||
| loss_fn = LigerCrossEntropyLoss( | ||
| ignore_index=ignore_index, | ||
| label_smoothing=label_smoothing, | ||
| reduction="none", | ||
| ) | ||
| fce.fused_vocab_parallel_cross_entropy = _build_wrapper(loss_fn) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a constrain that need to be addressed in the future given that TP is a common use case in Megatron, but it's a great start supporting megatron!
BTW, does this patching also not support other parallel strategy? (Sequence Parallel, etc)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels a bit awkward to me to have patching and function wrapping logics in liger side. Surely it is a simpler way to use liger's ce without touching megatron codebase. However, if supporting megatron framework is not in our roadmap, and not going to add it to our test suite in a short time, it will be quite inconvenient to maintain this support in a long run. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, megatron's SP requires TP>1