Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/candle/src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod cublaslt;
mod layer_norm;
mod linear;
mod radix_mlp;
#[allow(dead_code, unused)]
mod rms_norm;
mod rotary;
Expand All @@ -10,5 +11,7 @@ pub use cublaslt::get_cublas_lt_wrapper;
pub use layer_norm::{LayerNorm, LayerNormNoBias};
pub use linear::{HiddenAct, Linear};
#[allow(unused_imports)]
pub use radix_mlp::CompactUnfoldTensors;
#[allow(unused_imports)]
pub use rms_norm::RMSNorm;
pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling};
82 changes: 82 additions & 0 deletions backends/candle/src/layers/radix_mlp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Published under RadixMLP by Michael Feil
// Copyright (c) 2025 michaelfeil

use candle::{Device, Result, Tensor};
use text_embeddings_backend_core::Batch;

/// Helper struct to manage compact/unfold tensor operations for RadixMLP.

/// if is not compact, all are operations are a no-op.
#[allow(dead_code)]
pub struct CompactUnfoldTensors {
pub scatter_unfold: Option<Tensor>,
pub fold_gather: Option<Tensor>,
pub position_ids_compact: Tensor,
}

#[allow(dead_code)]
impl CompactUnfoldTensors {
/// Create compact/unfold tensors from batch data
// returning the input_ids tensor and the compact/unfold tensors if applicable.
pub fn from_batch(batch: &Batch, device: &Device) -> Result<(Tensor, Self)> {
let shape = batch.input_ids.len();

let (input_ids, compact_tensors) =
if let (Some(compact_ids), Some(compact_pos), Some(scatter), Some(fold)) = (
batch.compact_input_ids.as_ref(),
batch.compact_position_ids.as_ref(),
batch.scatter_unfold.as_ref(),
batch.fold_gather.as_ref(),
) {
let m = compact_ids.len();
let compact_ids_t = Tensor::from_vec(compact_ids.clone(), m, device)?;
let scatter_t = Tensor::from_vec(scatter.clone(), shape, device)?;
let fold_t = Tensor::from_vec(fold.clone(), m, device)?;
let position_ids_compact = Tensor::from_vec(compact_pos.clone(), m, device)?;

(
compact_ids_t,
CompactUnfoldTensors {
scatter_unfold: Some(scatter_t),
fold_gather: Some(fold_t),
position_ids_compact,
},
)
} else {
let input_ids = Tensor::from_vec(batch.input_ids.clone(), shape, device)?;
let position_ids = Tensor::from_vec(batch.position_ids.clone(), shape, device)?;
(
input_ids,
CompactUnfoldTensors {
scatter_unfold: None,
fold_gather: None,
position_ids_compact: position_ids,
},
)
};

Ok((input_ids, compact_tensors))
}

/// Expand compact → original using `scatter_unfold`, if present.
#[inline]
pub fn scatter_unfold(&self, tensor: &Tensor) -> Result<Tensor> {
if let Some(scatter) = &self.scatter_unfold {
tensor.index_select(scatter, 0)?.contiguous()
} else {
Ok(tensor.clone())
}
}

/// Gather original → compact using `fold_gather`, if present.
/// Identity path: returns a shallow handle clone (no device copy).
#[inline]
pub fn fold_gather(&self, tensor: &Tensor) -> Result<Tensor> {
if let Some(gather) = &self.fold_gather {
tensor.index_select(gather, 0)?.contiguous()
} else {
Ok(tensor.clone())
}
}
}
4 changes: 4 additions & 0 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,10 @@ impl Backend for CandleBackend {
self.model.is_padded()
}

fn supports_radix_mlp(&self) -> bool {
self.model.supports_radix_mlp()
}

