diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c99dd25bb..2a8974c49 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -69,6 +69,11 @@ jobs: python -m pip install '.[train, onnx, openvino, dev]' python -m pip install --upgrade 'huggingface_hub[hf_xet]' + - name: Install ipex backend + if: matrix.os == 'ubuntu-latest' + run: | + python -m pip install '.[ipex]' + - name: Install model2vec run: python -m pip install model2vec if: ${{ contains(fromJSON('["3.10", "3.11", "3.12"]'), matrix.python-version) }} diff --git a/docs/installation.md b/docs/installation.md index 77efb1568..116cd2c92 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -4,6 +4,7 @@ We recommend **Python 3.9+**, **[PyTorch 1.11.0+](https://pytorch.org/get-starte * **Default:** This allows for loading, saving, and inference (i.e., getting embeddings) of models. * **ONNX:** This allows for loading, saving, inference, optimizing, and quantizing of models using the ONNX backend. * **OpenVINO:** This allows for loading, saving, and inference of models using the OpenVINO backend. +* **IPEX:** This allows for loading, saving, and inference of models using the IPEX backend. * **Default and Training**: Like **Default**, plus training. * **Development**: All of the above plus some dependencies for developing Sentence Transformers, see [Editable Install](#editable-install). @@ -37,6 +38,12 @@ Note that you can mix and match the various extras, e.g. ``pip install -U "sente pip install -U "sentence-transformers[openvino]" +.. tab:: IPEX + + :: + + pip install -U "sentence-transformers[ipex]" + .. tab:: Default and Training :: @@ -87,6 +94,12 @@ Note that you can mix and match the various extras, e.g. ``pip install -U "sente pip install -U "sentence-transformers[openvino]" +.. tab:: IPEX + + :: + + pip install -U "sentence-transformers[ipex]" + .. tab:: Default and Training :: @@ -139,6 +152,12 @@ You can install ``sentence-transformers`` directly from source to take advantage pip install -U "sentence-transformers[openvino] @ git+https://github.com/UKPLab/sentence-transformers.git" +.. tab:: IPEX + + :: + + pip install -U "sentence-transformers[ipex] @ git+https://github.com/UKPLab/sentence-transformers.git" + .. tab:: Default and Training :: diff --git a/pyproject.toml b/pyproject.toml index 4511d2dca..628135350 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ train = ["datasets", "accelerate>=0.20.3"] onnx = ["optimum[onnxruntime]>=1.23.1"] onnx-gpu = ["optimum[onnxruntime-gpu]>=1.23.1"] openvino = ["optimum-intel[openvino]>=1.20.0"] +ipex = ["optimum-intel[ipex]>=1.21.0"] dev = ["datasets", "accelerate>=0.20.3", "pre-commit", "pytest", "pytest-cov", "peft"] [build-system] diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index fe808a314..fd9354508 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -132,7 +132,7 @@ class SentenceTransformer(nn.Sequential, FitMixin, PeftAdapterMixin): model_card_data (:class:`~sentence_transformers.model_card.SentenceTransformerModelCardData`, optional): A model card data object that contains information about the model. This is used to generate a model card when saving the model. If not set, a default model card data object is created. - backend (str): The backend to use for inference. Can be one of "torch" (default), "onnx", or "openvino". + backend (str): The backend to use for inference. Can be one of "torch" (default), "onnx", "openvino", or "ipex". See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for benchmarking information on the different backends. @@ -181,7 +181,7 @@ def __init__( tokenizer_kwargs: dict[str, Any] | None = None, config_kwargs: dict[str, Any] | None = None, model_card_data: SentenceTransformerModelCardData | None = None, - backend: Literal["torch", "onnx", "openvino"] = "torch", + backend: Literal["torch", "onnx", "openvino", "ipex"] = "torch", ) -> None: # Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name` self.prompts = prompts or {} @@ -386,8 +386,8 @@ def __init__( # Pass the model to the model card data for later use in generating a model card upon saving this model self.model_card_data.register_model(self) - def get_backend(self) -> Literal["torch", "onnx", "openvino"]: - """Return the backend used for inference, which can be one of "torch", "onnx", or "openvino". + def get_backend(self) -> Literal["torch", "onnx", "openvino", "ipex"]: + """Return the backend used for inference, which can be one of "torch", "onnx", "openvino" or "ipex". Returns: str: The backend used for inference. diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 5bb245e1e..b97c4aadc 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -55,7 +55,7 @@ class Transformer(InputModule): tokenizer_name_or_path: Name or path of the tokenizer. When None, then model_name_or_path is used backend: Backend used for model inference. Can be `torch`, `onnx`, - or `openvino`. Default is `torch`. + `openvino`, or `ipex`. Default is `torch`. """ config_file_name: str = "sentence_bert_config.json" @@ -195,8 +195,12 @@ def _load_model( self._load_onnx_model(model_name_or_path, config, cache_dir, **model_args) elif backend == "openvino": self._load_openvino_model(model_name_or_path, config, cache_dir, **model_args) + elif backend == "ipex": + self._load_ipex_model(model_name_or_path, config, cache_dir, **model_args) else: - raise ValueError(f"Unsupported backend '{backend}'. `backend` should be `torch`, `onnx`, or `openvino`.") + raise ValueError( + f"Unsupported backend '{backend}'. `backend` should be `torch`, `onnx`, `openvino`, or `ipex`." + ) def _load_peft_model(self, model_name_or_path: str, config: PeftConfig, cache_dir: str, **model_args) -> None: from peft import PeftModel @@ -262,6 +266,24 @@ def _load_openvino_model( if export: self._backend_warn_to_save(model_name_or_path, is_local, backend_name) + def _load_ipex_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: + try: + from optimum.intel import IPEXModel + except ModuleNotFoundError: + raise Exception( + "Using the IPEX backend requires installing Optimum and IPEX. " + "You can install them with pip: `pip install optimum-intel[ipex]`." + ) + + self.auto_model: IPEXModel = IPEXModel.from_pretrained( + model_name_or_path, + config=config, + cache_dir=cache_dir, + **model_args, + ) + # Wrap the save_pretrained method to save the model in the correct subfolder + self.auto_model._save_pretrained = _save_pretrained_wrapper(self.auto_model._save_pretrained, self.backend) + def _load_onnx_model( self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args ) -> None: diff --git a/tests/test_backends.py b/tests/test_backends.py index d5d8148cc..73986a3df 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -9,6 +9,7 @@ import numpy as np import pytest +from sentence_transformers import SentenceTransformer from tests.utils import is_ci try: @@ -17,7 +18,15 @@ except ImportError: pytest.skip("OpenVINO and ONNX backends are not available", allow_module_level=True) -from sentence_transformers import SentenceTransformer + +BACKENDS = [("onnx", ORTModelForFeatureExtraction), ("openvino", OVModelForFeatureExtraction)] +try: + from optimum.intel import IPEXModel + + BACKENDS.append(("ipex", IPEXModel)) +except ImportError: + pass + if is_ci(): pytest.skip("Skip test in CI to try and avoid 429 Client Error", allow_module_level=True) @@ -26,10 +35,7 @@ ## Testing exporting: @pytest.mark.parametrize( ["backend", "expected_auto_model_class"], - [ - ("onnx", ORTModelForFeatureExtraction), - ("openvino", OVModelForFeatureExtraction), - ], + BACKENDS, ) @pytest.mark.parametrize( "model_kwargs", [{}, {"file_name": "wrong_file_name"}]