Skip to content

Commit f6f5e63

Browse files
authored
feat: Support custom auto_model for wider model compatibility (Whisper, Bert,etc) & attn_implementation support (#2263)
* Update loader.py * Update vision.py * Update vision.py fix attn_implementation * Refactor: Improve parameter handling and checks in loader/vision
1 parent 05f4875 commit f6f5e63

File tree

2 files changed

+35
-12
lines changed

2 files changed

+35
-12
lines changed

unsloth/models/loader.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,14 @@ def from_pretrained(
469469
return_logits = False, # Return logits
470470
fullgraph = True, # No graph breaks
471471
use_exact_model_name = False,
472+
auto_model = None,
473+
whisper_language = None,
474+
whisper_task = None,
472475
*args, **kwargs,
473476
):
474477
if token is None: token = get_token()
475-
478+
if whisper_language is not None: assert(type(whisper_language) is str)
479+
if whisper_task is not None: assert(type(whisper_task) is str)
476480
SUPPORTS_BFLOAT16 = is_bfloat16_supported()
477481
if dtype is None:
478482
dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
@@ -709,7 +713,8 @@ def from_pretrained(
709713
# Check if VLM
710714
is_vlm = any(x.endswith("ForConditionalGeneration") for x in model_config.architectures)
711715
is_vlm = is_vlm or hasattr(model_config, "vision_config")
712-
auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
716+
if auto_model is None:
717+
auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
713718

714719
model, tokenizer = FastBaseModel.from_pretrained(
715720
model_name = model_name,
@@ -727,6 +732,8 @@ def from_pretrained(
727732
auto_model = auto_model,
728733
use_gradient_checkpointing = use_gradient_checkpointing,
729734
supports_sdpa = supports_sdpa,
735+
whisper_language = whisper_language,
736+
whisper_task = whisper_task,
730737
*args, **kwargs,
731738
)
732739

unsloth/models/vision.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,8 @@ def from_pretrained(
236236
auto_model = AutoModelForVision2Seq,
237237
use_gradient_checkpointing = "unsloth",
238238
supports_sdpa = True,
239+
whisper_language = None,
240+
whisper_task = None,
239241
**kwargs,
240242
):
241243
if model_types is None:
@@ -304,7 +306,8 @@ def from_pretrained(
304306
do_forced_float32 = True
305307
pass
306308
# Stop SDPA for some archs like Pixtral / Mistral3
307-
kwargs["attn_implementation"] = "sdpa"
309+
if not ("attn_implementation" in kwargs):
310+
kwargs["attn_implementation"] = "sdpa"
308311
if not supports_sdpa:
309312
print(f"Unsloth: {model_type_arch.title()} does not support SDPA - switching to eager!")
310313
del kwargs["attn_implementation"]
@@ -352,6 +355,7 @@ def from_pretrained(
352355
# Check if using forced float32 - we load it in bfloat16, then cast to float16!
353356
torch_dtype = dtype
354357
if do_forced_float32: torch_dtype = torch.bfloat16
358+
355359
model = auto_model.from_pretrained(
356360
model_name,
357361
device_map = device_map,
@@ -367,12 +371,23 @@ def from_pretrained(
367371

368372
# Counteract saved tokenizers
369373
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name
370-
auto_processor = AutoProcessor if auto_model is AutoModelForVision2Seq else AutoTokenizer
371-
tokenizer = auto_processor.from_pretrained(
372-
tokenizer_name,
373-
padding_side = "right",
374-
token = token,
375-
)
374+
is_vlm = (auto_model is AutoModelForVision2Seq)
375+
is_whisper = (whisper_language is not None and whisper_task is not None)
376+
auto_processor = AutoProcessor if (is_vlm or is_whisper) else AutoTokenizer
377+
if whisper_language and whisper_task:
378+
tokenizer = auto_processor.from_pretrained(
379+
tokenizer_name,
380+
padding_side = "right",
381+
token = token,
382+
language = whisper_language,
383+
task = whisper_task,
384+
)
385+
else:
386+
tokenizer = auto_processor.from_pretrained(
387+
tokenizer_name,
388+
padding_side = "right",
389+
token = token,
390+
)
376391
if hasattr(tokenizer, "tokenizer"):
377392
__tokenizer = tokenizer.tokenizer
378393
# Add padding side as well
@@ -469,6 +484,7 @@ def get_peft_model(
469484
modules_to_save = None,
470485
init_lora_weights = True,
471486
loftq_config = {},
487+
task_type = TaskType.CAUSAL_LM,
472488
temporary_location = "_unsloth_temporary_saved_buffers",
473489
**kwargs,
474490
):
@@ -492,7 +508,7 @@ def get_peft_model(
492508
finetune_attention_modules = True
493509
finetune_mlp_modules = True
494510
pass
495-
if target_modules is None:
511+
if target_modules is None or target_modules == "all-linear":
496512
target_modules = get_peft_regex(
497513
model,
498514
finetune_vision_layers = finetune_vision_layers,
@@ -503,7 +519,7 @@ def get_peft_model(
503519
else:
504520
assert(type(target_modules) in (list, tuple,))
505521
pass
506-
522+
507523
# Clear deleted GPU items
508524
for _ in range(3):
509525
gc.collect()
@@ -516,7 +532,7 @@ def get_peft_model(
516532
target_modules = target_modules,
517533
lora_dropout = lora_dropout,
518534
bias = bias,
519-
task_type = TaskType.CAUSAL_LM,
535+
task_type = task_type,
520536
)
521537
model = prepare_model_for_kbit_training(
522538
model,

0 commit comments

Comments
 (0)