diff --git a/README.md b/README.md index b3509ec..215aff1 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ - [**BAAI/bge-small-en-v1.5**](https://huggingface.co/BAAI/bge-small-en-v1.5) - Default - [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) +- [**sentence-transformers/all-mpnet-base-v2**](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) - [**mixedbread-ai/mxbai-embed-large-v1**](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1) - [**Qdrant/clip-ViT-B-32-text**](https://huggingface.co/Qdrant/clip-ViT-B-32-text) - pairs with `clip-ViT-B-32-vision` for image-to-text search - [**BAAI/bge-large-en-v1.5**](https://huggingface.co/BAAI/bge-large-en-v1.5) diff --git a/src/models/text_embedding.rs b/src/models/text_embedding.rs index edca79f..a7947a4 100644 --- a/src/models/text_embedding.rs +++ b/src/models/text_embedding.rs @@ -15,6 +15,8 @@ pub enum EmbeddingModel { AllMiniLML12V2, /// Quantized sentence-transformers/all-MiniLM-L12-v2 AllMiniLML12V2Q, + /// sentence-transformers/all-mpnet-base-v2 + AllMpnetBaseV2, /// BAAI/bge-base-en-v1.5 BGEBaseENV15, /// Quantized BAAI/bge-base-en-v1.5 @@ -111,6 +113,15 @@ fn init_models_map() -> HashMap> { additional_files: Vec::new(), output_key: None, }, + ModelInfo { + model: EmbeddingModel::AllMpnetBaseV2, + dim: 768, + description: String::from("Sentence Transformer model, mpnet-base-v2"), + model_code: String::from("Xenova/all-mpnet-base-v2"), + model_file: String::from("onnx/model.onnx"), + additional_files: Vec::new(), + output_key: None, + }, ModelInfo { model: EmbeddingModel::BGEBaseENV15, dim: 768, diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index e83530d..e8bce67 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -173,6 +173,7 @@ impl TextEmbedding { EmbeddingModel::ParaphraseMLMiniLML12V2 => Some(Pooling::Mean), EmbeddingModel::ParaphraseMLMiniLML12V2Q => Some(Pooling::Mean), EmbeddingModel::ParaphraseMLMpnetBaseV2 => Some(Pooling::Mean), + EmbeddingModel::AllMpnetBaseV2 => Some(Pooling::Mean), EmbeddingModel::ModernBertEmbedLarge => Some(Pooling::Mean), diff --git a/tests/embeddings.rs b/tests/embeddings.rs index 6ef5616..7172b64 100644 --- a/tests/embeddings.rs +++ b/tests/embeddings.rs @@ -38,6 +38,7 @@ fn verify_embeddings(model: &EmbeddingModel, embeddings: &[Embedding]) -> Result EmbeddingModel::AllMiniLML12V2Q => [-0.07808663, 0.27919534, -0.0770612, -0.75660324], EmbeddingModel::AllMiniLML6V2 => [0.59605527, 0.36542925, -0.16450031, -0.40903988], EmbeddingModel::AllMiniLML6V2Q => [0.5677276, 0.40180072, -0.15454668, -0.4672576], + EmbeddingModel::AllMpnetBaseV2=> [-0.51290065, -0.4844747, -0.53036124, -0.5337459], EmbeddingModel::BGEBaseENV15 => [-0.51290065, -0.4844747, -0.53036124, -0.5337459], EmbeddingModel::BGEBaseENV15Q => [-0.5130697, -0.48461288, -0.53067875, -0.5337806], EmbeddingModel::BGELargeENV15 => [-0.19347441, -0.28394595, -0.1549195, -0.22201893],