diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 95e9b7919..d5c7be68d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -545,10 +545,17 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Warn on too large or too small learning rate if " learning_rate" in call_args: learning_rate_check = \ - "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! "\ - "Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ - "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! "\ - "Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n" + "use_strict_mode = os.environ.get('UNSLOTH_USE_STRICT_MODE', '1') == '1'\n"\ + "lower_limit_msg = f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! "\ + "Consider increasing it, otherwise gradient updates will be close to 0!'\n"\ + "upper_limit_msg = f'Unsloth: Your learning rate of `{learning_rate}` is way too and more than 1! "\ + "Consider decreasing it to 1e-1, otherwise gradient updates will explode!'\n"\ + "if learning_rate < 1e-7:\n"\ + "if use_strict_mode: raise FloatingPointError(lower_limit_msg)\n"\ + "else: print(lower_limit_msg)\n"\ + "if learning_rate > 1:\n"\ + "if use_strict_mode: raise OverflowError(upper_limit_msg)\n"\ + "else: print(upper_limit_msg)\n"\ extra_args += learning_rate_check pass