@@ -236,7 +236,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
236236 if "args" in call_args and "model" in call_args :
237237 mixed_precision = \
238238 "use_bf16 = getattr(args, 'bf16', False)\n " \
239+ "if type(use_bf16) is not bool: use_bf16 = False\n " \
239240 "use_fp16 = getattr(args, 'fp16', False)\n " \
241+ "if type(use_fp16) is not bool: use_fp16 = False\n " \
240242 "force_float32 = False\n " \
241243 "if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':\n " \
242244 " print('Unsloth: Switching to float32 training since model cannot work with float16')\n " \
@@ -293,7 +295,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
293295 " if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size\n " \
294296 " if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps\n " \
295297 "fp16_full_eval = getattr(args, 'fp16_full_eval', False)\n " \
298+ "if type(fp16_full_eval) is not bool: fp16_full_eval = False\n " \
296299 "bf16_full_eval = getattr(args, 'bf16_full_eval', False)\n " \
300+ "if type(bf16_full_eval) is not bool: bf16_full_eval = False\n " \
297301 "if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n " \
298302 "if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n " \
299303 "if force_float32:\n " \
0 commit comments