Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 73 additions & 34 deletions sentence_transformers/models/StaticEmbedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import logging
import math
import os
from pathlib import Path
from typing import Any
from typing import Any, cast

try:
from typing import Self
Expand All @@ -17,7 +16,7 @@
from safetensors.torch import save_file as save_safetensors_file
from tokenizers import Tokenizer
from torch import nn
from transformers import PreTrainedTokenizerFast
from transformers import PreTrainedTokenizerBase, PreTrainedTokenizerFast

from sentence_transformers.models.InputModule import InputModule
from sentence_transformers.util import get_device_name
Expand All @@ -28,22 +27,25 @@
class StaticEmbedding(InputModule):
def __init__(
self,
tokenizer: Tokenizer | PreTrainedTokenizerFast,
tokenizer: PreTrainedTokenizerBase | Tokenizer,
embedding_weights: np.ndarray | torch.Tensor | None = None,
embedding_dim: int | None = None,
**kwargs,
max_seq_length: int | None = None,
**kwargs: Any,
) -> None:
"""
Initializes the StaticEmbedding model given a tokenizer. The model is a simple embedding bag model that
takes the mean of trained per-token embeddings to compute text embeddings.

Args:
tokenizer (Tokenizer | PreTrainedTokenizerFast): The tokenizer to be used. Must be a fast tokenizer
from ``transformers`` or ``tokenizers``.
tokenizer (Tokenizer | PreTrainedTokenizerFast): The tokenizer to be used.
If this is a Tokenizer from the `tokenizers` library, it will be wrapped in a PreTrainedTokenizerFast.
embedding_weights (np.ndarray | torch.Tensor | None, optional): Pre-trained embedding weights.
Defaults to None.
embedding_dim (int | None, optional): Dimension of the embeddings. Required if embedding_weights
is not provided. Defaults to None.
max_seq_length (int | None, optional): Maximum sequence length for the tokenizer.
If None, no truncation is applied. Defaults to None.

.. tip::

Expand Down Expand Up @@ -79,69 +81,104 @@ def __init__(
"""
super().__init__()

if isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer = tokenizer._tokenizer
elif not isinstance(tokenizer, Tokenizer):
raise ValueError(
"The tokenizer must be fast (i.e. Rust-backed) to use this class. "
"Use Tokenizer.from_pretrained() from `tokenizers` to load a fast tokenizer."
)
if isinstance(tokenizer, Tokenizer):
tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)

padding_token = tokenizer.special_tokens_map.get("pad_token", None)
vocabulary = tokenizer.get_vocab()
# This is more of a safeguard. According to the typing, pad_token can be a list, but in practice it never is.
if not isinstance(padding_token, list) and padding_token is not None:
pad_token_id = vocabulary.get(padding_token, None)
else:
pad_token_id = None

if embedding_weights is not None:
if isinstance(embedding_weights, np.ndarray):
embedding_weights = torch.from_numpy(embedding_weights)

self.embedding = nn.EmbeddingBag.from_pretrained(embedding_weights, freeze=False)
self.embedding = nn.EmbeddingBag.from_pretrained(embedding_weights, freeze=False, padding_idx=pad_token_id)
elif embedding_dim is not None:
self.embedding = nn.EmbeddingBag(tokenizer.get_vocab_size(), embedding_dim)
# Safer because the vocab size is typed weirdly.
vocab_size = len(tokenizer.get_vocab())
self.embedding = nn.EmbeddingBag(vocab_size, embedding_dim, padding_idx=pad_token_id)
else:
raise ValueError("Either `embedding_weights` or `embedding_dim` must be provided.")

self.num_embeddings = self.embedding.num_embeddings
self.embedding_dim = self.embedding.embedding_dim

self.tokenizer: Tokenizer = tokenizer
self.tokenizer.no_padding()
self.tokenizer = tokenizer
self._tokenizer_kwargs = {}
# Implicitly sets tokenizer kwargs because of the setter
self.max_seq_length = max_seq_length

# For the model card
self.base_model = kwargs.get("base_model", None)