fn embed(&self, batch: Batch) -> Result<Embeddings, BackendError> {
let batch_size = batch.len();
let pooled_indices = batch.pooled_indices.clone();
Expand Down
46 changes: 35 additions & 11 deletions backends/candle/src/models/flash_mistral.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm};
use crate::layers::{get_cos_sin, get_inv_freqs, CompactUnfoldTensors, HiddenAct, Linear, RMSNorm};
use crate::models::{MistralConfig, Model};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
Expand Down Expand Up @@ -69,6 +69,7 @@ impl MistralAttention {
cos: &Tensor,
sin: &Tensor,
max_s: usize,
compact_tensors: &CompactUnfoldTensors,
) -> Result<Tensor> {
let _enter = self.span.enter();

Expand All @@ -93,6 +94,10 @@ impl MistralAttention {

apply_rotary_inplace(&q, &k, &cos, &sin, true)?;

let q = compact_tensors.scatter_unfold(&q)?;
let k = compact_tensors.scatter_unfold(&k)?;
let v = compact_tensors.scatter_unfold(&v)?;

let attention = flash_attn_varlen(
&q,
&k,
Expand All @@ -109,6 +114,8 @@ impl MistralAttention {
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

let attention = compact_tensors.fold_gather(&attention)?;

self.o_proj.forward(&attention)
}
}
Expand Down Expand Up @@ -207,13 +214,19 @@ impl MistralLayer {
cos: &Tensor,
sin: &Tensor,
max_s: usize,
compact_tensors: &CompactUnfoldTensors,
) -> Result<(Tensor, Tensor)> {
let _enter = self.span.enter();

let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, residual)?;
let attn_output =
self.attention
.forward(&normed_hidden_states, cu_seqlens, cos, sin, max_s)?;
let attn_output = self.attention.forward(
&normed_hidden_states,
cu_seqlens,
cos,
sin,
max_s,
compact_tensors,
)?;
let (normed_attn_res_output, attn_res) = self
.post_attention_layer_norm
.forward(&attn_output, Some(&res))?;
Expand Down Expand Up @@ -296,19 +309,22 @@ impl FlashMistralModel {
let batch_size = batch.cumulative_seq_lengths.len() - 1;
let shape = batch.input_ids.len();

// Create Cuda tensors
let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?;
let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?;
// Create compact/unfold tensors and get embeddings
let (input_ids, compact_tensors) = CompactUnfoldTensors::from_batch(&batch, &self.device)?;
let mut hidden_states = self.embeddings.forward(&input_ids)?.contiguous()?;

let cu_seqlens = Tensor::from_vec(
batch.cumulative_seq_lengths.clone(),
batch_size + 1,
&self.device,
)?;

let mut hidden_states = self.embeddings.forward(&input_ids)?;

let cos = self.cos_cache.index_select(&position_ids, 0)?;
let sin = self.sin_cache.index_select(&position_ids, 0)?;
let cos = self
.cos_cache
.index_select(&compact_tensors.position_ids_compact, 0)?;
let sin = self
.sin_cache
.index_select(&compact_tensors.position_ids_compact, 0)?;

let mut residual = None;
for layer in &self.layers {
Expand All @@ -319,13 +335,16 @@ impl FlashMistralModel {
&cos,
&sin,
batch.max_length as usize,
&compact_tensors,
)?;
hidden_states = h;
residual = Some(r);
}

let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?;

let outputs = compact_tensors.scatter_unfold(&outputs)?;

let has_pooling_requests = !batch.pooled_indices.is_empty();
let has_raw_requests = !batch.raw_indices.is_empty();

Expand Down Expand Up @@ -442,6 +461,11 @@ impl Model for FlashMistralModel {
fn is_padded(&self) -> bool {
false
}

fn supports_radix_mlp(&self) -> bool {
true
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}
Expand Down
3 changes: 3 additions & 0 deletions backends/candle/src/models/flash_qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ impl Qwen2Attention {
max_s,
max_s,
self.softmax_scale,
// TODO: Qwen2 models are generally not causal, this is a bug.
// e.g. https://huggingface.co/jinaai/jina-code-embeddings-0.5b
// breaks for this reason.
false,
None,
None,
Expand Down
51 changes: 39 additions & 12 deletions backends/candle/src/models/flash_qwen3.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm};
use crate::layers::{get_cos_sin, get_inv_freqs, CompactUnfoldTensors, HiddenAct, Linear, RMSNorm};
use crate::models::{Model, Qwen3Config};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
Expand Down Expand Up @@ -109,6 +109,7 @@ impl Qwen3Attention {
cos: &Tensor,
sin: &Tensor,
max_s: usize,
compact_tensors: &CompactUnfoldTensors,
) -> Result<Tensor> {
let _enter = self.span.enter();

Expand Down Expand Up @@ -146,8 +147,14 @@ impl Qwen3Attention {
let (q, _) = self.q_norm.forward(&q, None)?;
let (k, _) = self.k_norm.forward(&k, None)?;

// Apply RoPE in COMPACT space
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;

// Expand Q, K, V to ORIGINAL layout for attention
let q = compact_tensors.scatter_unfold(&q)?;
let k = compact_tensors.scatter_unfold(&k)?;
let v = compact_tensors.scatter_unfold(&v)?;

let attention = flash_attn_varlen(
&q,
&k,
Expand All @@ -164,6 +171,9 @@ impl Qwen3Attention {
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

// Compact attention output back to COMPACT layout before o_proj
let attention = compact_tensors.fold_gather(&attention)?;

self.o_proj.forward(&attention)
}
}
Expand Down Expand Up @@ -262,14 +272,20 @@ impl Qwen3Layer {
cos: &Tensor,
sin: &Tensor,
max_s: usize,
compact_tensors: &CompactUnfoldTensors,
) -> Result<(Tensor, Tensor)> {
let _enter = self.span.enter();

let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, residual)?;

let attn_output =
self.attention
.forward(&normed_hidden_states, cu_seqlens, cos, sin, max_s)?;
let attn_output = self.attention.forward(
&normed_hidden_states,
cu_seqlens,
cos,
sin,
max_s,
compact_tensors,
)?;

let (normed_attn_res_output, attn_res) = self
.post_attention_layer_norm
Expand Down Expand Up @@ -360,21 +376,24 @@ impl FlashQwen3Model {
let _enter = self.span.enter();

let batch_size = batch.cumulative_seq_lengths.len() - 1;
let shape = batch.input_ids.len();

// Create Cuda tensors
let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?;
let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?;
// Create compact/unfold tensors and get embeddings
let (input_ids, compact_tensors) = CompactUnfoldTensors::from_batch(&batch, &self.device)?;
let mut hidden_states = self.embeddings.forward(&input_ids)?.contiguous()?;

let cu_seqlens = Tensor::from_vec(
batch.cumulative_seq_lengths.clone(),
batch_size + 1,
&self.device,
)?;

let mut hidden_states = self.embeddings.forward(&input_ids)?;

let cos = self.cos_cache.index_select(&position_ids, 0)?;
let sin = self.sin_cache.index_select(&position_ids, 0)?;
// sin and cos are applied on the compact formation, therefore should be on the compact array
let cos = self
.cos_cache
.index_select(&compact_tensors.position_ids_compact, 0)?;
let sin = self
.sin_cache
.index_select(&compact_tensors.position_ids_compact, 0)?;

let mut residual = None;
for layer in &self.layers {
Expand All @@ -385,13 +404,16 @@ impl FlashQwen3Model {
&cos,
&sin,
batch.max_length as usize,
&compact_tensors,
)?;
hidden_states = h;
residual = Some(r);
}

let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?;

// Expand final outputs to original layout for pooling/raw extraction
let outputs = compact_tensors.scatter_unfold(&outputs)?;
let has_pooling_requests = !batch.pooled_indices.is_empty();
let has_raw_requests = !batch.raw_indices.is_empty();

Expand Down Expand Up @@ -474,6 +496,7 @@ impl FlashQwen3Model {
let raw_embeddings = if has_raw_requests {
if batch_size > 1 && has_pooling_requests {
// Create indexing vector for the embeddings
let shape = batch.input_ids.len();
let mut final_indices: Vec<u32> = Vec::with_capacity(shape);
for i in batch.raw_indices.into_iter() {
let i = i as usize;
Expand Down Expand Up @@ -509,6 +532,10 @@ impl Model for FlashQwen3Model {
false
}

fn supports_radix_mlp(&self) -> bool {
true
}

fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}
Expand Down
4 changes: 4 additions & 0 deletions backends/candle/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ pub use flash_qwen3::FlashQwen3Model;
pub(crate) trait Model {
fn is_padded(&self) -> bool;

fn supports_radix_mlp(&self) -> bool {
false
}

fn embed(&self, _batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
candle::bail!("`embed` is not implemented for this model");
}
Expand Down
Loading