Skip to content

Commit e2aafc7

Browse files
danielhanchennaliazhelijeromekujackswlshimmyshimmer
authored
Fix GRPO (#2787)
* Update _utils.py * Update _utils.py * versioning * Update _utils.py * Update _utils.py * Update _utils.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update vision.py * HF Transfer * fix(utils): add missing importlib import to fix NameError (#2134) This commit fixes a NameError that occurs when `importlib` is referenced in _utils.py without being imported, especially when UNSLOTH_USE_MODELSCOPE=1 is enabled. By adding the missing import statement, the code will no longer throw a NameError. * Add QLoRA Train and Merge16bit Test (#2130) * add reference and unsloth lora merging tests * add test / dataset printing to test scripts * allow running tests from repo root * add qlora test readme * more readme edits * ruff formatting * additional readme comments * forgot to add actual tests * add apache license * Update pyproject.toml * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update loader.py * Revert * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Bug fix * Update mapper.py * check SDPA for Mistral 3, Pixtral * Update vision.py * Versioning * Update rl_replacements.py * Update README.md * add model registry * move hf hub utils to unsloth/utils * refactor global model info dicts to dataclasses * fix dataclass init * fix llama registration * remove deprecated key function * start registry reog * add llama vision * quant types -> Enum * remap literal quant types to QuantType Enum * add llama model registration * fix quant tag mapping * add qwen2.5 models to registry * add option to include original model in registry * handle quant types per model size * separate registration of base and instruct llama3.2 * add QwenQVQ to registry * add gemma3 to registry * add phi * add deepseek v3 * add deepseek r1 base * add deepseek r1 zero * add deepseek distill llama * add deepseek distill models * remove redundant code when constructing model names * add mistral small to registry * rename model registration methods * rename deepseek registration methods * refactor naming for mistral and phi * add global register models * refactor model registration tests for new registry apis * add model search method * remove deprecated registration api * add quant type test * add registry readme * make llama registration more specific * clear registry when executing individual model registration file * more registry readme updates * Update _auto_install.py * Llama4 * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Synthetic data * Update mapper.py * Xet and Synthetic * Update synthetic.py * Update loader.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update pyproject.toml * Delete .gitignore * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update _utils.py * Update pyproject.toml * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update synthetic.py * Update chat_templates.py * Seasame force float16 / float32 * Fix Seasame * Update loader.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * is_multimodal * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update vision.py * Update vision.py * Update vision.py * UNSLOTH_DISABLE_STATIC_GENERATION * Update vision.py * Auto vision detection * Sesame * Whisper * Update loader.py * Update loader.py * Update loader.py * Update mapper.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update vision.py * Update loader.py * Update loader.py * Update loader.py * Update loader.py * Update _utils.py * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * logging * Update pyproject.toml * Update rl.py * versioning * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * logits / temperature * Update rl_replacements.py * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py --------- Co-authored-by: naliazheli <[email protected]> Co-authored-by: jeromeku <[email protected]> Co-authored-by: Jack Shi Wei Lun <[email protected]> Co-authored-by: Michael Han <[email protected]>
1 parent 8767244 commit e2aafc7

File tree

3 files changed

+37
-4
lines changed

3 files changed

+37
-4
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ triton = [
3737
]
3838

3939
huggingface = [
40-
"unsloth_zoo>=2025.6.3",
40+
"unsloth_zoo>=2025.6.4",
4141
"packaging",
4242
"tyro",
4343
"transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2",
@@ -381,7 +381,7 @@ colab-ampere-torch220 = [
381381
"flash-attn>=2.6.3",
382382
]
383383
colab-new = [
384-
"unsloth_zoo>=2025.6.3",
384+
"unsloth_zoo>=2025.6.4",
385385
"packaging",
386386
"tyro",
387387
"transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2",

unsloth/models/rl.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,21 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
486486
arguments = re.sub(x, y, arguments)
487487
pass
488488

489+
# Fix GRPO beta default as 0.001 TRL used to be 0.04, now 0.00!
490+
# https://github.com/huggingface/trl/pull/3516
491+
# https://verl.readthedocs.io/en/latest/examples/config.html
492+
if trainer_file == "grpo_trainer":
493+
replacements = {
494+
"beta" : 0.001,
495+
}
496+
for k, v in replacements.items():
497+
x = f"{k}( = [^,\n]{{1,}})?,\n"
498+
y = f"'{v}'" if type(v) is str else f"{v}"
499+
y = f"{k} = {y},\n"
500+
arguments = re.sub(x, y, arguments)
501+
pass
502+
pass
503+
489504
# Warn on too large or too small learning rate
490505
if " learning_rate" in call_args:
491506
learning_rate_check = \
@@ -553,6 +568,17 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
553568
extra_args += check_num_generations
554569
pass
555570

571+
# Check temperature must not be <= 0. Also stop if >= 10
572+
if "temperature" in call_args:
573+
check_temperature = \
574+
"if temperature <= 0:\n"\
575+
" raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')\n"\
576+
"elif temperature >= 10:\n"\
577+
" raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')\n"\
578+
"\n"
579+
extra_args += check_temperature
580+
pass
581+
556582
# Edit config with anything extra
557583
if trainer_file in RL_CONFIG_CHANGES:
558584
process_extra_args = RL_CONFIG_CHANGES[trainer_file]

unsloth/models/rl_replacements.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep,
269269
# See https://github.com/huggingface/trl/issues/2770
270270
# logits = logits[:, -logits_to_keep:]
271271
# return logits
272+
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
273+
# logits = logits / self.temperature
272274
# logps = selective_log_softmax(logits, input_ids)
273275

274276
# row_indices, col_indices = torch.where(logps < -20)
@@ -325,7 +327,6 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
325327
else:
326328
ref_per_token_logps = None
327329
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
328-
329330
# x - x.detach() allows for preserving gradients from x
330331
advantages = inputs["advantages"]
331332
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
@@ -335,10 +336,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
335336
old_hidden_states = inputs["old_per_token_logps"]
336337
else:
337338
old_hidden_states = None
339+
338340
input_ids = input_ids[:, -logits_to_keep:]
339341
if per_token_logps is not None:
340342

341-
ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
343+
if ref_per_token_logps is not None:
344+
ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
345+
342346
per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
343347

344348
loss, completion_length, mean_kl = grpo_compute_loss_slow(
@@ -354,6 +358,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
354358
epsilon_high = self.epsilon_high,
355359
max_completion_length = self.args.max_completion_length,
356360
delta = self.args.delta,
361+
temperature = self.args.temperature,
357362
)
358363
else:
359364
if hasattr(self.args, "loss_type"):
@@ -370,6 +375,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
370375
epsilon_high = self.epsilon_high,
371376
max_completion_length = self.args.max_completion_length,
372377
delta = self.args.delta,
378+
temperature = self.args.temperature,
373379
)
374380
else:
375381
# to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17
@@ -381,6 +387,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
381387
advantages,
382388
old_hidden_states,
383389
n_chunks = self.args.unsloth_num_chunks,
390+
temperature = self.args.temperature,
384391
)
385392

386393
# Log the metrics

0 commit comments

Comments
 (0)