Skip to content

Commit 6e45dad

Browse files
committed
cleanup
1 parent eee5a62 commit 6e45dad

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

unsloth/models/llama.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -762,16 +762,8 @@ def LlamaModel_fast_forward(
762762
seq_length_with_past = seq_length
763763

764764
# Fix out of bounds tokenization unless we were given packed metadata
765-
allow_overlength = getattr(self, "_unsloth_allow_packed_overlength", False) or any(
766-
key in kwargs
767-
for key in (
768-
"cu_seq_lens_q",
769-
"cu_seq_lens",
770-
"cu_seqlens",
771-
"max_length_q",
772-
"max_seqlen",
773-
"packed_seq_lengths",
774-
)
765+
allow_overlength = getattr(self, "_unsloth_allow_packed_overlength", False) or (
766+
"packed_seq_lengths" in kwargs
775767
)
776768
if hasattr(self, "max_seq_length") and not allow_overlength:
777769
if seq_length > self.max_seq_length:

unsloth/models/qwen3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ def Qwen3Attention_fast_forward(
212212
# Must be contiguous or else results are False!
213213
# https://github.com/pytorch/pytorch/issues/112577
214214
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
215+
# Needs (batch_size, n_heads, seq_len, head_dim)
216+
# is_casual and attention_mask must not be both set!
217+
# when qlen==vlen and attn_mask is None, we should use causal attention
215218
Q_len = Q.shape[-2]
216219
K_len = K.shape[-2]
217220
if seq_info is not None and attention_mask is None:

0 commit comments

Comments
 (0)