@@ -236,6 +236,8 @@ def from_pretrained(
236236 auto_model = AutoModelForVision2Seq ,
237237 use_gradient_checkpointing = "unsloth" ,
238238 supports_sdpa = True ,
239+ whisper_language = None ,
240+ whisper_task = None ,
239241 ** kwargs ,
240242 ):
241243 if model_types is None :
@@ -304,7 +306,8 @@ def from_pretrained(
304306 do_forced_float32 = True
305307 pass
306308 # Stop SDPA for some archs like Pixtral / Mistral3
307- kwargs ["attn_implementation" ] = "sdpa"
309+ if not ("attn_implementation" in kwargs ):
310+ kwargs ["attn_implementation" ] = "sdpa"
308311 if not supports_sdpa :
309312 print (f"Unsloth: { model_type_arch .title ()} does not support SDPA - switching to eager!" )
310313 del kwargs ["attn_implementation" ]
@@ -352,6 +355,7 @@ def from_pretrained(
352355 # Check if using forced float32 - we load it in bfloat16, then cast to float16!
353356 torch_dtype = dtype
354357 if do_forced_float32 : torch_dtype = torch .bfloat16
358+
355359 model = auto_model .from_pretrained (
356360 model_name ,
357361 device_map = device_map ,
@@ -367,12 +371,23 @@ def from_pretrained(
367371
368372 # Counteract saved tokenizers
369373 tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
370- auto_processor = AutoProcessor if auto_model is AutoModelForVision2Seq else AutoTokenizer
371- tokenizer = auto_processor .from_pretrained (
372- tokenizer_name ,
373- padding_side = "right" ,
374- token = token ,
375- )
374+ is_vlm = (auto_model is AutoModelForVision2Seq )
375+ is_whisper = (whisper_language is not None and whisper_task is not None )
376+ auto_processor = AutoProcessor if (is_vlm or is_whisper ) else AutoTokenizer
377+ if whisper_language and whisper_task :
378+ tokenizer = auto_processor .from_pretrained (
379+ tokenizer_name ,
380+ padding_side = "right" ,
381+ token = token ,
382+ language = whisper_language ,
383+ task = whisper_task ,
384+ )
385+ else :
386+ tokenizer = auto_processor .from_pretrained (
387+ tokenizer_name ,
388+ padding_side = "right" ,
389+ token = token ,
390+ )
376391 if hasattr (tokenizer , "tokenizer" ):
377392 __tokenizer = tokenizer .tokenizer
378393 # Add padding side as well
@@ -469,6 +484,7 @@ def get_peft_model(
469484 modules_to_save = None ,
470485 init_lora_weights = True ,
471486 loftq_config = {},
487+ task_type = TaskType .CAUSAL_LM ,
472488 temporary_location = "_unsloth_temporary_saved_buffers" ,
473489 ** kwargs ,
474490 ):
@@ -492,7 +508,7 @@ def get_peft_model(
492508 finetune_attention_modules = True
493509 finetune_mlp_modules = True
494510 pass
495- if target_modules is None :
511+ if target_modules is None or target_modules == "all-linear" :
496512 target_modules = get_peft_regex (
497513 model ,
498514 finetune_vision_layers = finetune_vision_layers ,
@@ -503,7 +519,7 @@ def get_peft_model(
503519 else :
504520 assert (type (target_modules ) in (list , tuple ,))
505521 pass
506-
522+
507523 # Clear deleted GPU items
508524 for _ in range (3 ):
509525 gc .collect ()
@@ -516,7 +532,7 @@ def get_peft_model(
516532 target_modules = target_modules ,
517533 lora_dropout = lora_dropout ,
518534 bias = bias ,
519- task_type = TaskType . CAUSAL_LM ,
535+ task_type = task_type ,
520536 )
521537 model = prepare_model_for_kbit_training (
522538 model ,
0 commit comments