Skip to content

Commit 41421bf

Browse files
committed
Fix bf16 = None
1 parent 7985307 commit 41421bf

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

unsloth/models/rl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
236236
if "args" in call_args and "model" in call_args:
237237
mixed_precision = \
238238
"use_bf16 = getattr(args, 'bf16', False)\n"\
239+
"if type(use_bf16) is not bool: use_bf16 = False\n"\
239240
"use_fp16 = getattr(args, 'fp16', False)\n"\
241+
"if type(use_fp16) is not bool: use_fp16 = False\n"\
240242
"force_float32 = False\n"\
241243
"if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':\n"\
242244
" print('Unsloth: Switching to float32 training since model cannot work with float16')\n"\
@@ -293,7 +295,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
293295
" if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size\n"\
294296
" if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps\n"\
295297
"fp16_full_eval = getattr(args, 'fp16_full_eval', False)\n"\
298+
"if type(fp16_full_eval) is not bool: fp16_full_eval = False\n"\
296299
"bf16_full_eval = getattr(args, 'bf16_full_eval', False)\n"\
300+
"if type(bf16_full_eval) is not bool: bf16_full_eval = False\n"\
297301
"if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"\
298302
"if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"\
299303
"if force_float32:\n"\

0 commit comments

Comments
 (0)