Skip to content

Commit 85ec645

Browse files
benHeidtomaarsen
andauthored
MegaBatchMarginLoss use mini batches for positives too (#3550)
* Use minibatches also for positive embeddings This commit fixes that with very large batch sizes there might be OOMs since all positives are embedded at once. * remove unintented change * Slight refactor: range until batch_size and move for-loop deeper --------- Co-authored-by: Tom Aarsen <[email protected]>
1 parent 8e4c85b commit 85ec645

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

sentence_transformers/losses/MegaBatchMarginLoss.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,17 @@ def __init__(
9595
def forward_mini_batched(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
9696
anchor, positive = sentence_features
9797
feature_names = list(anchor.keys())
98+
batch_size = len(positive[next(iter(positive))])
9899

100+
all_positive_emb = []
99101
with torch.no_grad():
100102
self.model.eval()
101-
all_positive_emb = self.model(positive)["sentence_embedding"].detach()
103+
for start_idx in range(0, batch_size, self.mini_batch_size):
104+
end_idx = start_idx + self.mini_batch_size
105+
input_mini_batch = {k: v[start_idx:end_idx] for k, v in positive.items()}
106+
all_positive_emb.append(self.model(input_mini_batch)["sentence_embedding"].detach())
102107
self.model.train()
108+
all_positive_emb = torch.cat(all_positive_emb, dim=0)
103109

104110
diagonal_matrix = torch.eye(len(all_positive_emb), len(all_positive_emb), device=all_positive_emb.device)
105111

0 commit comments

Comments
 (0)