From 3fb5aef9a2a4290c28d7a7016d492c1de4296027 Mon Sep 17 00:00:00 2001 From: mamba-chen Date: Thu, 7 May 2026 14:52:20 +0800 Subject: [PATCH 1/6] add benchmark script for ops/grpo_loss --- .../scripts/benchmark_grpo_loss_kernel.py | 365 ++++++++++++++++++ 1 file changed, 365 insertions(+) create mode 100644 benchmark/scripts/benchmark_grpo_loss_kernel.py diff --git a/benchmark/scripts/benchmark_grpo_loss_kernel.py b/benchmark/scripts/benchmark_grpo_loss_kernel.py new file mode 100644 index 000000000..b1287aaec --- /dev/null +++ b/benchmark/scripts/benchmark_grpo_loss_kernel.py @@ -0,0 +1,365 @@ +""" +Benchmark for GRPO loss kernel (ops layer). + +This benchmark tests the performance of the low-level GRPO loss kernel +(liger_kernel.ops.grpo_loss.GrpoLossFunction) against a pure PyTorch baseline. + +Unlike benchmark_grpo_loss.py which tests the chunked_loss layer (fused linear + loss), +this benchmark focuses on the kernel-level implementation without the linear layer fusion. + +Usage: + python benchmark_grpo_loss_kernel.py [--overwrite] +""" +import os +import sys + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.utils import infer_device + +device = infer_device() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +############################################################################# +# Torch baseline implementation for GRPO loss kernel +############################################################################# + + +def torch_grpo_loss_kernel( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + loss_type="grpo", +): + """ + Torch baseline implementation of GRPO loss kernel. + This mimics the kernel-level API without chunking. + """ + B, L_ADD_1, N = logits.shape + L = L_ADD_1 - 1 + + # Compute log probabilities + logits_for_loss = logits[:, :-1, :] / temperature + log_probs = torch.nn.functional.log_softmax(logits_for_loss, dim=-1) + + # Gather log probs for selected tokens + completion_ids_expanded = completion_ids.unsqueeze(-1) + logp = log_probs.gather(dim=-1, index=completion_ids_expanded).squeeze(-1) + + # Compute importance ratio + if old_logp is None: + old_logp_val = logp + else: + old_logp_val = old_logp + + coef_1 = torch.exp(logp - old_logp_val) + + # Compute clipped coefficient + coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) + + # Expand advantages to per-token + advantages_expanded = advantages.unsqueeze(-1).expand(-1, L) + + # Compute per-token loss + per_token_loss1 = coef_1 * advantages_expanded + per_token_loss2 = coef_2 * advantages_expanded + per_token_loss = -torch.minimum(per_token_loss1, per_token_loss2) + + # Add KL penalty if beta > 0 + if beta != 0.0 and ref_logp is not None: + kl = torch.exp(ref_logp - logp) - (ref_logp - logp) - 1 + per_token_loss += beta * kl + + # Apply mask and reduce + if completion_mask is not None: + mask = completion_mask.float() + else: + mask = torch.ones(B, L, device=logits.device) + + # Reduce loss (GRPO uses per-sequence mean) + loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() + + return loss + + +############################################################################# +# Test the memory consumption of the GRPO loss kernel +############################################################################# + + +def bench_memory_grpo_loss_kernel( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + from liger_kernel.ops.grpo_loss import GrpoLossFunction + + B = input.x + T = input.extra_benchmark_config["T"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + # Create inputs + logits = torch.randn(B, T + 1, V, requires_grad=True, dtype=dtype, device=device) + completion_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device) + advantages = torch.randn(B, dtype=dtype, device=device) + completion_mask = torch.ones(B, T, dtype=torch.bool, device=device) + old_logp = None # On-policy case + ref_logp = torch.randn(B, T, dtype=torch.float32, device=device) if input.extra_benchmark_config.get("beta", 0.0) > 0 else None + + temperature = 1.0 + beta = input.extra_benchmark_config.get("beta", 0.0) + eps_low = 0.2 + eps_high = 0.2 + + def liger_fwd(): + return GrpoLossFunction.apply( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + False, # inplace + "grpo", # loss_type + None, # max_completion_length + True, # reduce + "token", # importance_sampling_level + 1.0, # sapo_temperature_pos + 1.05, # sapo_temperature_neg + None, # vllm_is_ratio + None, # delta + False, # use_bias_correction_kl + )[0] + + def torch_fwd(): + return torch_grpo_loss_kernel( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + + def fwd(): + if provider == "liger": + return liger_fwd() + elif provider == "torch": + return torch_fwd() + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +############################################################################# +# Test the speed of the GRPO loss kernel +############################################################################# + + +def bench_speed_grpo_loss_kernel( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + from liger_kernel.ops.grpo_loss import GrpoLossFunction + + B = input.x + T = input.extra_benchmark_config["T"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + # Create inputs + logits = torch.randn(B, T + 1, V, requires_grad=True, dtype=dtype, device=device) + completion_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device) + advantages = torch.randn(B, dtype=dtype, device=device) + completion_mask = torch.ones(B, T, dtype=torch.bool, device=device) + old_logp = None # On-policy case + ref_logp = torch.randn(B, T, dtype=torch.float32, device=device) if input.extra_benchmark_config.get("beta", 0.0) > 0 else None + + temperature = 1.0 + beta = input.extra_benchmark_config.get("beta", 0.0) + eps_low = 0.2 + eps_high = 0.2 + + def liger_fwd(): + return GrpoLossFunction.apply( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + False, # inplace + "grpo", # loss_type + None, # max_completion_length + True, # reduce + "token", # importance_sampling_level + 1.0, # sapo_temperature_pos + 1.05, # sapo_temperature_neg + None, # vllm_is_ratio + None, # delta + False, # use_bias_correction_kl + )[0] + + def torch_fwd(): + return torch_grpo_loss_kernel( + logits, + old_logp, + ref_logp, + completion_ids, + advantages, + completion_mask, + temperature, + beta, + eps_low, + eps_high, + ) + + def fwd(): + if provider == "liger": + return liger_fwd() + elif provider == "torch": + return torch_fwd() + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[logits], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + # Benchmark GRPO loss kernel without KL penalty + no_kl_configs = { + "kernel_name": "grpo_loss_kernel_no_kl", + "x_name": "B", + "x_label": "Batch Size", + "x_values": [2**i for i in range(1, 5)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "T": 512, + "V": 32000, + "beta": 0.0, + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + # Benchmark GRPO loss kernel with KL penalty + with_kl_configs = { + "kernel_name": "grpo_loss_kernel_with_kl", + "x_name": "B", + "x_label": "Batch Size", + "x_values": [2**i for i in range(1, 5)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "T": 512, + "V": 32000, + "beta": 0.1, + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + # Run benchmarks without KL + print("Benchmarking GRPO loss kernel (no KL penalty)...") + run_benchmarks( + bench_test_fn=bench_speed_grpo_loss_kernel, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **no_kl_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_grpo_loss_kernel, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **no_kl_configs, + ) + + # Run benchmarks with KL + print("Benchmarking GRPO loss kernel (with KL penalty)...") + run_benchmarks( + bench_test_fn=bench_speed_grpo_loss_kernel, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **with_kl_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_grpo_loss_kernel, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **with_kl_configs, + ) From d9cf8cd85abe4430511428c802f0a771fefd3e7d Mon Sep 17 00:00:00 2001 From: mamba-chen Date: Thu, 7 May 2026 15:37:47 +0800 Subject: [PATCH 2/6] add benchmark script for ops/grpo_loss --- .../scripts/benchmark_grpo_loss_kernel.py | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/benchmark/scripts/benchmark_grpo_loss_kernel.py b/benchmark/scripts/benchmark_grpo_loss_kernel.py index b1287aaec..df79befe6 100644 --- a/benchmark/scripts/benchmark_grpo_loss_kernel.py +++ b/benchmark/scripts/benchmark_grpo_loss_kernel.py @@ -113,6 +113,7 @@ def bench_memory_grpo_loss_kernel( T = input.extra_benchmark_config["T"] V = input.extra_benchmark_config["V"] dtype = input.extra_benchmark_config["dtype"] + importance_sampling_level = input.extra_benchmark_config["importance_sampling_level"] provider = input.kernel_provider # Create inputs @@ -144,7 +145,7 @@ def liger_fwd(): "grpo", # loss_type None, # max_completion_length True, # reduce - "token", # importance_sampling_level + importance_sampling_level, # importance_sampling_level 1.0, # sapo_temperature_pos 1.05, # sapo_temperature_neg None, # vllm_is_ratio @@ -198,6 +199,7 @@ def bench_speed_grpo_loss_kernel( T = input.extra_benchmark_config["T"] V = input.extra_benchmark_config["V"] dtype = input.extra_benchmark_config["dtype"] + importance_sampling_level = input.extra_benchmark_config["importance_sampling_level"] provider = input.kernel_provider mode = input.kernel_operation_mode @@ -230,7 +232,7 @@ def liger_fwd(): "grpo", # loss_type None, # max_completion_length True, # reduce - "token", # importance_sampling_level + importance_sampling_level, # importance_sampling_level 1.0, # sapo_temperature_pos 1.05, # sapo_temperature_neg None, # vllm_is_ratio @@ -294,9 +296,9 @@ def full(): if __name__ == "__main__": args = parse_benchmark_script_args() - # Benchmark GRPO loss kernel without KL penalty - no_kl_configs = { - "kernel_name": "grpo_loss_kernel_no_kl", + # Benchmark token-level importance sampling (original GRPO) + token_configs = { + "kernel_name": "grpo_loss_kernel_token", "x_name": "B", "x_label": "Batch Size", "x_values": [2**i for i in range(1, 5)], @@ -305,16 +307,17 @@ def full(): { "T": 512, "V": 32000, - "beta": 0.0, + "beta": 0.1, "dtype": torch.bfloat16, + "importance_sampling_level": "token", } ], "overwrite": args.overwrite, } - # Benchmark GRPO loss kernel with KL penalty - with_kl_configs = { - "kernel_name": "grpo_loss_kernel_with_kl", + # Benchmark sequence-level importance sampling (GSPO) + sequence_configs = { + "kernel_name": "grpo_loss_kernel_sequence", "x_name": "B", "x_label": "Batch Size", "x_values": [2**i for i in range(1, 5)], @@ -325,41 +328,42 @@ def full(): "V": 32000, "beta": 0.1, "dtype": torch.bfloat16, + "importance_sampling_level": "sequence", } ], "overwrite": args.overwrite, } - # Run benchmarks without KL - print("Benchmarking GRPO loss kernel (no KL penalty)...") + # Run benchmarks for token-level (GRPO) + print("Benchmarking GRPO (token-level importance sampling)...") run_benchmarks( bench_test_fn=bench_speed_grpo_loss_kernel, kernel_operation_modes=["forward", "full", "backward"], metric_name="speed", metric_unit="ms", - **no_kl_configs, + **token_configs, ) run_benchmarks( bench_test_fn=bench_memory_grpo_loss_kernel, kernel_operation_modes=["full"], metric_name="memory", metric_unit="MB", - **no_kl_configs, + **token_configs, ) - # Run benchmarks with KL - print("Benchmarking GRPO loss kernel (with KL penalty)...") + # Run benchmarks for sequence-level (GSPO) + print("Benchmarking GSPO (sequence-level importance sampling)...") run_benchmarks( bench_test_fn=bench_speed_grpo_loss_kernel, kernel_operation_modes=["forward", "full", "backward"], metric_name="speed", metric_unit="ms", - **with_kl_configs, + **sequence_configs, ) run_benchmarks( bench_test_fn=bench_memory_grpo_loss_kernel, kernel_operation_modes=["full"], metric_name="memory", metric_unit="MB", - **with_kl_configs, + **sequence_configs, ) From 7fd609688f8735bb74e705ce96585bf1cff50f89 Mon Sep 17 00:00:00 2001 From: mamba-chen Date: Thu, 7 May 2026 16:49:37 +0800 Subject: [PATCH 3/6] add benchmark script for ops/grpo_loss --- benchmark/scripts/benchmark_grpo_loss_kernel.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/benchmark/scripts/benchmark_grpo_loss_kernel.py b/benchmark/scripts/benchmark_grpo_loss_kernel.py index df79befe6..8c4882c89 100644 --- a/benchmark/scripts/benchmark_grpo_loss_kernel.py +++ b/benchmark/scripts/benchmark_grpo_loss_kernel.py @@ -107,7 +107,7 @@ def torch_grpo_loss_kernel( def bench_memory_grpo_loss_kernel( input: SingleBenchmarkRunInput, ) -> SingleBenchmarkRunOutput: - from liger_kernel.ops.grpo_loss import GrpoLossFunction + from liger_kernel.ops import GrpoLossFunction B = input.x T = input.extra_benchmark_config["T"] @@ -122,7 +122,8 @@ def bench_memory_grpo_loss_kernel( advantages = torch.randn(B, dtype=dtype, device=device) completion_mask = torch.ones(B, T, dtype=torch.bool, device=device) old_logp = None # On-policy case - ref_logp = torch.randn(B, T, dtype=torch.float32, device=device) if input.extra_benchmark_config.get("beta", 0.0) > 0 else None + ref_logp = torch.randn(B, T, dtype=torch.float32, device=device) if input.extra_benchmark_config.get("beta", + 0.0) > 0 else None temperature = 1.0 beta = input.extra_benchmark_config.get("beta", 0.0) @@ -209,7 +210,8 @@ def bench_speed_grpo_loss_kernel( advantages = torch.randn(B, dtype=dtype, device=device) completion_mask = torch.ones(B, T, dtype=torch.bool, device=device) old_logp = None # On-policy case - ref_logp = torch.randn(B, T, dtype=torch.float32, device=device) if input.extra_benchmark_config.get("beta", 0.0) > 0 else None + ref_logp = torch.randn(B, T, dtype=torch.float32, device=device) if input.extra_benchmark_config.get("beta", + 0.0) > 0 else None temperature = 1.0 beta = input.extra_benchmark_config.get("beta", 0.0) @@ -301,7 +303,7 @@ def full(): "kernel_name": "grpo_loss_kernel_token", "x_name": "B", "x_label": "Batch Size", - "x_values": [2**i for i in range(1, 5)], + "x_values": [2 ** i for i in range(1, 5)], "kernel_providers": ["liger", "torch"], "extra_benchmark_configs": [ { @@ -320,7 +322,7 @@ def full(): "kernel_name": "grpo_loss_kernel_sequence", "x_name": "B", "x_label": "Batch Size", - "x_values": [2**i for i in range(1, 5)], + "x_values": [2 ** i for i in range(1, 5)], "kernel_providers": ["liger", "torch"], "extra_benchmark_configs": [ { From dc7e58e162f1cab5baa5494d48b520e0eb79ec89 Mon Sep 17 00:00:00 2001 From: mamba-chen Date: Fri, 8 May 2026 14:49:33 +0800 Subject: [PATCH 4/6] add benchmark script for ops/grpo_loss --- benchmark/scripts/benchmark_grpo_loss_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/scripts/benchmark_grpo_loss_kernel.py b/benchmark/scripts/benchmark_grpo_loss_kernel.py index 8c4882c89..001aa3285 100644 --- a/benchmark/scripts/benchmark_grpo_loss_kernel.py +++ b/benchmark/scripts/benchmark_grpo_loss_kernel.py @@ -194,7 +194,7 @@ def full(): def bench_speed_grpo_loss_kernel( input: SingleBenchmarkRunInput, ) -> SingleBenchmarkRunOutput: - from liger_kernel.ops.grpo_loss import GrpoLossFunction + from liger_kernel.ops import GrpoLossFunction B = input.x T = input.extra_benchmark_config["T"] From d9919a62d1eb362dbdc8ad36f95159638e8caa42 Mon Sep 17 00:00:00 2001 From: mamba-chen Date: Thu, 14 May 2026 09:56:15 +0800 Subject: [PATCH 5/6] add benchmark script for ops/grpo_loss --- .../scripts/benchmark_grpo_loss_kernel.py | 287 ++++++++---------- 1 file changed, 129 insertions(+), 158 deletions(-) diff --git a/benchmark/scripts/benchmark_grpo_loss_kernel.py b/benchmark/scripts/benchmark_grpo_loss_kernel.py index 001aa3285..dd4c691be 100644 --- a/benchmark/scripts/benchmark_grpo_loss_kernel.py +++ b/benchmark/scripts/benchmark_grpo_loss_kernel.py @@ -14,6 +14,7 @@ import sys import torch +import torch.nn as nn import triton from utils import QUANTILES @@ -31,106 +32,82 @@ ############################################################################# -# Torch baseline implementation for GRPO loss kernel +# Module wrappers for GRPO loss kernel ############################################################################# -def torch_grpo_loss_kernel( - logits, - old_logp, - ref_logp, - completion_ids, - advantages, - completion_mask, - temperature, - beta, - eps_low, - eps_high, - loss_type="grpo", -): - """ - Torch baseline implementation of GRPO loss kernel. - This mimics the kernel-level API without chunking. - """ - B, L_ADD_1, N = logits.shape - L = L_ADD_1 - 1 - - # Compute log probabilities - logits_for_loss = logits[:, :-1, :] / temperature - log_probs = torch.nn.functional.log_softmax(logits_for_loss, dim=-1) - - # Gather log probs for selected tokens - completion_ids_expanded = completion_ids.unsqueeze(-1) - logp = log_probs.gather(dim=-1, index=completion_ids_expanded).squeeze(-1) - - # Compute importance ratio - if old_logp is None: - old_logp_val = logp - else: - old_logp_val = old_logp - - coef_1 = torch.exp(logp - old_logp_val) - - # Compute clipped coefficient - coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high) - - # Expand advantages to per-token - advantages_expanded = advantages.unsqueeze(-1).expand(-1, L) - - # Compute per-token loss - per_token_loss1 = coef_1 * advantages_expanded - per_token_loss2 = coef_2 * advantages_expanded - per_token_loss = -torch.minimum(per_token_loss1, per_token_loss2) - - # Add KL penalty if beta > 0 - if beta != 0.0 and ref_logp is not None: - kl = torch.exp(ref_logp - logp) - (ref_logp - logp) - 1 - per_token_loss += beta * kl - - # Apply mask and reduce - if completion_mask is not None: - mask = completion_mask.float() - else: - mask = torch.ones(B, L, device=logits.device) - - # Reduce loss (GRPO uses per-sequence mean) - loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() - - return loss +class TorchGRPOLoss(nn.Module): + """Torch baseline module for GRPO loss kernel.""" + def __init__(self, temperature=1.0, beta=0.0, eps_low=0.2, eps_high=0.2): + super().__init__() + self.temperature = temperature + self.beta = beta + self.eps_low = eps_low + self.eps_high = eps_high -############################################################################# -# Test the memory consumption of the GRPO loss kernel -############################################################################# + def forward(self, logits, old_logp, ref_logp, completion_ids, advantages, completion_mask): + B, L_ADD_1, N = logits.shape + L = L_ADD_1 - 1 + # Compute log probabilities + logits_for_loss = logits[:, :-1, :] / self.temperature + log_probs = torch.nn.functional.log_softmax(logits_for_loss, dim=-1) -def bench_memory_grpo_loss_kernel( - input: SingleBenchmarkRunInput, -) -> SingleBenchmarkRunOutput: - from liger_kernel.ops import GrpoLossFunction + # Gather log probs for selected tokens + completion_ids_expanded = completion_ids.unsqueeze(-1) + logp = log_probs.gather(dim=-1, index=completion_ids_expanded).squeeze(-1) - B = input.x - T = input.extra_benchmark_config["T"] - V = input.extra_benchmark_config["V"] - dtype = input.extra_benchmark_config["dtype"] - importance_sampling_level = input.extra_benchmark_config["importance_sampling_level"] - provider = input.kernel_provider + # Compute importance ratio + if old_logp is None: + old_logp_val = logp + else: + old_logp_val = old_logp - # Create inputs - logits = torch.randn(B, T + 1, V, requires_grad=True, dtype=dtype, device=device) - completion_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device) - advantages = torch.randn(B, dtype=dtype, device=device) - completion_mask = torch.ones(B, T, dtype=torch.bool, device=device) - old_logp = None # On-policy case - ref_logp = torch.randn(B, T, dtype=torch.float32, device=device) if input.extra_benchmark_config.get("beta", - 0.0) > 0 else None + coef_1 = torch.exp(logp - old_logp_val) - temperature = 1.0 - beta = input.extra_benchmark_config.get("beta", 0.0) - eps_low = 0.2 - eps_high = 0.2 + # Compute clipped coefficient + coef_2 = torch.clamp(coef_1, 1 - self.eps_low, 1 + self.eps_high) + + # Expand advantages to per-token + advantages_expanded = advantages.unsqueeze(-1).expand(-1, L) + + # Compute per-token loss + per_token_loss1 = coef_1 * advantages_expanded + per_token_loss2 = coef_2 * advantages_expanded + per_token_loss = -torch.minimum(per_token_loss1, per_token_loss2) + + # Add KL penalty if beta > 0 + if self.beta != 0.0 and ref_logp is not None: + kl = torch.exp(ref_logp - logp) - (ref_logp - logp) - 1 + per_token_loss += self.beta * kl + + # Apply mask and reduce + if completion_mask is not None: + mask = completion_mask.float() + else: + mask = torch.ones(B, L, device=logits.device) + + # Reduce loss (GRPO uses per-sequence mean) + loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() + + return loss + + +class LigerGRPOLoss(nn.Module): + """Liger module wrapper for GRPO loss kernel.""" + + def __init__(self, temperature=1.0, beta=0.0, eps_low=0.2, eps_high=0.2, importance_sampling_level="token"): + super().__init__() + self.temperature = temperature + self.beta = beta + self.eps_low = eps_low + self.eps_high = eps_high + self.importance_sampling_level = importance_sampling_level + + def forward(self, logits, old_logp, ref_logp, completion_ids, advantages, completion_mask): + from liger_kernel.ops import GrpoLossFunction - def liger_fwd(): return GrpoLossFunction.apply( logits, old_logp, @@ -138,15 +115,15 @@ def liger_fwd(): completion_ids, advantages, completion_mask, - temperature, - beta, - eps_low, - eps_high, + self.temperature, + self.beta, + self.eps_low, + self.eps_high, False, # inplace "grpo", # loss_type None, # max_completion_length True, # reduce - importance_sampling_level, # importance_sampling_level + self.importance_sampling_level, 1.0, # sapo_temperature_pos 1.05, # sapo_temperature_neg None, # vllm_is_ratio @@ -154,25 +131,50 @@ def liger_fwd(): False, # use_bias_correction_kl )[0] - def torch_fwd(): - return torch_grpo_loss_kernel( - logits, - old_logp, - ref_logp, - completion_ids, - advantages, - completion_mask, - temperature, - beta, - eps_low, - eps_high, - ) + +############################################################################# +# Test the memory consumption of the GRPO loss kernel +############################################################################# + + +def bench_memory_grpo_loss_kernel( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + importance_sampling_level = input.extra_benchmark_config["importance_sampling_level"] + provider = input.kernel_provider + + temperature = 1.0 + beta = input.extra_benchmark_config.get("beta", 0.0) + eps_low = 0.2 + eps_high = 0.2 + + # Instantiate modules + torch_grpo = TorchGRPOLoss(temperature=temperature, beta=beta, eps_low=eps_low, eps_high=eps_high).to(device) + liger_grpo = LigerGRPOLoss( + temperature=temperature, + beta=beta, + eps_low=eps_low, + eps_high=eps_high, + importance_sampling_level=importance_sampling_level, + ).to(device) + + # Create inputs + logits = torch.randn(B, T + 1, V, requires_grad=True, dtype=dtype, device=device) + completion_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device) + advantages = torch.randn(B, dtype=dtype, device=device) + completion_mask = torch.ones(B, T, dtype=torch.bool, device=device) + old_logp = None # On-policy case + ref_logp = torch.randn(B, T, dtype=torch.float32, device=device) if beta > 0 else None def fwd(): if provider == "liger": - return liger_fwd() + return liger_grpo(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask) elif provider == "torch": - return torch_fwd() + return torch_grpo(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask) def full(): y = fwd() @@ -194,8 +196,6 @@ def full(): def bench_speed_grpo_loss_kernel( input: SingleBenchmarkRunInput, ) -> SingleBenchmarkRunOutput: - from liger_kernel.ops import GrpoLossFunction - B = input.x T = input.extra_benchmark_config["T"] V = input.extra_benchmark_config["V"] @@ -204,63 +204,34 @@ def bench_speed_grpo_loss_kernel( provider = input.kernel_provider mode = input.kernel_operation_mode - # Create inputs - logits = torch.randn(B, T + 1, V, requires_grad=True, dtype=dtype, device=device) - completion_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device) - advantages = torch.randn(B, dtype=dtype, device=device) - completion_mask = torch.ones(B, T, dtype=torch.bool, device=device) - old_logp = None # On-policy case - ref_logp = torch.randn(B, T, dtype=torch.float32, device=device) if input.extra_benchmark_config.get("beta", - 0.0) > 0 else None - temperature = 1.0 beta = input.extra_benchmark_config.get("beta", 0.0) eps_low = 0.2 eps_high = 0.2 - def liger_fwd(): - return GrpoLossFunction.apply( - logits, - old_logp, - ref_logp, - completion_ids, - advantages, - completion_mask, - temperature, - beta, - eps_low, - eps_high, - False, # inplace - "grpo", # loss_type - None, # max_completion_length - True, # reduce - importance_sampling_level, # importance_sampling_level - 1.0, # sapo_temperature_pos - 1.05, # sapo_temperature_neg - None, # vllm_is_ratio - None, # delta - False, # use_bias_correction_kl - )[0] + # Instantiate modules + torch_grpo = TorchGRPOLoss(temperature=temperature, beta=beta, eps_low=eps_low, eps_high=eps_high).to(device) + liger_grpo = LigerGRPOLoss( + temperature=temperature, + beta=beta, + eps_low=eps_low, + eps_high=eps_high, + importance_sampling_level=importance_sampling_level, + ).to(device) - def torch_fwd(): - return torch_grpo_loss_kernel( - logits, - old_logp, - ref_logp, - completion_ids, - advantages, - completion_mask, - temperature, - beta, - eps_low, - eps_high, - ) + # Create inputs + logits = torch.randn(B, T + 1, V, requires_grad=True, dtype=dtype, device=device) + completion_ids = torch.randint(0, V, (B, T), dtype=torch.long, device=device) + advantages = torch.randn(B, dtype=dtype, device=device) + completion_mask = torch.ones(B, T, dtype=torch.bool, device=device) + old_logp = None # On-policy case + ref_logp = torch.randn(B, T, dtype=torch.float32, device=device) if beta > 0 else None def fwd(): if provider == "liger": - return liger_fwd() + return liger_grpo(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask) elif provider == "torch": - return torch_fwd() + return torch_grpo(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask) if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench( From af72a53b7dfb160044aaadb0b900a84651a5f717 Mon Sep 17 00:00:00 2001 From: mamba-chen Date: Sat, 16 May 2026 16:32:51 +0800 Subject: [PATCH 6/6] checkstyle --- benchmark/scripts/benchmark_grpo_loss_kernel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmark/scripts/benchmark_grpo_loss_kernel.py b/benchmark/scripts/benchmark_grpo_loss_kernel.py index dd4c691be..f06aa3a66 100644 --- a/benchmark/scripts/benchmark_grpo_loss_kernel.py +++ b/benchmark/scripts/benchmark_grpo_loss_kernel.py @@ -10,6 +10,7 @@ Usage: python benchmark_grpo_loss_kernel.py [--overwrite] """ + import os import sys @@ -274,7 +275,7 @@ def full(): "kernel_name": "grpo_loss_kernel_token", "x_name": "B", "x_label": "Batch Size", - "x_values": [2 ** i for i in range(1, 5)], + "x_values": [2**i for i in range(1, 5)], "kernel_providers": ["liger", "torch"], "extra_benchmark_configs": [ { @@ -293,7 +294,7 @@ def full(): "kernel_name": "grpo_loss_kernel_sequence", "x_name": "B", "x_label": "Batch Size", - "x_values": [2 ** i for i in range(1, 5)], + "x_values": [2**i for i in range(1, 5)], "kernel_providers": ["liger", "torch"], "extra_benchmark_configs": [ {