Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/models/image_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
model_code: String::from("Qdrant/clip-ViT-B-32-vision"),
model_file: String::from("model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: ImageEmbeddingModel::Resnet50,
Expand All @@ -34,6 +35,7 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
model_code: String::from("Qdrant/resnet50-onnx"),
model_file: String::from("model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: ImageEmbeddingModel::UnicomVitB16,
Expand All @@ -42,6 +44,7 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
model_code: String::from("Qdrant/Unicom-ViT-B-16"),
model_file: String::from("model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: ImageEmbeddingModel::UnicomVitB32,
Expand All @@ -50,6 +53,7 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
model_code: String::from("Qdrant/Unicom-ViT-B-32"),
model_file: String::from("model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: ImageEmbeddingModel::NomicEmbedVisionV15,
Expand All @@ -58,6 +62,7 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
model_code: String::from("nomic-ai/nomic-embed-vision-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
];

Expand Down
5 changes: 4 additions & 1 deletion src/models/model_info.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
use crate::RerankerModel;
use crate::{OutputKey, RerankerModel};

/// Data struct about the available models
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ModelInfo<T> {
pub model: T,
pub dim: usize,
pub description: String,
pub model_code: String,
pub model_file: String,
pub additional_files: Vec<String>,
pub output_key: Option<OutputKey>,
}

/// Data struct about the available reranker models
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct RerankerModelInfo {
pub model: RerankerModel,
pub description: String,
Expand Down
1 change: 1 addition & 0 deletions src/models/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub fn models_list() -> Vec<ModelInfo<SparseModel>> {
model_code: String::from("Qdrant/Splade_PP_en_v1"),
model_file: String::from("model.onnx"),
additional_files: Vec::new(),
output_key: None,
}]
}

Expand Down
41 changes: 41 additions & 0 deletions src/models/text_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ pub enum EmbeddingModel {
ClipVitB32,
/// jinaai/jina-embeddings-v2-base-code
JinaEmbeddingsV2BaseCode,
/// onnx-community/embeddinggemma-300m-ONNX
EmbeddingGemma300M,
}

/// Centralized function to initialize the models map.
Expand All @@ -80,6 +82,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Qdrant/all-MiniLM-L6-v2-onnx"),
model_file: String::from("model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::AllMiniLML6V2Q,
Expand All @@ -88,6 +91,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Xenova/all-MiniLM-L6-v2"),
model_file: String::from("onnx/model_quantized.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::AllMiniLML12V2,
Expand All @@ -96,6 +100,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Xenova/all-MiniLM-L12-v2"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::AllMiniLML12V2Q,
Expand All @@ -104,6 +109,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Xenova/all-MiniLM-L12-v2"),
model_file: String::from("onnx/model_quantized.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::BGEBaseENV15,
Expand All @@ -112,6 +118,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Xenova/bge-base-en-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::BGEBaseENV15Q,
Expand All @@ -120,6 +127,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Qdrant/bge-base-en-v1.5-onnx-Q"),
model_file: String::from("model_optimized.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::BGELargeENV15,
Expand All @@ -128,6 +136,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Xenova/bge-large-en-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::BGELargeENV15Q,
Expand All @@ -136,6 +145,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Qdrant/bge-large-en-v1.5-onnx-Q"),
model_file: String::from("model_optimized.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::BGESmallENV15,
Expand All @@ -144,6 +154,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Xenova/bge-small-en-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::BGESmallENV15Q,
Expand All @@ -154,6 +165,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Qdrant/bge-small-en-v1.5-onnx-Q"),
model_file: String::from("model_optimized.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::NomicEmbedTextV1,
Expand All @@ -162,6 +174,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("nomic-ai/nomic-embed-text-v1"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::NomicEmbedTextV15,
Expand All @@ -170,6 +183,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("nomic-ai/nomic-embed-text-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::NomicEmbedTextV15Q,
Expand All @@ -180,6 +194,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("nomic-ai/nomic-embed-text-v1.5"),
model_file: String::from("onnx/model_quantized.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::ParaphraseMLMiniLML12V2Q,
Expand All @@ -188,6 +203,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q"),
model_file: String::from("model_optimized.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::ParaphraseMLMiniLML12V2,
Expand All @@ -196,6 +212,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Xenova/paraphrase-multilingual-MiniLM-L12-v2"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::ParaphraseMLMpnetBaseV2,
Expand All @@ -206,6 +223,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Xenova/paraphrase-multilingual-mpnet-base-v2"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::BGESmallZHV15,
Expand All @@ -214,6 +232,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Xenova/bge-small-zh-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::BGELargeZHV15,
Expand All @@ -222,6 +241,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Xenova/bge-large-zh-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::ModernBertEmbedLarge,
Expand All @@ -230,6 +250,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("lightonai/modernbert-embed-large"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Small,
Expand All @@ -238,6 +259,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("intfloat/multilingual-e5-small"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Base,
Expand All @@ -246,6 +268,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("intfloat/multilingual-e5-base"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Large,
Expand All @@ -254,6 +277,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Qdrant/multilingual-e5-large-onnx"),
model_file: String::from("model.onnx"),
additional_files: vec!["model.onnx_data".to_string()],
output_key: None,
},
ModelInfo {
model: EmbeddingModel::MxbaiEmbedLargeV1,
Expand All @@ -262,6 +286,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::MxbaiEmbedLargeV1Q,
Expand All @@ -270,6 +295,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"),
model_file: String::from("onnx/model_quantized.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::GTEBaseENV15,
Expand All @@ -278,6 +304,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::GTEBaseENV15Q,
Expand All @@ -286,6 +313,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"),
model_file: String::from("onnx/model_quantized.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::GTELargeENV15,
Expand All @@ -294,6 +322,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::GTELargeENV15Q,
Expand All @@ -302,6 +331,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"),
model_file: String::from("onnx/model_quantized.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::ClipVitB32,
Expand All @@ -310,6 +340,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("Qdrant/clip-ViT-B-32-text"),
model_file: String::from("model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::JinaEmbeddingsV2BaseCode,
Expand All @@ -318,6 +349,16 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
model_code: String::from("jinaai/jina-embeddings-v2-base-code"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
output_key: None,
},
ModelInfo {
model: EmbeddingModel::EmbeddingGemma300M,
dim: 768,
description: String::from("EmbeddingGemma is a 300M parameter from Google"),
model_code: String::from("onnx-community/embeddinggemma-300m-ONNX"),
model_file: String::from("onnx/model.onnx"),
additional_files: vec!["onnx/model.onnx_data".to_string()],
output_key: Some(crate::OutputKey::ByName("sentence_embedding")),
},
];

Expand Down
8 changes: 7 additions & 1 deletion src/output/output_precedence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//! e.g. reading the output keys from the model file.

/// Enum for defining the key of the output.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OutputKey {
OnlyOne,
ByOrder(usize),
Expand Down Expand Up @@ -41,3 +41,9 @@ impl OutputPrecedence for &[OutputKey] {
self.iter()
}
}

impl OutputPrecedence for &OutputKey {
fn key_precedence(&self) -> impl Iterator<Item = &OutputKey> {
std::iter::once(*self)
}
}
Loading