def tokenize(self, texts: list[str], **kwargs) -> dict[str, torch.Tensor]:
encodings = self.tokenizer.encode_batch(texts, add_special_tokens=False)
encodings_ids = [encoding.ids for encoding in encodings]

offsets = torch.from_numpy(np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]]))
input_ids = torch.tensor([token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long)
return {"input_ids": input_ids, "offsets": offsets}
def get_word_embedding_dimension(self) -> int:
"""The embedding dimension is the same for word and sentence embeddings."""
return self.embedding_dim

def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
def tokenize(self, texts: list[str], **kwargs: Any) -> dict[str, torch.Tensor]:
"""Tokenizes the input texts and returns a dictionary of tokenized features."""
out_features = {}
# The tokenizer typing is incorrect because we don't pass a framework. Therefore, the return type
# is a dict of lists of lists of ints for all keys we care about.
tokenized = cast(
dict[str, list[list[int]]], self.tokenizer(texts, add_special_tokens=False, **self._tokenizer_kwargs)
)
ids = []
offsets = [0]
for token_ids in tokenized["input_ids"]:
ids.append(torch.LongTensor(token_ids))
offsets.append(offsets[-1] + len(token_ids))

out_features["input_ids"] = torch.cat(ids)
out_features["offsets"] = torch.LongTensor(offsets[:-1])

return out_features

def forward(self, features: dict[str, torch.Tensor], **kwargs: Any) -> dict[str, torch.Tensor]:
features["sentence_embedding"] = self.embedding(features["input_ids"], features["offsets"])
return features

@property
def max_seq_length(self) -> int:
return math.inf
def max_seq_length(self) -> int | float:
"""Gets the maximum sequence length for the tokenizer."""
return math.inf if self._max_seq_length is None else self._max_seq_length

@max_seq_length.setter
def max_seq_length(self, value: int | None) -> None:
"""Sets the maximum sequence length for the tokenizer."""
self._max_seq_length = value
if value is None:
self._tokenizer_kwargs.pop("max_length", None)
self._tokenizer_kwargs.pop("truncation", None)
else:
self._tokenizer_kwargs["max_length"] = value
self._tokenizer_kwargs["truncation"] = True

def get_sentence_embedding_dimension(self) -> int:
"""Returns the dimension of the sentence embeddings."""
return self.embedding_dim

def save(self, output_path: str, *args, safe_serialization: bool = True, **kwargs) -> None:
def save(self, output_path: str, *args: Any, safe_serialization: bool = True, **kwargs: Any) -> None:
if safe_serialization:
save_safetensors_file(self.state_dict(), os.path.join(output_path, "model.safetensors"))
else:
torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
self.tokenizer.save(str(Path(output_path) / "tokenizer.json"))
self.save_tokenizer(output_path, **kwargs)

@classmethod
def load(
cls,
cls: type[Self],
model_name_or_path: str,
subfolder: str = "",
token: bool | str | None = None,
cache_folder: str | None = None,
revision: str | None = None,
local_files_only: bool = False,
**kwargs,
**kwargs: Any,
) -> Self:
hub_kwargs = {
"subfolder": subfolder,
Expand All @@ -153,13 +190,15 @@ def load(
tokenizer_path = cls.load_file_path(model_name_or_path, filename="tokenizer.json", **hub_kwargs)
tokenizer = Tokenizer.from_file(tokenizer_path)

weights = cls.load_torch_weights(model_name_or_path=model_name_or_path, **hub_kwargs)
weights = cast(
dict[str, torch.FloatTensor], cls.load_torch_weights(model_name_or_path=model_name_or_path, **hub_kwargs)
)
try:
weights = weights["embedding.weight"]
except KeyError:
# For compatibility with model2vec models, which are saved with just an "embeddings" key
weights = weights["embeddings"]
return StaticEmbedding(tokenizer, embedding_weights=weights)
return cls(tokenizer, embedding_weights=weights)

@classmethod
def from_distillation(
Expand Down
Loading