Skip to content

Commit ad8d27d

Browse files
ikuyamadatomaarsen
andauthored
Add Support for Knowledgeable Passage Retriever (KPR) (#3495)
* Add entity-related features as valid feature names * Dynamically pass arguments to auto_model based on its signature --------- Co-authored-by: Tom Aarsen <[email protected]>
1 parent 5b18f36 commit ad8d27d

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

sentence_transformers/models/Transformer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import inspect
34
import logging
45
import os
56
from pathlib import Path
@@ -225,11 +226,9 @@ def __repr__(self) -> str:
225226

226227
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
227228
"""Returns token_embeddings, cls_token"""
228-
trans_features = {
229-
key: value
230-
for key, value in features.items()
231-
if key in ["input_ids", "attention_mask", "token_type_ids", "inputs_embeds"]
232-
}
229+
# Get the signature of the auto_model's forward method to pass only the expected arguments from `features`
230+
model_forward_params = list(inspect.signature(self.auto_model.forward).parameters)
231+
trans_features = {key: value for key, value in features.items() if key in model_forward_params}
233232

234233
outputs = self.auto_model(**trans_features, **kwargs, return_dict=True)
235234
token_embeddings = outputs[0]

0 commit comments

Comments
 (0)