Skip to content

Commit 27c14b9

Browse files
authored
[fix] Avoid error if prompts & output_value=None (#3327)
* Avoid error if prompts & output_value=None * Refactor, avoid function & batch_size override
1 parent 03dff58 commit 27c14b9

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

sentence_transformers/SentenceTransformer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,15 @@ def encode(
700700
embeddings.append(token_emb[0 : last_mask_id + 1])
701701
elif output_value is None: # Return all outputs
702702
embeddings = []
703-
for sent_idx in range(len(out_features["sentence_embedding"])):
704-
row = {name: out_features[name][sent_idx] for name in out_features}
705-
embeddings.append(row)
703+
for idx in range(len(out_features["sentence_embedding"])):
704+
batch_item = {}
705+
for name, value in out_features.items():
706+
try:
707+
batch_item[name] = value[idx]
708+
except TypeError:
709+
# Handle non-indexable values (like prompt_length)
710+
batch_item[name] = value
711+
embeddings.append(batch_item)
706712
else: # Sentence embeddings
707713
embeddings = out_features[output_value]
708714
embeddings = embeddings.detach()

tests/test_sentence_transformer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,27 @@ def test_save_load_prompts() -> None:
376376
assert fresh_model.default_prompt_name == "query"
377377

378378

379+
def test_prompt_output_value_None(stsb_bert_tiny_model_reused) -> None:
380+
model = stsb_bert_tiny_model_reused
381+
outputs = model.encode(
382+
["Text one", "Text two"],
383+
prompt="query: ",
384+
output_value=None,
385+
)
386+
assert len(outputs) == 2
387+
assert isinstance(outputs, list)
388+
expected_keys = {
389+
"input_ids",
390+
"token_type_ids",
391+
"attention_mask",
392+
"sentence_embedding",
393+
"token_embeddings",
394+
"prompt_length",
395+
}
396+
assert set(outputs[0].keys()) == expected_keys
397+
assert set(outputs[1].keys()) == expected_keys
398+
399+
379400
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test float16 support.")
380401
def test_load_with_torch_dtype() -> None:
381402
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")

0 commit comments

Comments
 (0)