Skip to content

Commit ac79d54

Browse files
authored
[bugfix] fix reranker_padding_free (#6989)
1 parent 4234e70 commit ac79d54

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

swift/plugin/loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def generative_reranker_loss(outputs,
575575

576576
# Extract logits at the last valid (non-padding) token position for each sample
577577
batch_size = logits.shape[0]
578-
last_valid_indices = get_last_valid_indices(attention_mask)
578+
last_valid_indices = -1 if attention_mask is None else get_last_valid_indices(attention_mask)
579579
batch_indices = torch.arange(batch_size, device=logits.device)
580580
last_valid_logits = logits[batch_indices, last_valid_indices, :]
581581

@@ -743,7 +743,7 @@ def listwise_generative_reranker_loss(outputs,
743743

744744
# Extract logits at the last valid (non-padding) token position for each sample
745745
batch_size = logits.shape[0]
746-
last_valid_indices = get_last_valid_indices(attention_mask)
746+
last_valid_indices = -1 if attention_mask is None else get_last_valid_indices(attention_mask)
747747
batch_indices = torch.arange(batch_size, device=logits.device)
748748
last_valid_logits = logits[batch_indices, last_valid_indices, :]
749749

swift/trainers/trainers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
190190
labels,
191191
num_items_in_batch=num_items_in_batch,
192192
trainer=self,
193-
attention_mask=inputs['attention_mask'])
193+
attention_mask=inputs.get('attention_mask'))
194194
else:
195195
# Fallback to model's loss
196196
loss = outputs.loss

0 commit comments

Comments
 (0)