diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index b9f46da0a..f66118d9a 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -41,6 +41,7 @@ from .quantization import quantize_embeddings from .util import ( batch_to_device, + cos_sim, get_device_name, import_from_string, is_sentence_transformer_model, @@ -691,6 +692,93 @@ def forward(self, input: dict[str, Tensor], **kwargs) -> dict[str, Tensor]: input = module(input, **module_kwargs) return input + def classify( + self, + sentences: str | list[str], + labels: list[str], + label_template: str = 'The main subject of this text is {}.', + prompt_name: str | None = None, + prompt: str | None = None, + batch_size: int = 32, + show_progress_bar: bool | None = None, + normalize_embeddings: bool = False, + **kwargs, + ) -> list[list[tuple[str, float]]]: + """ + Perform zero-shot classification using the embedding models. + + Args: + sentences (Union[str, List[str]]): The sentences to classify. + labels (List[str]): The labels to classify the sentences against. + label_template (str, optional): A template to format the labels. + For example, if the label template is "This is a label: {}", + then the label "positive" will be formatted as "This is a label: positive". Defaults to None. + prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary, + which is either set in the constructor or loaded from the model configuration. For example if + ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What + is the capital of France?" will be encoded as "query: What is the capital of France?" because the sentence + is appended to the prompt. If ``prompt`` is also set, this argument is ignored. Defaults to None. + prompt (Optional[str], optional): The prompt to use for encoding. For example, if the prompt is "query: ", then the + sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?" + because the sentence is appended to the prompt. If ``prompt`` is set, ``prompt_name`` is ignored. Defaults to None. + batch_size (int, optional): The batch size used for the computation. Defaults to 32. + show_progress_bar (bool, optional): Whether to output a progress bar when encode sentences. Defaults to None. + normalize_embeddings (bool, optional): Whether to normalize returned vectors to have length 1. In that case, + the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False. + + Returns: + List[List[Tuple[str, float]]]: A list of results for each sentence. + Each result is a list of tuples with the label and the similarity score. + + Example: + :: + + from sentence_transformers import SentenceTransformer + + # Load a pre-trained SentenceTransformer model + model = SentenceTransformer('all-mpnet-base-v2') + + sentences = [ + "The weather is lovely today.", + "It's so sunny outside!", + "He drove to the stadium.", + ] + labels = ["weather", "sports", "politics"] + results = model.classify(sentences, labels) + + for sentence, result in zip(sentences, results): + print(f"\nClassification for '{sentence}':") + for label, score in result: + print(f"{label}: {score}") + """ + raw_labels = list(labels) + if label_template: + labels = [label_template.format(label) for label in labels] + logger.debug("Encoding input sentences") + text_embeddings = self.encode( + sentences, + batch_size=batch_size, + show_progress_bar=show_progress_bar, + prompt=prompt, + prompt_name=prompt_name, + normalize_embeddings=normalize_embeddings, + ) + logger.debug("Encoding labels") + label_embeddings = self.encode( + labels, + batch_size=batch_size, + show_progress_bar=show_progress_bar, + prompt=prompt, + prompt_name=prompt_name, + normalize_embeddings=normalize_embeddings, + ) + similarities = cos_sim(text_embeddings, label_embeddings) + # torch nn softmax on similarity + similarities = torch.nn.functional.softmax(similarities, dim=1) + similarities = similarities.cpu().tolist() + + return [sorted(zip(raw_labels, row), key=lambda x: x[1], reverse=True) for row in similarities] + @property def similarity_fn_name(self) -> Literal["cosine", "dot", "euclidean", "manhattan"]: """Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`. diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 28636d728..eda95b7f7 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -612,6 +612,31 @@ def test_similarity_score(stsb_bert_tiny_model_reused: SentenceTransformer, simi if similarity_fn_name in ("cosine", "dot"): assert (pairwise_scores > 0.5).all() +def test_classify(stsb_bert_tiny_model_reused: SentenceTransformer) -> None: + model = stsb_bert_tiny_model_reused + + sentences = [ + "The weather is so nice!", + "It's so sunny outside.", + "He's driving to the movie theater.", + "She's going to the cinema.", + ] + labels = [ + 'travel', 'cooking', 'dancing' + ] + results = model.classify(sentences, labels) + assert len(results) == len(sentences) + for result in results: + assert len(result) == len(labels) + predicted_labels = list() + scores = 0 + for predicted_label, score in result: + assert predicted_label in labels + assert 0 <= score <= 1 + predicted_labels.append(predicted_label) + scores += score + assert np.isclose(scores, 1) + assert set(predicted_labels) == set(labels) def test_similarity_score_save(stsb_bert_tiny_model: SentenceTransformer) -> None: model = stsb_bert_tiny_model