From e4fff0c35ff873f7f963afa1179525b1869c3ac5 Mon Sep 17 00:00:00 2001 From: stephantul Date: Mon, 6 Oct 2025 20:47:23 +0200 Subject: [PATCH 1/2] feat: use regular tokenizer in static embedding, add max length --- .../models/StaticEmbedding.py | 101 ++++++++++++------ 1 file changed, 68 insertions(+), 33 deletions(-) diff --git a/sentence_transformers/models/StaticEmbedding.py b/sentence_transformers/models/StaticEmbedding.py index b83d62461..4290b0b26 100644 --- a/sentence_transformers/models/StaticEmbedding.py +++ b/sentence_transformers/models/StaticEmbedding.py @@ -4,8 +4,7 @@ import logging import math import os -from pathlib import Path -from typing import Any +from typing import Any, cast try: from typing import Self @@ -17,7 +16,7 @@ from safetensors.torch import save_file as save_safetensors_file from tokenizers import Tokenizer from torch import nn -from transformers import PreTrainedTokenizerFast +from transformers import PreTrainedTokenizerBase, PreTrainedTokenizerFast from sentence_transformers.models.InputModule import InputModule from sentence_transformers.util import get_device_name @@ -28,22 +27,25 @@ class StaticEmbedding(InputModule): def __init__( self, - tokenizer: Tokenizer | PreTrainedTokenizerFast, + tokenizer: PreTrainedTokenizerBase | Tokenizer, embedding_weights: np.ndarray | torch.Tensor | None = None, embedding_dim: int | None = None, - **kwargs, + max_seq_length: int | None = None, + **kwargs: Any, ) -> None: """ Initializes the StaticEmbedding model given a tokenizer. The model is a simple embedding bag model that takes the mean of trained per-token embeddings to compute text embeddings. Args: - tokenizer (Tokenizer | PreTrainedTokenizerFast): The tokenizer to be used. Must be a fast tokenizer - from ``transformers`` or ``tokenizers``. + tokenizer (Tokenizer | PreTrainedTokenizerFast): The tokenizer to be used. + If this is a Tokenizer from the `tokenizers` library, it will be wrapped in a PreTrainedTokenizerFast. embedding_weights (np.ndarray | torch.Tensor | None, optional): Pre-trained embedding weights. Defaults to None. embedding_dim (int | None, optional): Dimension of the embeddings. Required if embedding_weights is not provided. Defaults to None. + max_seq_length (int | None, optional): Maximum sequence length for the tokenizer. + If None, no truncation is applied. Defaults to None. .. tip:: @@ -79,69 +81,102 @@ def __init__( """ super().__init__() - if isinstance(tokenizer, PreTrainedTokenizerFast): - tokenizer = tokenizer._tokenizer - elif not isinstance(tokenizer, Tokenizer): - raise ValueError( - "The tokenizer must be fast (i.e. Rust-backed) to use this class. " - "Use Tokenizer.from_pretrained() from `tokenizers` to load a fast tokenizer." - ) + if isinstance(tokenizer, Tokenizer): + tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer) + + padding_token = tokenizer.special_tokens_map.get("pad_token", None) + vocabulary = tokenizer.get_vocab() + # This is more of a safeguard. According to the typing, pad_token can be a list, but in practice it never is. + if not isinstance(padding_token, list) and padding_token is not None: + pad_token_id = vocabulary.get(padding_token, None) + else: + pad_token_id = None if embedding_weights is not None: if isinstance(embedding_weights, np.ndarray): embedding_weights = torch.from_numpy(embedding_weights) - self.embedding = nn.EmbeddingBag.from_pretrained(embedding_weights, freeze=False) + self.embedding = nn.EmbeddingBag.from_pretrained(embedding_weights, freeze=False, padding_idx=pad_token_id) elif embedding_dim is not None: - self.embedding = nn.EmbeddingBag(tokenizer.get_vocab_size(), embedding_dim) + # Safer because the vocab size is typed weirdly. + vocab_size = len(tokenizer.get_vocab()) + self.embedding = nn.EmbeddingBag(vocab_size, embedding_dim, padding_idx=pad_token_id) else: raise ValueError("Either `embedding_weights` or `embedding_dim` must be provided.") self.num_embeddings = self.embedding.num_embeddings self.embedding_dim = self.embedding.embedding_dim - self.tokenizer: Tokenizer = tokenizer - self.tokenizer.no_padding() + self.tokenizer = tokenizer + self._tokenizer_kwargs = {} + # Implicitly sets tokenizer kwargs because of the setter + self.max_seq_length = max_seq_length # For the model card self.base_model = kwargs.get("base_model", None) - def tokenize(self, texts: list[str], **kwargs) -> dict[str, torch.Tensor]: - encodings = self.tokenizer.encode_batch(texts, add_special_tokens=False) - encodings_ids = [encoding.ids for encoding in encodings] + def get_word_embedding_dimension(self) -> int: + """The embedding dimension is the same for word and sentence embeddings.""" + return self.embedding_dim + + def tokenize(self, texts: list[str], **kwargs: Any) -> dict[str, torch.Tensor]: + """Tokenizes the input texts and returns a dictionary of tokenized features.""" + out_features = {} + # The tokenizer typing is incorrect because we don't pass a framework. Therefore, the return type + # is a dict of lists of lists of ints for all keys we care about. + tokenized = cast(dict[str, list[list[int]]], self.tokenizer(texts, add_special_tokens=False, **self._tokenizer_kwargs)) + ids = [] + offsets = [0] + for token_ids in tokenized["input_ids"]: + ids.append(torch.LongTensor(token_ids)) + offsets.append(offsets[-1] + len(token_ids)) + + out_features["input_ids"] = torch.cat(ids) + out_features["offsets"] = torch.LongTensor(offsets[:-1]) - offsets = torch.from_numpy(np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]])) - input_ids = torch.tensor([token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long) - return {"input_ids": input_ids, "offsets": offsets} + return out_features - def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: + def forward(self, features: dict[str, torch.Tensor], **kwargs: Any) -> dict[str, torch.Tensor]: features["sentence_embedding"] = self.embedding(features["input_ids"], features["offsets"]) return features @property - def max_seq_length(self) -> int: - return math.inf + def max_seq_length(self) -> int | float: + """Gets the maximum sequence length for the tokenizer.""" + return math.inf if self._max_seq_length is None else self._max_seq_length + + @max_seq_length.setter + def max_seq_length(self, value: int | None) -> None: + """Sets the maximum sequence length for the tokenizer.""" + self._max_seq_length = value + if value is None: + self._tokenizer_kwargs.pop("max_length", None) + self._tokenizer_kwargs.pop("truncation", None) + else: + self._tokenizer_kwargs["max_length"] = value + self._tokenizer_kwargs["truncation"] = True def get_sentence_embedding_dimension(self) -> int: + """Returns the dimension of the sentence embeddings.""" return self.embedding_dim - def save(self, output_path: str, *args, safe_serialization: bool = True, **kwargs) -> None: + def save(self, output_path: str, *args: Any, safe_serialization: bool = True, **kwargs: Any) -> None: if safe_serialization: save_safetensors_file(self.state_dict(), os.path.join(output_path, "model.safetensors")) else: torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin")) - self.tokenizer.save(str(Path(output_path) / "tokenizer.json")) + self.save_tokenizer(output_path, **kwargs) @classmethod def load( - cls, + cls: type[Self], model_name_or_path: str, subfolder: str = "", token: bool | str | None = None, cache_folder: str | None = None, revision: str | None = None, local_files_only: bool = False, - **kwargs, + **kwargs: Any, ) -> Self: hub_kwargs = { "subfolder": subfolder, @@ -153,13 +188,13 @@ def load( tokenizer_path = cls.load_file_path(model_name_or_path, filename="tokenizer.json", **hub_kwargs) tokenizer = Tokenizer.from_file(tokenizer_path) - weights = cls.load_torch_weights(model_name_or_path=model_name_or_path, **hub_kwargs) + weights = cast(dict[str, torch.FloatTensor], cls.load_torch_weights(model_name_or_path=model_name_or_path, **hub_kwargs)) try: weights = weights["embedding.weight"] except KeyError: # For compatibility with model2vec models, which are saved with just an "embeddings" key weights = weights["embeddings"] - return StaticEmbedding(tokenizer, embedding_weights=weights) + return cls(tokenizer, embedding_weights=weights) @classmethod def from_distillation( From a4f85a174de6659d5845f0a165e76b243668c268 Mon Sep 17 00:00:00 2001 From: stephantul Date: Mon, 6 Oct 2025 20:54:05 +0200 Subject: [PATCH 2/2] ruff --- sentence_transformers/models/StaticEmbedding.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sentence_transformers/models/StaticEmbedding.py b/sentence_transformers/models/StaticEmbedding.py index 4290b0b26..51a88aac7 100644 --- a/sentence_transformers/models/StaticEmbedding.py +++ b/sentence_transformers/models/StaticEmbedding.py @@ -124,7 +124,9 @@ def tokenize(self, texts: list[str], **kwargs: Any) -> dict[str, torch.Tensor]: out_features = {} # The tokenizer typing is incorrect because we don't pass a framework. Therefore, the return type # is a dict of lists of lists of ints for all keys we care about. - tokenized = cast(dict[str, list[list[int]]], self.tokenizer(texts, add_special_tokens=False, **self._tokenizer_kwargs)) + tokenized = cast( + dict[str, list[list[int]]], self.tokenizer(texts, add_special_tokens=False, **self._tokenizer_kwargs) + ) ids = [] offsets = [0] for token_ids in tokenized["input_ids"]: @@ -188,7 +190,9 @@ def load( tokenizer_path = cls.load_file_path(model_name_or_path, filename="tokenizer.json", **hub_kwargs) tokenizer = Tokenizer.from_file(tokenizer_path) - weights = cast(dict[str, torch.FloatTensor], cls.load_torch_weights(model_name_or_path=model_name_or_path, **hub_kwargs)) + weights = cast( + dict[str, torch.FloatTensor], cls.load_torch_weights(model_name_or_path=model_name_or_path, **hub_kwargs) + ) try: weights = weights["embedding.weight"] except KeyError: