From e3050aa4d2b9638fd159e5d2c2b372444f37dfbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zacar=C3=ADas=20F=2E=20Ojeda?= Date: Sat, 13 Sep 2025 21:13:46 -0300 Subject: [PATCH 1/2] feat: adding EmbeddingGemma300M model * add output_key field to ModelInfo and update models to include it. --- src/models/image_embedding.rs | 5 ++++ src/models/model_info.rs | 3 ++- src/models/sparse.rs | 1 + src/models/text_embedding.rs | 41 +++++++++++++++++++++++++++++++++ src/output/output_precedence.rs | 8 ++++++- src/text_embedding/impl.rs | 28 ++++++++++++++++------ src/text_embedding/init.rs | 5 +++- tests/embeddings.rs | 1 + 8 files changed, 82 insertions(+), 10 deletions(-) diff --git a/src/models/image_embedding.rs b/src/models/image_embedding.rs index 4913de6..59a4fde 100644 --- a/src/models/image_embedding.rs +++ b/src/models/image_embedding.rs @@ -26,6 +26,7 @@ pub fn models_list() -> Vec> { model_code: String::from("Qdrant/clip-ViT-B-32-vision"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: ImageEmbeddingModel::Resnet50, @@ -34,6 +35,7 @@ pub fn models_list() -> Vec> { model_code: String::from("Qdrant/resnet50-onnx"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: ImageEmbeddingModel::UnicomVitB16, @@ -42,6 +44,7 @@ pub fn models_list() -> Vec> { model_code: String::from("Qdrant/Unicom-ViT-B-16"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: ImageEmbeddingModel::UnicomVitB32, @@ -50,6 +53,7 @@ pub fn models_list() -> Vec> { model_code: String::from("Qdrant/Unicom-ViT-B-32"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: ImageEmbeddingModel::NomicEmbedVisionV15, @@ -58,6 +62,7 @@ pub fn models_list() -> Vec> { model_code: String::from("nomic-ai/nomic-embed-vision-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ]; diff --git a/src/models/model_info.rs b/src/models/model_info.rs index 4d7b184..2e10fe1 100644 --- a/src/models/model_info.rs +++ b/src/models/model_info.rs @@ -1,4 +1,4 @@ -use crate::RerankerModel; +use crate::{OutputKey, RerankerModel}; /// Data struct about the available models #[derive(Debug, Clone)] @@ -9,6 +9,7 @@ pub struct ModelInfo { pub model_code: String, pub model_file: String, pub additional_files: Vec, + pub output_key: Option, } /// Data struct about the available reranker models diff --git a/src/models/sparse.rs b/src/models/sparse.rs index a89f93c..1b52d2f 100644 --- a/src/models/sparse.rs +++ b/src/models/sparse.rs @@ -17,6 +17,7 @@ pub fn models_list() -> Vec> { model_code: String::from("Qdrant/Splade_PP_en_v1"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }] } diff --git a/src/models/text_embedding.rs b/src/models/text_embedding.rs index 56da9a1..8292872 100644 --- a/src/models/text_embedding.rs +++ b/src/models/text_embedding.rs @@ -68,6 +68,8 @@ pub enum EmbeddingModel { ClipVitB32, /// jinaai/jina-embeddings-v2-base-code JinaEmbeddingsV2BaseCode, + /// onnx-community/embeddinggemma-300m-ONNX + EmbeddingGemma300M, } /// Centralized function to initialize the models map. @@ -80,6 +82,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/all-MiniLM-L6-v2-onnx"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::AllMiniLML6V2Q, @@ -88,6 +91,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/all-MiniLM-L6-v2"), model_file: String::from("onnx/model_quantized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::AllMiniLML12V2, @@ -96,6 +100,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/all-MiniLM-L12-v2"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::AllMiniLML12V2Q, @@ -104,6 +109,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/all-MiniLM-L12-v2"), model_file: String::from("onnx/model_quantized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGEBaseENV15, @@ -112,6 +118,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/bge-base-en-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGEBaseENV15Q, @@ -120,6 +127,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/bge-base-en-v1.5-onnx-Q"), model_file: String::from("model_optimized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGELargeENV15, @@ -128,6 +136,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/bge-large-en-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGELargeENV15Q, @@ -136,6 +145,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/bge-large-en-v1.5-onnx-Q"), model_file: String::from("model_optimized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGESmallENV15, @@ -144,6 +154,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/bge-small-en-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGESmallENV15Q, @@ -154,6 +165,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/bge-small-en-v1.5-onnx-Q"), model_file: String::from("model_optimized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::NomicEmbedTextV1, @@ -162,6 +174,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("nomic-ai/nomic-embed-text-v1"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::NomicEmbedTextV15, @@ -170,6 +183,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("nomic-ai/nomic-embed-text-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::NomicEmbedTextV15Q, @@ -180,6 +194,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("nomic-ai/nomic-embed-text-v1.5"), model_file: String::from("onnx/model_quantized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::ParaphraseMLMiniLML12V2Q, @@ -188,6 +203,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q"), model_file: String::from("model_optimized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::ParaphraseMLMiniLML12V2, @@ -196,6 +212,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/paraphrase-multilingual-MiniLM-L12-v2"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::ParaphraseMLMpnetBaseV2, @@ -206,6 +223,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/paraphrase-multilingual-mpnet-base-v2"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGESmallZHV15, @@ -214,6 +232,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/bge-small-zh-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::BGELargeZHV15, @@ -222,6 +241,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Xenova/bge-large-zh-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::ModernBertEmbedLarge, @@ -230,6 +250,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("lightonai/modernbert-embed-large"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::MultilingualE5Small, @@ -238,6 +259,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("intfloat/multilingual-e5-small"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::MultilingualE5Base, @@ -246,6 +268,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("intfloat/multilingual-e5-base"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::MultilingualE5Large, @@ -254,6 +277,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/multilingual-e5-large-onnx"), model_file: String::from("model.onnx"), additional_files: vec!["model.onnx_data".to_string()], + output_key: None, }, ModelInfo { model: EmbeddingModel::MxbaiEmbedLargeV1, @@ -262,6 +286,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::MxbaiEmbedLargeV1Q, @@ -270,6 +295,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"), model_file: String::from("onnx/model_quantized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::GTEBaseENV15, @@ -278,6 +304,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::GTEBaseENV15Q, @@ -286,6 +313,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"), model_file: String::from("onnx/model_quantized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::GTELargeENV15, @@ -294,6 +322,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::GTELargeENV15Q, @@ -302,6 +331,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"), model_file: String::from("onnx/model_quantized.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::ClipVitB32, @@ -310,6 +340,7 @@ fn init_models_map() -> HashMap> { model_code: String::from("Qdrant/clip-ViT-B-32-text"), model_file: String::from("model.onnx"), additional_files: Vec::new(), + output_key: None, }, ModelInfo { model: EmbeddingModel::JinaEmbeddingsV2BaseCode, @@ -318,6 +349,16 @@ fn init_models_map() -> HashMap> { model_code: String::from("jinaai/jina-embeddings-v2-base-code"), model_file: String::from("onnx/model.onnx"), additional_files: Vec::new(), + output_key: None, + }, + ModelInfo { + model: EmbeddingModel::EmbeddingGemma300M, + dim: 768, + description: String::from("EmbeddingGemma is a 300M parameter from Google"), + model_code: String::from("onnx-community/embeddinggemma-300m-ONNX"), + model_file: String::from("onnx/model.onnx"), + additional_files: vec!["onnx/model.onnx_data".to_string()], + output_key: Some(crate::OutputKey::ByName("sentence_embedding")), }, ]; diff --git a/src/output/output_precedence.rs b/src/output/output_precedence.rs index ead5bb7..5dcc6a3 100644 --- a/src/output/output_precedence.rs +++ b/src/output/output_precedence.rs @@ -7,7 +7,7 @@ //! e.g. reading the output keys from the model file. /// Enum for defining the key of the output. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum OutputKey { OnlyOne, ByOrder(usize), @@ -41,3 +41,9 @@ impl OutputPrecedence for &[OutputKey] { self.iter() } } + +impl OutputPrecedence for &OutputKey { + fn key_precedence(&self) -> impl Iterator { + std::iter::once(*self) + } +} diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index e85c942..0ec031e 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -3,8 +3,10 @@ #[cfg(feature = "hf-hub")] use crate::common::load_tokenizer_hf_hub; use crate::{ - common::load_tokenizer, models::text_embedding::models_list, models::ModelTrait, - pooling::Pooling, Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, QuantizationMode, + common::load_tokenizer, + models::{text_embedding::models_list, ModelTrait}, + pooling::Pooling, + Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, OutputKey, QuantizationMode, SingleBatchOutput, }; #[cfg(feature = "hf-hub")] @@ -80,6 +82,7 @@ impl TextEmbedding { session, post_processing, TextEmbedding::get_quantization_mode(&model_name), + model_info.output_key.clone(), )) } @@ -109,6 +112,7 @@ impl TextEmbedding { session, model.pooling, model.quantization, + model.output_key, )) } @@ -118,6 +122,7 @@ impl TextEmbedding { session: Session, post_process: Option, quantization: QuantizationMode, + output_key: Option, ) -> Self { let need_token_type_ids = session .inputs @@ -130,6 +135,7 @@ impl TextEmbedding { need_token_type_ids, pooling: post_process, quantization, + output_key, } } /// Return the TextEmbedding model's directory from cache or remote retrieval @@ -185,6 +191,8 @@ impl TextEmbedding { EmbeddingModel::ClipVitB32 => Some(Pooling::Mean), EmbeddingModel::JinaEmbeddingsV2BaseCode => Some(Pooling::Mean), + + EmbeddingModel::EmbeddingGemma300M => Some(Pooling::Mean), } } @@ -365,10 +373,16 @@ impl TextEmbedding { batch_size: Option, ) -> Result> { let batches = self.transform(texts, batch_size)?; - - batches.export_with_transformer(output::transformer_with_precedence( - output::OUTPUT_TYPE_PRECEDENCE, - self.pooling.clone(), - )) + if let Some(output_key) = &self.output_key { + return batches.export_with_transformer(output::transformer_with_precedence( + output_key, + self.pooling.clone(), + )); + } else { + batches.export_with_transformer(output::transformer_with_precedence( + output::OUTPUT_TYPE_PRECEDENCE, + self.pooling.clone(), + )) + } } } diff --git a/src/text_embedding/init.rs b/src/text_embedding/init.rs index 2e4393f..54b28fe 100644 --- a/src/text_embedding/init.rs +++ b/src/text_embedding/init.rs @@ -5,7 +5,7 @@ use crate::{ common::TokenizerFiles, init::{HasMaxLength, InitOptionsWithLength}, pooling::Pooling, - EmbeddingModel, QuantizationMode, + EmbeddingModel, OutputKey, QuantizationMode, }; use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; use tokenizers::Tokenizer; @@ -80,6 +80,7 @@ pub struct UserDefinedEmbeddingModel { pub tokenizer_files: TokenizerFiles, pub pooling: Option, pub quantization: QuantizationMode, + pub output_key: Option, } impl UserDefinedEmbeddingModel { @@ -89,6 +90,7 @@ impl UserDefinedEmbeddingModel { tokenizer_files, quantization: QuantizationMode::None, pooling: None, + output_key: None, } } @@ -110,4 +112,5 @@ pub struct TextEmbedding { pub(crate) session: Session, pub(crate) need_token_type_ids: bool, pub(crate) quantization: QuantizationMode, + pub(crate) output_key: Option, } diff --git a/tests/embeddings.rs b/tests/embeddings.rs index 70e7772..6ef5616 100644 --- a/tests/embeddings.rs +++ b/tests/embeddings.rs @@ -64,6 +64,7 @@ fn verify_embeddings(model: &EmbeddingModel, embeddings: &[Embedding]) -> Result EmbeddingModel::ParaphraseMLMpnetBaseV2 => [0.39132136, 0.49490625, 0.65497226, 0.34237382], EmbeddingModel::ClipVitB32 => [0.7057363, 1.3549932, 0.46823958, 0.52351093], EmbeddingModel::JinaEmbeddingsV2BaseCode => [-0.31383067, -0.3758629, -0.24878195, -0.35373706], + EmbeddingModel::EmbeddingGemma300M => [0.22703816, 0.6947083, 0.07579082, 1.6958784], _ => panic!("Model {model} not found. If you have just inserted this `EmbeddingModel` variant, please update the expected embeddings."), }; From 1ed7444d3571666cafbb974182f053407c13c150 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Mon, 15 Sep 2025 22:42:16 +0530 Subject: [PATCH 2/2] chore: Misc. updates Signed-off-by: Anush008 --- src/models/model_info.rs | 2 ++ src/text_embedding/impl.rs | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/models/model_info.rs b/src/models/model_info.rs index 2e10fe1..f65170c 100644 --- a/src/models/model_info.rs +++ b/src/models/model_info.rs @@ -2,6 +2,7 @@ use crate::{OutputKey, RerankerModel}; /// Data struct about the available models #[derive(Debug, Clone)] +#[non_exhaustive] pub struct ModelInfo { pub model: T, pub dim: usize, @@ -14,6 +15,7 @@ pub struct ModelInfo { /// Data struct about the available reranker models #[derive(Debug, Clone)] +#[non_exhaustive] pub struct RerankerModelInfo { pub model: RerankerModel, pub description: String, diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index 0ec031e..3a0fdc0 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -374,10 +374,10 @@ impl TextEmbedding { ) -> Result> { let batches = self.transform(texts, batch_size)?; if let Some(output_key) = &self.output_key { - return batches.export_with_transformer(output::transformer_with_precedence( + batches.export_with_transformer(output::transformer_with_precedence( output_key, self.pooling.clone(), - )); + )) } else { batches.export_with_transformer(output::transformer_with_precedence( output::OUTPUT_TYPE_PRECEDENCE,