diff --git a/backends/candle/src/layers/mod.rs b/backends/candle/src/layers/mod.rs index 24eb4cf71..ab98a05cb 100644 --- a/backends/candle/src/layers/mod.rs +++ b/backends/candle/src/layers/mod.rs @@ -2,6 +2,7 @@ mod cublaslt; mod layer_norm; mod linear; +mod radix_mlp; #[allow(dead_code, unused)] mod rms_norm; mod rotary; @@ -10,5 +11,7 @@ pub use cublaslt::get_cublas_lt_wrapper; pub use layer_norm::{LayerNorm, LayerNormNoBias}; pub use linear::{HiddenAct, Linear}; #[allow(unused_imports)] +pub use radix_mlp::CompactUnfoldTensors; +#[allow(unused_imports)] pub use rms_norm::RMSNorm; pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling}; diff --git a/backends/candle/src/layers/radix_mlp.rs b/backends/candle/src/layers/radix_mlp.rs new file mode 100644 index 000000000..8f536712c --- /dev/null +++ b/backends/candle/src/layers/radix_mlp.rs @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Published under RadixMLP by Michael Feil +// Copyright (c) 2025 michaelfeil + +use candle::{Device, Result, Tensor}; +use text_embeddings_backend_core::Batch; + +/// Helper struct to manage compact/unfold tensor operations for RadixMLP. + +/// if is not compact, all are operations are a no-op. +#[allow(dead_code)] +pub struct CompactUnfoldTensors { + pub scatter_unfold: Option, + pub fold_gather: Option, + pub position_ids_compact: Tensor, +} + +#[allow(dead_code)] +impl CompactUnfoldTensors { + /// Create compact/unfold tensors from batch data + // returning the input_ids tensor and the compact/unfold tensors if applicable. + pub fn from_batch(batch: &Batch, device: &Device) -> Result<(Tensor, Self)> { + let shape = batch.input_ids.len(); + + let (input_ids, compact_tensors) = + if let (Some(compact_ids), Some(compact_pos), Some(scatter), Some(fold)) = ( + batch.compact_input_ids.as_ref(), + batch.compact_position_ids.as_ref(), + batch.scatter_unfold.as_ref(), + batch.fold_gather.as_ref(), + ) { + let m = compact_ids.len(); + let compact_ids_t = Tensor::from_vec(compact_ids.clone(), m, device)?; + let scatter_t = Tensor::from_vec(scatter.clone(), shape, device)?; + let fold_t = Tensor::from_vec(fold.clone(), m, device)?; + let position_ids_compact = Tensor::from_vec(compact_pos.clone(), m, device)?; + + ( + compact_ids_t, + CompactUnfoldTensors { + scatter_unfold: Some(scatter_t), + fold_gather: Some(fold_t), + position_ids_compact, + }, + ) + } else { + let input_ids = Tensor::from_vec(batch.input_ids.clone(), shape, device)?; + let position_ids = Tensor::from_vec(batch.position_ids.clone(), shape, device)?; + ( + input_ids, + CompactUnfoldTensors { + scatter_unfold: None, + fold_gather: None, + position_ids_compact: position_ids, + }, + ) + }; + + Ok((input_ids, compact_tensors)) + } + + /// Expand compact → original using `scatter_unfold`, if present. + #[inline] + pub fn scatter_unfold(&self, tensor: &Tensor) -> Result { + if let Some(scatter) = &self.scatter_unfold { + tensor.index_select(scatter, 0)?.contiguous() + } else { + Ok(tensor.clone()) + } + } + + /// Gather original → compact using `fold_gather`, if present. + /// Identity path: returns a shallow handle clone (no device copy). + #[inline] + pub fn fold_gather(&self, tensor: &Tensor) -> Result { + if let Some(gather) = &self.fold_gather { + tensor.index_select(gather, 0)?.contiguous() + } else { + Ok(tensor.clone()) + } + } +} diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index ff824f555..5bf6e0b18 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -588,6 +588,10 @@ impl Backend for CandleBackend { self.model.is_padded() } + fn supports_radix_mlp(&self) -> bool { + self.model.supports_radix_mlp() + } + fn embed(&self, batch: Batch) -> Result { let batch_size = batch.len(); let pooled_indices = batch.pooled_indices.clone(); diff --git a/backends/candle/src/models/flash_mistral.rs b/backends/candle/src/models/flash_mistral.rs index c8488f360..44656b4b6 100644 --- a/backends/candle/src/models/flash_mistral.rs +++ b/backends/candle/src/models/flash_mistral.rs @@ -1,5 +1,5 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; +use crate::layers::{get_cos_sin, get_inv_freqs, CompactUnfoldTensors, HiddenAct, Linear, RMSNorm}; use crate::models::{MistralConfig, Model}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; @@ -69,6 +69,7 @@ impl MistralAttention { cos: &Tensor, sin: &Tensor, max_s: usize, + compact_tensors: &CompactUnfoldTensors, ) -> Result { let _enter = self.span.enter(); @@ -93,6 +94,10 @@ impl MistralAttention { apply_rotary_inplace(&q, &k, &cos, &sin, true)?; + let q = compact_tensors.scatter_unfold(&q)?; + let k = compact_tensors.scatter_unfold(&k)?; + let v = compact_tensors.scatter_unfold(&v)?; + let attention = flash_attn_varlen( &q, &k, @@ -109,6 +114,8 @@ impl MistralAttention { )?; let attention = attention.flatten_from(candle::D::Minus2)?; + let attention = compact_tensors.fold_gather(&attention)?; + self.o_proj.forward(&attention) } } @@ -207,13 +214,19 @@ impl MistralLayer { cos: &Tensor, sin: &Tensor, max_s: usize, + compact_tensors: &CompactUnfoldTensors, ) -> Result<(Tensor, Tensor)> { let _enter = self.span.enter(); let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, residual)?; - let attn_output = - self.attention - .forward(&normed_hidden_states, cu_seqlens, cos, sin, max_s)?; + let attn_output = self.attention.forward( + &normed_hidden_states, + cu_seqlens, + cos, + sin, + max_s, + compact_tensors, + )?; let (normed_attn_res_output, attn_res) = self .post_attention_layer_norm .forward(&attn_output, Some(&res))?; @@ -296,19 +309,22 @@ impl FlashMistralModel { let batch_size = batch.cumulative_seq_lengths.len() - 1; let shape = batch.input_ids.len(); - // Create Cuda tensors - let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?; - let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?; + // Create compact/unfold tensors and get embeddings + let (input_ids, compact_tensors) = CompactUnfoldTensors::from_batch(&batch, &self.device)?; + let mut hidden_states = self.embeddings.forward(&input_ids)?.contiguous()?; + let cu_seqlens = Tensor::from_vec( batch.cumulative_seq_lengths.clone(), batch_size + 1, &self.device, )?; - let mut hidden_states = self.embeddings.forward(&input_ids)?; - - let cos = self.cos_cache.index_select(&position_ids, 0)?; - let sin = self.sin_cache.index_select(&position_ids, 0)?; + let cos = self + .cos_cache + .index_select(&compact_tensors.position_ids_compact, 0)?; + let sin = self + .sin_cache + .index_select(&compact_tensors.position_ids_compact, 0)?; let mut residual = None; for layer in &self.layers { @@ -319,6 +335,7 @@ impl FlashMistralModel { &cos, &sin, batch.max_length as usize, + &compact_tensors, )?; hidden_states = h; residual = Some(r); @@ -326,6 +343,8 @@ impl FlashMistralModel { let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?; + let outputs = compact_tensors.scatter_unfold(&outputs)?; + let has_pooling_requests = !batch.pooled_indices.is_empty(); let has_raw_requests = !batch.raw_indices.is_empty(); @@ -442,6 +461,11 @@ impl Model for FlashMistralModel { fn is_padded(&self) -> bool { false } + + fn supports_radix_mlp(&self) -> bool { + true + } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } diff --git a/backends/candle/src/models/flash_qwen2.rs b/backends/candle/src/models/flash_qwen2.rs index c9116311a..220225ff6 100644 --- a/backends/candle/src/models/flash_qwen2.rs +++ b/backends/candle/src/models/flash_qwen2.rs @@ -111,6 +111,9 @@ impl Qwen2Attention { max_s, max_s, self.softmax_scale, + // TODO: Qwen2 models are generally not causal, this is a bug. + // e.g. https://huggingface.co/jinaai/jina-code-embeddings-0.5b + // breaks for this reason. false, None, None, diff --git a/backends/candle/src/models/flash_qwen3.rs b/backends/candle/src/models/flash_qwen3.rs index 10f27bddf..61cc37be0 100644 --- a/backends/candle/src/models/flash_qwen3.rs +++ b/backends/candle/src/models/flash_qwen3.rs @@ -1,5 +1,5 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; +use crate::layers::{get_cos_sin, get_inv_freqs, CompactUnfoldTensors, HiddenAct, Linear, RMSNorm}; use crate::models::{Model, Qwen3Config}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; @@ -109,6 +109,7 @@ impl Qwen3Attention { cos: &Tensor, sin: &Tensor, max_s: usize, + compact_tensors: &CompactUnfoldTensors, ) -> Result { let _enter = self.span.enter(); @@ -146,8 +147,14 @@ impl Qwen3Attention { let (q, _) = self.q_norm.forward(&q, None)?; let (k, _) = self.k_norm.forward(&k, None)?; + // Apply RoPE in COMPACT space apply_rotary_inplace(&q, &k, &cos, &sin, true)?; + // Expand Q, K, V to ORIGINAL layout for attention + let q = compact_tensors.scatter_unfold(&q)?; + let k = compact_tensors.scatter_unfold(&k)?; + let v = compact_tensors.scatter_unfold(&v)?; + let attention = flash_attn_varlen( &q, &k, @@ -164,6 +171,9 @@ impl Qwen3Attention { )?; let attention = attention.flatten_from(candle::D::Minus2)?; + // Compact attention output back to COMPACT layout before o_proj + let attention = compact_tensors.fold_gather(&attention)?; + self.o_proj.forward(&attention) } } @@ -262,14 +272,20 @@ impl Qwen3Layer { cos: &Tensor, sin: &Tensor, max_s: usize, + compact_tensors: &CompactUnfoldTensors, ) -> Result<(Tensor, Tensor)> { let _enter = self.span.enter(); let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, residual)?; - let attn_output = - self.attention - .forward(&normed_hidden_states, cu_seqlens, cos, sin, max_s)?; + let attn_output = self.attention.forward( + &normed_hidden_states, + cu_seqlens, + cos, + sin, + max_s, + compact_tensors, + )?; let (normed_attn_res_output, attn_res) = self .post_attention_layer_norm @@ -360,21 +376,24 @@ impl FlashQwen3Model { let _enter = self.span.enter(); let batch_size = batch.cumulative_seq_lengths.len() - 1; - let shape = batch.input_ids.len(); - // Create Cuda tensors - let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?; - let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?; + // Create compact/unfold tensors and get embeddings + let (input_ids, compact_tensors) = CompactUnfoldTensors::from_batch(&batch, &self.device)?; + let mut hidden_states = self.embeddings.forward(&input_ids)?.contiguous()?; + let cu_seqlens = Tensor::from_vec( batch.cumulative_seq_lengths.clone(), batch_size + 1, &self.device, )?; - let mut hidden_states = self.embeddings.forward(&input_ids)?; - - let cos = self.cos_cache.index_select(&position_ids, 0)?; - let sin = self.sin_cache.index_select(&position_ids, 0)?; + // sin and cos are applied on the compact formation, therefore should be on the compact array + let cos = self + .cos_cache + .index_select(&compact_tensors.position_ids_compact, 0)?; + let sin = self + .sin_cache + .index_select(&compact_tensors.position_ids_compact, 0)?; let mut residual = None; for layer in &self.layers { @@ -385,6 +404,7 @@ impl FlashQwen3Model { &cos, &sin, batch.max_length as usize, + &compact_tensors, )?; hidden_states = h; residual = Some(r); @@ -392,6 +412,8 @@ impl FlashQwen3Model { let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?; + // Expand final outputs to original layout for pooling/raw extraction + let outputs = compact_tensors.scatter_unfold(&outputs)?; let has_pooling_requests = !batch.pooled_indices.is_empty(); let has_raw_requests = !batch.raw_indices.is_empty(); @@ -474,6 +496,7 @@ impl FlashQwen3Model { let raw_embeddings = if has_raw_requests { if batch_size > 1 && has_pooling_requests { // Create indexing vector for the embeddings + let shape = batch.input_ids.len(); let mut final_indices: Vec = Vec::with_capacity(shape); for i in batch.raw_indices.into_iter() { let i = i as usize; @@ -509,6 +532,10 @@ impl Model for FlashQwen3Model { false } + fn supports_radix_mlp(&self) -> bool { + true + } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index 424f4e984..fb9a942fc 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -98,6 +98,10 @@ pub use flash_qwen3::FlashQwen3Model; pub(crate) trait Model { fn is_padded(&self) -> bool; + fn supports_radix_mlp(&self) -> bool { + false + } + fn embed(&self, _batch: Batch) -> Result<(Option, Option)> { candle::bail!("`embed` is not implemented for this model"); } diff --git a/backends/candle/tests/common.rs b/backends/candle/tests/common.rs index 896e0f65b..10c27fb29 100644 --- a/backends/candle/tests/common.rs +++ b/backends/candle/tests/common.rs @@ -85,7 +85,7 @@ impl Deref for SnapshotEmbeddings { impl From>> for SnapshotEmbeddings { fn from(value: Vec>) -> Self { - Self(value.into_iter().map(|v| SnapEmbedding(v)).collect()) + Self(value.into_iter().map(SnapEmbedding).collect()) } } @@ -181,7 +181,7 @@ pub fn download_artifacts( } _ => { for path in &paths { - download_dense_module(&api_repo, &path)?; + download_dense_module(&api_repo, path)?; } Some(paths) } @@ -350,5 +350,9 @@ pub fn batch(encodings: Vec, pooled_indices: Vec, raw_indices: Ve max_length, pooled_indices, raw_indices, + compact_input_ids: None, + compact_position_ids: None, + scatter_unfold: None, + fold_gather: None, } } diff --git a/backends/core/src/lib.rs b/backends/core/src/lib.rs index 8e134d2be..e9b9cc5eb 100644 --- a/backends/core/src/lib.rs +++ b/backends/core/src/lib.rs @@ -14,6 +14,10 @@ pub struct Batch { pub max_length: u32, pub pooled_indices: Vec, pub raw_indices: Vec, + pub compact_input_ids: Option>, + pub compact_position_ids: Option>, + pub scatter_unfold: Option>, + pub fold_gather: Option>, } impl Batch { @@ -42,6 +46,10 @@ pub trait Backend { fn is_padded(&self) -> bool; + fn supports_radix_mlp(&self) -> bool { + false + } + fn embed(&self, batch: Batch) -> Result; fn predict(&self, batch: Batch) -> Result; diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 245715b38..9e323864f 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -75,6 +75,7 @@ pub struct Backend { health_receiver: watch::Receiver, _backend_thread: Arc, pub padded_model: bool, + pub radix_mlp_supported: bool, pub max_batch_size: Option, pub model_type: ModelType, } @@ -105,6 +106,7 @@ impl Backend { ) .await?; let padded_model = backend.is_padded(); + let radix_mlp_supported = backend.supports_radix_mlp(); let max_batch_size = backend.max_batch_size(); let (health_sender, health_receiver) = watch::channel(false); @@ -116,6 +118,7 @@ impl Backend { health_receiver, _backend_thread, padded_model, + radix_mlp_supported, max_batch_size, model_type, }) @@ -223,6 +226,10 @@ impl Backend { max_length: tmp_length, pooled_indices, raw_indices: vec![], + compact_input_ids: None, + compact_position_ids: None, + fold_gather: None, + scatter_unfold: None, } } @@ -280,6 +287,10 @@ impl Backend { max_length, pooled_indices, raw_indices: vec![], + compact_input_ids: None, + compact_position_ids: None, + fold_gather: None, + scatter_unfold: None, }; match &self.model_type { @@ -314,6 +325,10 @@ impl Backend { max_length: 1, pooled_indices: vec![0], raw_indices: vec![], + compact_input_ids: None, + compact_position_ids: None, + fold_gather: None, + scatter_unfold: None, }; match &self.model_type { ModelType::Classifier => self.predict(batch).await.map(|_| ()), @@ -611,6 +626,7 @@ async fn download_safetensors(api: &ApiRepo) -> Result, ApiError> { } // Download weight files + // TODO: Parallelize all files. let mut safetensors_files = Vec::new(); for n in safetensors_filenames { tracing::info!("Downloading `{}`", n); diff --git a/core/src/lib.rs b/core/src/lib.rs index 4c41f4f34..c0e3b35f0 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,6 +1,7 @@ pub mod download; pub mod infer; pub mod queue; +pub mod radix_mlp; pub mod tokenization; use text_embeddings_backend::BackendError; diff --git a/core/src/queue.rs b/core/src/queue.rs index 3fd8b7715..8afe8b3ba 100644 --- a/core/src/queue.rs +++ b/core/src/queue.rs @@ -1,4 +1,5 @@ use crate::infer::InferResult; +use crate::radix_mlp; use crate::tokenization::ValidEncoding; use std::cmp::max; use std::collections::VecDeque; @@ -43,6 +44,7 @@ impl Queue { padded_model: bool, max_batch_tokens: usize, max_batch_requests: Option, + radix_mlp_threshold: f32, max_concurrent_requests: usize, ) -> Self { // Create channels @@ -54,6 +56,7 @@ impl Queue { padded_model, max_batch_tokens, max_batch_requests, + radix_mlp_threshold, max_concurrent_requests, queue_receiver, ) @@ -98,10 +101,14 @@ fn queue_blocking_task( padded_model: bool, max_batch_tokens: usize, max_batch_requests: Option, + radix_mlp_threshold: f32, max_concurrent_requests: usize, mut queue_receiver: mpsc::Receiver, ) { let capacity = max_batch_requests.unwrap_or(max_concurrent_requests); + let radix_mlp_pad = std::env::var("RADIX_MLP_PAD") + .map(|s| s.to_lowercase() == "true") + .unwrap_or(false); let mut entries: VecDeque = VecDeque::with_capacity(max_concurrent_requests); @@ -179,6 +186,41 @@ fn queue_blocking_task( } } + // Compute RadixMLP compact representation with BOTH mappings + let (compact_input_ids, compact_position_ids, scatter_unfold, fold_gather) = + if radix_mlp_threshold > 1e-6 && !input_ids.is_empty() { + let (compact_ids, compact_pos, scatter, fold) = + radix_mlp::compute_fold_and_scatter( + &input_ids, + &position_ids, + &cu_seq_lengths, + radix_mlp_pad, + ); + + // Only use if we achieved meaningful compression + let compression_ratio = compact_ids.len() as f32 / input_ids.len() as f32; + tracing::info!( + "RadixMLP compression ratio: {:.2} ({} -> {})", + compression_ratio, + input_ids.len(), + compact_ids.len() + ); + metrics::histogram!("te_radix_mlp_compression_ratio") + .record(compression_ratio as f64); + if radix_mlp_threshold < 1.0 && compression_ratio < radix_mlp_threshold { + ( + Some(compact_ids), + Some(compact_pos), + Some(scatter), + Some(fold), + ) + } else { + (None, None, None, None) + } + } else { + (None, None, None, None) + }; + let batch_size = metadata.len(); let next_batch = if metadata.is_empty() { None @@ -193,6 +235,10 @@ fn queue_blocking_task( max_length, pooled_indices, raw_indices, + compact_input_ids, + compact_position_ids, + scatter_unfold, + fold_gather, // Add the second mapping }, )) }; diff --git a/core/src/radix_mlp.rs b/core/src/radix_mlp.rs new file mode 100644 index 000000000..39a713dba --- /dev/null +++ b/core/src/radix_mlp.rs @@ -0,0 +1,1006 @@ +// SPDX-License-Identifier: MIT +// Published under RadixMLP by Michael Feil +// Copyright (c) 2025 michaelfeil + +/// Computes indices for RadixMLP-style folding and scattering to enable prefix-based computation sharing. +/// +/// This function identifies shared prefixes among sequences in a batch. For a batch of token +/// sequences, it produces a "compacted" representation containing only the unique subsequences +/// encountered. It also generates index maps to "scatter" (unfold) results from the compact +/// representation back to the original batch structure and to "gather" (fold) the original +/// inputs into the compact form. +/// +/// The core idea is to build a prefix tree (trie) over the sequences, where each node represents +/// a unique `(token_id, position_id)` pair in a specific path. This allows deduplication of +/// identical sub-sequences across the batch. +/// +/// # Arguments +/// +/// * `input_ids`: A flattened vector of token IDs for all sequences in the batch. +/// * `position_ids`: A flattened vector of position IDs corresponding to each token in `input_ids`. +/// * `cu_seq_lengths`: Cumulative sequence lengths, e.g., `[0, len_seq1, len_seq1 + len_seq2, ...]`. +/// This defines the boundaries of each sequence in the flattened `input_ids` and `position_ids`. +/// * `pad_multiple_of`: If `true`, the output compact vectors are padded to a multiple of 8 (for +/// small tensors) or 64 (for larger ones) to improve performance on certain hardware (e.g., cuBLAS). +/// +/// # Returns +/// +/// A tuple containing four vectors: +/// +/// 1. `compact_input_ids`: A vector of the unique token IDs, representing the compacted data. +/// Each unique prefix path from the input sequences appears only once. +/// 2. `compact_position_ids`: The corresponding position IDs for `compact_input_ids`. +/// 3. `scatter_indices`: An index map to unfold data from the compact space to the original +/// batch space. It has the same length as the original `input_ids`. +/// `unfolded[i] = compact[scatter_indices[i]]`. +/// 4. `fold_gather`: An index map to gather data from the original batch space to the compact +/// space. It has the same length as the `compact_input_ids`. Each index points to the +/// *first occurrence* of that unique `(token, position)` pair in the original `input_ids`. +/// `compact[j] = original[fold_gather[j]]`. +pub fn compute_fold_and_scatter( + input_ids: &[u32], + position_ids: &[u32], + cu_seq_lengths: &[u32], + pad_multiple_of: bool, +) -> (Vec, Vec, Vec, Vec) { + // Empty fast-path + if input_ids.is_empty() { + return (Vec::new(), Vec::new(), Vec::new(), Vec::new()); + } + + // Single-sequence fast-path: identity + if cu_seq_lengths.len() == 2 { + let mut compact_input_ids = input_ids.to_vec(); + let mut compact_position_ids = position_ids.to_vec(); + let mut fold_gather: Vec = (0..input_ids.len() as u32).collect(); + let scatter_indices = fold_gather.clone(); + + if pad_multiple_of { + pad_to_multiple( + &mut compact_input_ids, + &mut compact_position_ids, + &mut fold_gather, + ); + } + + return ( + compact_input_ids, + compact_position_ids, + scatter_indices, + fold_gather, + ); + } + + #[inline] + fn make_key(token: u32, pos: u32) -> u64 { + ((pos as u64) << 32) | (token as u64) + } + + // Pad to a multiple of 8 or 64 for performance if requested. + #[inline] + fn pad_to_multiple( + compact_input_ids: &mut Vec, + compact_position_ids: &mut Vec, + fold_gather: &mut Vec, + ) { + let current_len = compact_input_ids.len(); + if current_len == 0 { + return; + } + + let multiple = if current_len < 1024 { 8 } else { 64 }; + let remainder = current_len % multiple; + + if remainder != 0 { + let padding_needed = multiple - remainder; + compact_input_ids.reserve(padding_needed); + compact_position_ids.reserve(padding_needed); + fold_gather.reserve(padding_needed); + for _ in 0..padding_needed { + compact_input_ids.push(0); // Pad with token 0 + compact_position_ids.push(0); // Pad with position 0 + fold_gather.push(0); // Pad with index 0 + } + } + } + + #[derive(Debug)] + struct Node { + compact: u32, // u32::MAX => not assigned yet + children: Vec<(u64, usize)>, // sorted by key + } + + let n = input_ids.len(); + + // Arena of nodes; index 0 is a synthetic root. + let mut nodes: Vec = Vec::with_capacity(n + 1); + nodes.push(Node { + compact: u32::MAX, + children: Vec::new(), + }); + + // Outputs (pre-reserve generously to avoid reallocs) + let mut compact_input_ids: Vec = Vec::with_capacity(n); + let mut compact_position_ids: Vec = Vec::with_capacity(n); + let mut fold_gather: Vec = Vec::with_capacity(n); + let mut scatter_indices: Vec = Vec::with_capacity(n); + + let mut next_compact: u32 = 0; + + // -------- Single pass: build trie + produce all mappings -------- + for s in 0..cu_seq_lengths.len().saturating_sub(1) { + let start = cu_seq_lengths[s] as usize; + let end = cu_seq_lengths[s + 1] as usize; + + let mut parent = 0usize; // start from root + for i in start..end { + let t = input_ids[i]; + let p = position_ids[i]; + let k = make_key(t, p); + + // immutable lookup to find child or insertion point + let (exists, val) = { + let children = &nodes[parent].children; + match children.binary_search_by_key(&k, |&(key, _)| key) { + Ok(pos) => (true, children[pos].1), + Err(pos) => (false, pos), + } + }; + + let child_idx = if exists { + val + } else { + // create new node + let insert_pos = val; + let idx = nodes.len(); + nodes.push(Node { + compact: next_compact, // assign compact immediately + children: Vec::new(), + }); + // insert into parent's sorted children + nodes[parent].children.insert(insert_pos, (k, idx)); + + // record compact stream + first occurrence position + compact_input_ids.push(t); + compact_position_ids.push(p); + fold_gather.push(i as u32); + + next_compact += 1; + idx + }; + + // scatter: original position -> compact index + scatter_indices.push(nodes[child_idx].compact); + + parent = child_idx; + } + } + + // If no reduction happened, the streams equal identity (creation order == input order). + // That already satisfies your tests, so just return what we built. + + // Pad to a multiple of 8 for cublas performance if requested. + if pad_multiple_of { + pad_to_multiple( + &mut compact_input_ids, + &mut compact_position_ids, + &mut fold_gather, + ); + } + + ( + compact_input_ids, + compact_position_ids, + scatter_indices, + fold_gather, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_compute_fold_and_scatter_empty() { + let input_ids: Vec = vec![]; + let position_ids: Vec = vec![]; + let cu_seq_lengths: Vec = vec![]; + + let (compact_input_ids, compact_position_ids, scatter_indices, fold_gather) = + compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, false); + + assert_eq!(compact_input_ids, vec![] as Vec); + assert_eq!(compact_position_ids, vec![] as Vec); + assert_eq!(scatter_indices, vec![] as Vec); + assert_eq!(fold_gather, vec![] as Vec); + } + + #[test] + fn test_compute_fold_and_scatter_single_sequence() { + // Single sequence: [a, b, c] + let input_ids = vec![1, 2, 3]; + let position_ids = vec![0, 1, 2]; + let cu_seq_lengths = vec![0, 3]; + + let (compact_input_ids, compact_position_ids, scatter_indices, fold_gather) = + compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, false); + + // No deduplication possible with single sequence + assert_eq!(compact_input_ids, vec![1, 2, 3]); + assert_eq!(compact_position_ids, vec![0, 1, 2]); + assert_eq!(scatter_indices, vec![0, 1, 2]); + assert_eq!(fold_gather, vec![0, 1, 2]); + } + + #[test] + fn test_compute_fold_and_scatter_example_from_comments() { + // Example from comments: + // tokens = [a,b,c,d,e,f,g, a,b,c, e,f,g,h,i] + // pos = [0,1,2,3,4,5,6, 0,1,2, 3,4,5,6,7] + // cu_seqlen = [0,7,10,15] + // Expected folded: + // tokens = [a,b,c, d,e,f,g, e,f,g,h,i] + // pos = [0,1,2, 3,4,5,6, 3,4,5,6,7] + + let input_ids = vec![1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 5, 6, 7, 8, 9]; + let position_ids = vec![0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7]; + let cu_seq_lengths = vec![0, 7, 10, 15]; + + let (compact_input_ids, compact_position_ids, scatter_indices, fold_gather) = + compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, false); + + // Should deduplicate shared prefix [a,b,c] at positions [0,1,2] + // and shared subsequence [e,f,g] at positions [3,4,5] + assert_eq!(compact_input_ids.len(), 12); // Reduced from 15 to 12 + assert_eq!(compact_position_ids.len(), 12); + assert_eq!(scatter_indices.len(), 15); // Original length preserved + assert_eq!(fold_gather.len(), 12); // Same as compact length + + // Verify that we can reconstruct original sequences using scatter indices + for i in 0..input_ids.len() { + let compact_idx = scatter_indices[i] as usize; + assert_eq!(input_ids[i], compact_input_ids[compact_idx]); + assert_eq!(position_ids[i], compact_position_ids[compact_idx]); + } + } + + #[test] + fn test_compute_fold_and_scatter_identical_sequences() { + // Two identical sequences: [a,b,c] and [a,b,c] + let input_ids = vec![1, 2, 3, 1, 2, 3]; + let position_ids = vec![0, 1, 2, 0, 1, 2]; + let cu_seq_lengths = vec![0, 3, 6]; + + let (compact_input_ids, compact_position_ids, scatter_indices, fold_gather) = + compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, false); + + // Should completely deduplicate to single sequence + assert_eq!(compact_input_ids, vec![1, 2, 3]); + assert_eq!(compact_position_ids, vec![0, 1, 2]); + assert_eq!(scatter_indices, vec![0, 1, 2, 0, 1, 2]); + assert_eq!(fold_gather, vec![0, 1, 2]); + } + + #[test] + fn test_fold_gather_points_to_first_occurrence() { + // Two sequences with overlapping prefixes/suffixes + // S1: a b c d + // S2: a b e f + let input_ids = vec![1, 2, 3, 4, 1, 2, 5, 6]; + let position_ids = vec![0, 1, 2, 3, 0, 1, 2, 3]; + let cu = vec![0, 4, 8]; + + let (compact_ids, compact_pos, scatter, fold_gather) = + compute_fold_and_scatter(&input_ids, &position_ids, &cu, false); + + // For each compact index, compute the minimal original position that maps to it. + let mut mins = vec![u32::MAX; compact_ids.len()]; + for (orig_idx, &cidx) in scatter.iter().enumerate() { + mins[cidx as usize] = mins[cidx as usize].min(orig_idx as u32); + } + + assert_eq!(mins.len(), fold_gather.len()); + for (i, (&m, &fg)) in mins.iter().zip(fold_gather.iter()).enumerate() { + assert_eq!(m, fg, "fold_gather[{}] should be first occurrence index", i); + // sanity: the pair at fold_gather matches compact pair at i + let fi = fg as usize; + assert_eq!(input_ids[fi], compact_ids[i]); + assert_eq!(position_ids[fi], compact_pos[i]); + } + } + + #[test] + fn test_compute_fold_and_scatter_no_overlap() { + // Two sequences with no overlap: [a,b] and [c,d] + let input_ids = vec![1, 2, 3, 4]; + let position_ids = vec![0, 1, 0, 1]; + let cu_seq_lengths = vec![0, 2, 4]; + + let (compact_input_ids, compact_position_ids, scatter_indices, fold_gather) = + compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, false); + + // No deduplication possible + assert_eq!(compact_input_ids, vec![1, 2, 3, 4]); + assert_eq!(compact_position_ids, vec![0, 1, 0, 1]); + assert_eq!(scatter_indices, vec![0, 1, 2, 3]); + assert_eq!(fold_gather, vec![0, 1, 2, 3]); + } + + #[test] + fn test_compute_fold_and_scatter_partial_overlap() { + // Sequences: [a,b,c] and [a,b,d] + let input_ids = vec![1, 2, 3, 1, 2, 4]; + let position_ids = vec![0, 1, 2, 0, 1, 2]; + let cu_seq_lengths = vec![0, 3, 6]; + + let (compact_input_ids, compact_position_ids, scatter_indices, fold_gather) = + compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, false); + + // Should deduplicate shared prefix [a,b] at positions [0,1] + assert_eq!(compact_input_ids.len(), 4); // [a,b,c,d] in some order + assert_eq!(compact_position_ids.len(), 4); + assert_eq!(scatter_indices.len(), 6); + assert_eq!(fold_gather.len(), 4); + + // Verify reconstruction + for i in 0..input_ids.len() { + let compact_idx = scatter_indices[i] as usize; + assert_eq!(input_ids[i], compact_input_ids[compact_idx]); + assert_eq!(position_ids[i], compact_position_ids[compact_idx]); + } + } + + #[test] + fn test_compute_fold_and_scatter_different_positions() { + // Same tokens but different positions: [a,b] at [0,1] and [a,b] at [2,3] + let input_ids = vec![1, 2, 1, 2]; + let position_ids = vec![0, 1, 2, 3]; + let cu_seq_lengths = vec![0, 2, 4]; + + let (compact_input_ids, compact_position_ids, scatter_indices, fold_gather) = + compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, false); + + // Should NOT deduplicate because positions are different + assert_eq!(compact_input_ids.len(), 4); + assert_eq!(compact_position_ids.len(), 4); + assert_eq!(scatter_indices, vec![0, 1, 2, 3]); + assert_eq!(fold_gather, vec![0, 1, 2, 3]); + } + + #[test] + fn test_compute_fold_and_scatter_three_sequences_complex() { + // Three sequences with various overlaps: + // Seq1: [a,b,c,d] at [0,1,2,3] + // Seq2: [a,b,e,f] at [0,1,2,3] + // Seq3: [a,b,c,g] at [0,1,2,3] + let input_ids = vec![1, 2, 3, 4, 1, 2, 5, 6, 1, 2, 3, 7]; + let position_ids = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]; + let cu_seq_lengths = vec![0, 4, 8, 12]; + + let (compact_input_ids, compact_position_ids, scatter_indices, fold_gather) = + compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, false); + + // Should deduplicate: + // - [a,b] at [0,1] shared by all three + // - [c] at [2] shared by seq1 and seq3 + assert!(compact_input_ids.len() < 12); // Some deduplication should occur + assert_eq!(scatter_indices.len(), 12); + assert_eq!(fold_gather.len(), compact_input_ids.len()); + + // Verify reconstruction + for i in 0..input_ids.len() { + let compact_idx = scatter_indices[i] as usize; + assert_eq!(input_ids[i], compact_input_ids[compact_idx]); + assert_eq!(position_ids[i], compact_position_ids[compact_idx]); + } + } + + #[test] + fn test_compute_fold_and_scatter_edge_case_single_token() { + // Multiple single-token sequences + let input_ids = vec![1, 2, 1]; + let position_ids = vec![0, 0, 0]; + let cu_seq_lengths = vec![0, 1, 2, 3]; + + let (compact_input_ids, _compact_position_ids, scatter_indices, fold_gather) = + compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, false); + + // Should deduplicate token 1 at position 0 + assert_eq!(compact_input_ids.len(), 2); // [1, 2] + assert_eq!(scatter_indices, vec![0, 1, 0]); // First and third map to same compact index + assert_eq!(fold_gather.len(), 2); + } + + #[test] + fn test_compute_fold_and_scatter_deterministic_ordering() { + // Test that the function produces consistent results + let input_ids = vec![1, 2, 3, 1, 2, 4]; + let position_ids = vec![0, 1, 2, 0, 1, 2]; + let cu_seq_lengths = vec![0, 3, 6]; + + let result1 = compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, false); + let result2 = compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, false); + + assert_eq!(result1, result2); + } + + #[test] + fn test_padding_logic() { + // Test case 1: Compact size < 1024, needs padding to multiple of 8 + let input_ids_1 = vec![1, 2, 3, 1, 2, 4]; // compact size = 4 + let position_ids_1 = vec![0, 1, 2, 0, 1, 2]; + let cu_seq_lengths_1 = vec![0, 3, 6]; + let (compact_ids_1, _, _, _) = + compute_fold_and_scatter(&input_ids_1, &position_ids_1, &cu_seq_lengths_1, true); + assert_eq!(compact_ids_1.len(), 8, "Should pad from 4 to 8"); + + // Test case 2: Compact size < 1024, already a multiple of 8 + let input_ids_2 = (0..8).collect::>(); + let position_ids_2 = (0..8).collect::>(); + let cu_seq_lengths_2 = vec![0, 8]; + let (compact_ids_2, _, _, _) = + compute_fold_and_scatter(&input_ids_2, &position_ids_2, &cu_seq_lengths_2, true); + assert_eq!( + compact_ids_2.len(), + 8, + "Should not pad when already multiple of 8 and small input?" + ); + + // Test case 3: Compact size > 1024, needs padding to multiple of 64 + let n = 2047; + let input_ids_3 = (0..n).collect::>(); + let position_ids_3 = (0..n).collect::>(); + let cu_seq_lengths_3 = vec![0, n]; + let (compact_ids_3, _, _, _) = + compute_fold_and_scatter(&input_ids_3, &position_ids_3, &cu_seq_lengths_3, true); + assert_eq!(compact_ids_3.len(), 2048, "Should pad from 2047 to 2048"); + + // Test case 4: Compact size > 1024, already a multiple of 64 + let n = 1024; + let input_ids_4 = (0..n).collect::>(); + let position_ids_4 = (0..n).collect::>(); + let cu_seq_lengths_4 = vec![0, n]; + let (compact_ids_4, _, _, _) = + compute_fold_and_scatter(&input_ids_4, &position_ids_4, &cu_seq_lengths_4, true); + assert_eq!( + compact_ids_4.len(), + 1024, + "Should not pad when already multiple of 64" + ); + } + + #[test] + fn test_padding_to_multiple_of_8() { + // Compact size will be 4, padding should bring it to 8. + let input_ids = vec![1, 2, 3, 1, 2, 4]; // compact: [1,2,3,4] + let position_ids = vec![0, 1, 2, 0, 1, 2]; + let cu_seq_lengths = vec![0, 3, 6]; + + let (compact_input_ids, compact_position_ids, _scatter, fold_gather) = + compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, true); + + assert_eq!(compact_input_ids.len(), 8, "Should be padded to 8"); + assert_eq!(compact_position_ids.len(), 8, "Should be padded to 8"); + assert_eq!(fold_gather.len(), 8, "Should be padded to 8"); + + // Check that the first part is correct + assert_eq!(&compact_input_ids[0..4], &[1, 2, 3, 4]); + // Check that padding is zeros + assert_eq!(&compact_input_ids[4..8], &[0, 0, 0, 0]); + assert_eq!(&compact_position_ids[4..8], &[0, 0, 0, 0]); + assert_eq!(&fold_gather[4..8], &[0, 0, 0, 0]); + + // Test case where no padding is needed (compact size is already a multiple of 8) + // Let's create a case that compacts to 8 tokens + let input_ids_no_pad = vec![1, 2, 3, 4, 5, 6, 7, 8]; + let position_ids_no_pad = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let cu_seq_lengths_no_pad = vec![0, 8]; + let (compact_ids_no_pad, _, _, _) = compute_fold_and_scatter( + &input_ids_no_pad, + &position_ids_no_pad, + &cu_seq_lengths_no_pad, + true, + ); + assert_eq!(compact_ids_no_pad.len(), 8, "Should not be padded"); + } + + // also, add some tests that allow you to reconstruct. e.g. do a function where we do the following function. + // this test is more sophisticated. + // impagine the baseline is + // input_ids = [..] + // position_ids = + + // def positional_embeddings()? e.g. add for each position 0.01 to the input_ids. + // optional + + // def dummy_mlp(input_tensors: Vec[f32]): + // input_ids *= 2 // simulates the mlp part, input ids get embedded. + // + /// def dummy_attention(transformed_ids: Vec[f32], cu_seq_lengths): + /// let final_values = [] + /// for start, end in cu_seq_lengths.take_two(): // unsure how + /// sequence_only = vector[slice(start, end)] + /// // attention part: + /// attention = cumsum(sequence_only) + /// final_values.push(attention) + /// final_values + /// + /// now do a range of input_ids, with and without prefix. + /// + /// attn_orig = attention(dummy_mlp(input_ids, cu_seq_length) + /// + /// for radix mlp approach + /// fold_ids, fold_positions, scatter = compute_fold_and_scatter() + /// compact_input_ids = input_ids.index_select(fold_ids) + /// compact_positions_ids = position_ids.index_select(fold_ids) + /// + /// compact_mlp_out = mlp(compact_input_ids) // output and input are len_compact + /// mlp_unfolded = compact_mlp_out.index_select(compact_mlp_out) // unfolded len is OG length before compact + /// attention_folded = dummy_attention(mlp_unfolded) + /// + /// test with various instances, and always assert that attention_folded and unfolded are always the same. + /// you could just implement it in plain rust, but also use helpers. + /// run over a large range of possible samples and interesting range of inputs. + // Helper functions for simulation + fn apply_positional_embeddings(input_ids: &[u32], position_ids: &[u32]) -> Vec { + input_ids + .iter() + .zip(position_ids.iter()) + .map(|(&token, &pos)| { + let base = token as f32; + let pos_embed = (pos as f32 * 0.1).sin() * 0.01; + base + pos_embed + }) + .collect() + } + + fn dummy_mlp(input_embeddings: &[f32]) -> Vec { + // Simple MLP: multiply by 2 and add small nonlinearity + input_embeddings + .iter() + .map(|&x| x * 2.0 + (x * 0.1).tanh() * 0.1) + .collect() + } + + fn dummy_attention(mlp_outputs: &[f32], cu_seq_lengths: &[u32]) -> Vec { + let mut final_values = Vec::new(); + + for i in 0..cu_seq_lengths.len().saturating_sub(1) { + let start = cu_seq_lengths[i] as usize; + let end = cu_seq_lengths[i + 1] as usize; + + if start < end && end <= mlp_outputs.len() { + let sequence_slice = &mlp_outputs[start..end]; + + // Cumulative sum (simplified attention) + let mut cumsum = 0.0; + for &value in sequence_slice { + cumsum += value; + final_values.push(cumsum); + } + } + } + + final_values + } + + fn index_select_f32(source: &[f32], indices: &[u32]) -> Vec { + indices.iter().map(|&idx| source[idx as usize]).collect() + } + + // Parameterized comparison function + #[derive(Debug)] + struct RadixMLPTestResult { + baseline_output: Vec, + radix_output: Vec, + compression_ratio: f32, + original_tokens: usize, + compact_tokens: usize, + } + + fn run_radix_mlp_comparison( + input_ids: &[u32], + position_ids: &[u32], + cu_seq_lengths: &[u32], + pad_multiple_of_8: bool, + ) -> RadixMLPTestResult { + // Baseline computation pipeline + let embeddings = apply_positional_embeddings(input_ids, position_ids); + let mlp_outputs = dummy_mlp(&embeddings); + let attention_baseline = dummy_attention(&mlp_outputs, cu_seq_lengths); + + // RadixMLP computation pipeline + let (compact_input_ids, compact_position_ids, scatter_indices, _fold_gather) = + compute_fold_and_scatter(input_ids, position_ids, cu_seq_lengths, pad_multiple_of_8); + + let compact_embeddings = + apply_positional_embeddings(&compact_input_ids, &compact_position_ids); + let compact_mlp_outputs = dummy_mlp(&compact_embeddings); + let unfolded_mlp_outputs = index_select_f32(&compact_mlp_outputs, &scatter_indices); + let attention_radix = dummy_attention(&unfolded_mlp_outputs, cu_seq_lengths); + + // Calculate metrics + let original_tokens = input_ids.len(); + let compact_tokens = compact_input_ids.len(); + let compression_ratio = if original_tokens > 0 { + compact_tokens as f32 / original_tokens as f32 + } else { + 1.0 + }; + + RadixMLPTestResult { + baseline_output: attention_baseline, + radix_output: attention_radix, + compression_ratio, + original_tokens, + compact_tokens, + } + } + + fn assert_outputs_equal(result: &RadixMLPTestResult, test_name: &str, tolerance: f32) { + assert_eq!( + result.baseline_output.len(), + result.radix_output.len(), + "{}: Output length mismatch", + test_name + ); + + for (i, (baseline, radix)) in result + .baseline_output + .iter() + .zip(result.radix_output.iter()) + .enumerate() + { + assert!( + (baseline - radix).abs() < tolerance, + "{}: Mismatch at index {}: baseline={}, radix={}, diff={}", + test_name, + i, + baseline, + radix, + (baseline - radix).abs() + ); + } + } + + fn assert_compression_achieved( + result: &RadixMLPTestResult, + test_name: &str, + expected_compression: bool, + pad_multiple_of_8: bool, + ) { + if expected_compression { + // When padding is enabled, we might not strictly achieve compression + // if the overhead of padding > the gain from deduplication. + // But generally for these tests we construct cases where deduplication is significant. + // We can relax this check or make it context aware, but for now let's keep it simple. + // NOTE: logic kept as is, might fail if padding > savings. + let addition = if pad_multiple_of_8 { + 8 - (result.compact_tokens % 8) + } else { + 0 + }; + assert!( + result.compact_tokens < result.original_tokens + addition, + "{}: Expected compression but got {} -> {} tokens", + test_name, + result.original_tokens, + result.compact_tokens + ); + } else if pad_multiple_of_8 { + // With padding, we might not achieve compression if the compact size is already a multiple of 8. + assert!( + result.compact_tokens >= result.original_tokens, + "{}: Expected no compression (>=) but got {} -> {} tokens", + test_name, + result.original_tokens, + result.compact_tokens + ); + } else { + // Without padding, we should not have fewer tokens than original. + assert_eq!( + result.compact_tokens, result.original_tokens, + "{}: Expected no compression but got {} -> {} tokens", + test_name, result.original_tokens, result.compact_tokens + ); + } + } + + // Test case structure for parameterized tests + #[derive(Debug)] + struct TestCase { + name: &'static str, + input_ids: Vec, + position_ids: Vec, + cu_seq_lengths: Vec, + expect_compression: bool, + expected_compression_ratio: Option, // None means don't check specific ratio + pad_multiple_of_8: bool, + } + + // ...existing basic tests... + #[test] + fn test_radix_mlp_reconstruction_parameterized() { + let test_cases = vec![ + TestCase { + name: "identical_sequences", + input_ids: vec![5, 10, 15, 5, 10, 15], + position_ids: vec![0, 1, 2, 0, 1, 2], + cu_seq_lengths: vec![0, 3, 6], + expect_compression: true, + expected_compression_ratio: Some(0.5), // 6 -> 3 tokens + pad_multiple_of_8: false, + }, + TestCase { + name: "identical_sequences_padded", + input_ids: vec![5, 10, 15, 5, 10, 15], + position_ids: vec![0, 1, 2, 0, 1, 2], + cu_seq_lengths: vec![0, 3, 6], + expect_compression: false, // 6 -> 3 -> padded to 8. 8 > 6. So strictly no compression in terms of count. + expected_compression_ratio: None, + pad_multiple_of_8: true, + }, + TestCase { + name: "shared_prefix", + input_ids: vec![1, 2, 3, 1, 2, 4], + position_ids: vec![0, 1, 2, 0, 1, 2], + cu_seq_lengths: vec![0, 3, 6], + expect_compression: true, + expected_compression_ratio: Some(4.0 / 6.0), // 6 -> 4 tokens + pad_multiple_of_8: false, + }, + TestCase { + name: "shared_prefix_padded", + input_ids: vec![1, 2, 3, 1, 2, 4], + position_ids: vec![0, 1, 2, 0, 1, 2], + cu_seq_lengths: vec![0, 3, 6], + expect_compression: false, // 6 -> 4 -> padded to 8. + expected_compression_ratio: None, + pad_multiple_of_8: true, + }, + TestCase { + name: "no_overlap", + input_ids: vec![1, 2, 3, 4, 5, 6], + position_ids: vec![0, 1, 2, 0, 1, 2], + cu_seq_lengths: vec![0, 3, 6], + expect_compression: false, + expected_compression_ratio: Some(1.0), + pad_multiple_of_8: false, + }, + TestCase { + name: "complex_three_sequences", + input_ids: vec![1, 2, 3, 4, 1, 2, 5, 6, 1, 2, 3, 7], + position_ids: vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], + cu_seq_lengths: vec![0, 4, 8, 12], + expect_compression: true, + expected_compression_ratio: None, // Don't check specific ratio + pad_multiple_of_8: false, + }, + TestCase { + name: "complex_three_sequences_padded", + input_ids: vec![1, 2, 3, 4, 1, 2, 5, 6, 1, 2, 3, 7], + position_ids: vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], + cu_seq_lengths: vec![0, 4, 8, 12], + expect_compression: true, // 12 -> something < 12. If small enough, even with padding it's < 12. + // Actual unique: [1,2,3,4,5,6,7] -> 7 unique tokens. Padded to 8. 8 < 12. + expected_compression_ratio: None, + pad_multiple_of_8: true, + }, + TestCase { + name: "single_tokens", + input_ids: vec![1, 2, 1], + position_ids: vec![0, 0, 0], + cu_seq_lengths: vec![0, 1, 2, 3], + expect_compression: true, + expected_compression_ratio: Some(2.0 / 3.0), // 3 -> 2 tokens + pad_multiple_of_8: false, + }, + TestCase { + name: "different_positions", + input_ids: vec![1, 2, 1, 2], + position_ids: vec![0, 1, 2, 3], + cu_seq_lengths: vec![0, 2, 4], + expect_compression: false, + expected_compression_ratio: Some(1.0), + pad_multiple_of_8: false, + }, + ]; + + for test_case in test_cases { + let result = run_radix_mlp_comparison( + &test_case.input_ids, + &test_case.position_ids, + &test_case.cu_seq_lengths, + test_case.pad_multiple_of_8, + ); + + // Assert outputs are numerically identical + assert_outputs_equal(&result, test_case.name, 1e-6); + + // Assert compression expectations + assert_compression_achieved( + &result, + test_case.name, + test_case.expect_compression, + test_case.pad_multiple_of_8, + ); + + // Assert specific compression ratio if provided + if let Some(expected_ratio) = test_case.expected_compression_ratio { + assert!( + (result.compression_ratio - expected_ratio).abs() < 1e-6, + "{}: Expected compression ratio {}, got {}", + test_case.name, + expected_ratio, + result.compression_ratio + ); + } + + println!( + "{}: {} -> {} tokens (ratio: {:.3})", + test_case.name, + result.original_tokens, + result.compact_tokens, + result.compression_ratio + ); + } + } + + #[test] + fn test_radix_mlp_edge_cases_parameterized() { + let edge_cases = vec![ + TestCase { + name: "empty", + input_ids: vec![], + position_ids: vec![], + cu_seq_lengths: vec![], + expect_compression: false, + expected_compression_ratio: None, + pad_multiple_of_8: false, + }, + TestCase { + name: "single_token_single_sequence", + input_ids: vec![42], + position_ids: vec![0], + cu_seq_lengths: vec![0, 1], + expect_compression: false, + expected_compression_ratio: Some(1.0), + pad_multiple_of_8: false, + }, + TestCase { + name: "single_token_single_sequence_padded", + input_ids: vec![42], + position_ids: vec![0], + cu_seq_lengths: vec![0, 1], + expect_compression: false, + expected_compression_ratio: None, + pad_multiple_of_8: true, + }, + TestCase { + name: "long_identical_sequences", + input_ids: vec![1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5], + position_ids: vec![0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4], + cu_seq_lengths: vec![0, 5, 10, 15], + expect_compression: true, + expected_compression_ratio: Some(1.0 / 3.0), + pad_multiple_of_8: false, + }, + TestCase { + name: "long_identical_sequences_with_padding", + input_ids: vec![1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5], + position_ids: vec![0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4], + cu_seq_lengths: vec![0, 5, 10, 15], + expect_compression: true, + expected_compression_ratio: Some(8.0 / 15.0), // 15 -> 8 (with padding), ratio = 8/15 ~ 0.5333 + pad_multiple_of_8: true, + }, + ]; + + for test_case in edge_cases { + if test_case.input_ids.is_empty() { + let (compact_input_ids, compact_position_ids, scatter_indices, fold_gather) = + compute_fold_and_scatter( + &test_case.input_ids, + &test_case.position_ids, + &test_case.cu_seq_lengths, + test_case.pad_multiple_of_8, + ); + assert!(compact_input_ids.is_empty()); + assert!(compact_position_ids.is_empty()); + assert!(scatter_indices.is_empty()); + assert!(fold_gather.is_empty()); + continue; + } + + let result = run_radix_mlp_comparison( + &test_case.input_ids, + &test_case.position_ids, + &test_case.cu_seq_lengths, + test_case.pad_multiple_of_8, + ); + + assert_outputs_equal(&result, test_case.name, 1e-6); + assert_compression_achieved( + &result, + test_case.name, + test_case.expect_compression, + test_case.pad_multiple_of_8, + ); + + if let Some(expected_ratio) = test_case.expected_compression_ratio { + assert!( + (result.compression_ratio - expected_ratio).abs() < 1e-6, + "{}: Expected compression ratio {}, got {}", + test_case.name, + expected_ratio, + result.compression_ratio + ); + } + } + } + + #[test] + fn fail_and_report_time_large_batch() { + use std::time::Instant; + + // Relevant-sized problem: + // - batch = 32 sequences + // - each sequence has a shared prefix of 128 tokens (max dedup) + // - plus a unique tail of 200 tokens + // -> total ~ 10,496 tokens + let batch: usize = 32; + let shared_prefix: usize = 128; + let tail_len: usize = 200; + let seq_len: usize = shared_prefix + tail_len; + let total_tokens: usize = batch * seq_len; + + let mut input_ids: Vec = Vec::with_capacity(total_tokens); + let mut position_ids: Vec = Vec::with_capacity(total_tokens); + let mut cu_seq_lengths: Vec = Vec::with_capacity(batch + 1); + cu_seq_lengths.push(0); + + for seq_idx in 0..batch { + // Shared prefix across all sequences: same tokens, same positions + for j in 0..shared_prefix { + let token = (j as u32 % 1000) + 1; + input_ids.push(token); + position_ids.push(j as u32); + } + // Unique tail per sequence to keep the problem realistic + for k in 0..tail_len { + let token = 1_000_000u32 + (seq_idx as u32) * 10_000 + (k as u32); + input_ids.push(token); + position_ids.push((shared_prefix + k) as u32); + } + cu_seq_lengths.push(input_ids.len() as u32); + } + + let t0 = Instant::now(); + let (compact_ids, _compact_pos, _scatter, _fold) = + super::compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, false); + let dt = t0.elapsed(); + let dt_ms = dt.as_secs_f64() * 1000.0; + + let ratio = (compact_ids.len() as f64) / (input_ids.len() as f64); + + // Use println! so you also see this under --nocapture; include details in the panic too. + println!( + "compute_fold_and_scatter:\n batch={}\n seq_len={}\n total_tokens={}\n compact_tokens={}\n ratio={:.3}\n elapsed_ms={:.3}", + batch, + seq_len, + input_ids.len(), + compact_ids.len(), + ratio, + dt_ms + ); + + // Intentionally fail so the timing and stats are printed in default test runs. + // panic!( + // "TIMING REPORT (intentional failure to show output): \ + // batch={}, seq_len={}, total_tokens={}, compact_tokens={}, ratio={:.3}, elapsed_ms={:.3}\n\ + // scatter_len={}, fold_len={}, compact_pos_len={}", + // batch, seq_len, input_ids.len(), compact_ids.len(), ratio, dt_ms, + // scatter.len(), fold.len(), compact_pos.len() + // ); + } +} diff --git a/router/src/lib.rs b/router/src/lib.rs index d83bd95c5..2b1ec4d69 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -50,6 +50,7 @@ pub async fn run( max_concurrent_requests: usize, max_batch_tokens: usize, max_batch_requests: Option, + radix_mlp_threshold: f32, max_client_batch_size: usize, auto_truncate: bool, default_prompt: Option, @@ -294,10 +295,36 @@ pub async fn run( .or(max_batch_requests); // Queue logic + let radix_mlp_threshold = if config.model_type == "bert" + || config.model_type == "xlm-roberta" + || config.model_type == "camembert" + || config.model_type == "roberta" + || config.model_type == "distilbert" + || config.model_type == "modernbert" + || config.use_bidirectional_attention.unwrap_or(false) + || !backend.radix_mlp_supported + { + if radix_mlp_threshold > 0.0 { + tracing::warn!("`--radix-mlp-threshold` is only supported for Causal LM's Qwen2.5, Qwen3 and LLaMA models. Disabling RadixMLP."); + } + 0.0 + } else { + radix_mlp_threshold + }; + if radix_mlp_threshold > 0.0 { + tracing::info!( + "RadixMLP enabled with compression ratio threshold: {}", + radix_mlp_threshold + ); + } else { + tracing::info!("RadixMLP disabled"); + } + let queue = Queue::new( backend.padded_model, max_batch_tokens, max_batch_requests, + radix_mlp_threshold, max_concurrent_requests, ); @@ -449,6 +476,7 @@ pub struct ModelConfig { pub pad_token_id: usize, pub id2label: Option>, pub label2id: Option>, + pub use_bidirectional_attention: Option, } #[derive(Debug, Clone, PartialEq, Deserialize)] diff --git a/router/src/main.rs b/router/src/main.rs index 52bb8e9b5..9b43a6a5d 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -8,6 +8,17 @@ use veil::Redact; #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; +fn pct_parser(s: &str) -> Result { + let v = s.parse::().map_err(|e| e.to_string())?; + if !(0.0..=1.0).contains(&v) { + return Err(format!( + "The value must be between 0.0 and 1.0, but got {}", + v + )); + } + Ok(v) +} + /// App Configuration #[derive(Parser, Redact)] #[clap(author, version, about, long_about = None)] @@ -69,6 +80,14 @@ struct Args { #[clap(long, env)] max_batch_requests: Option, + /// RadixMLP threshold. + /// + /// Set the threshold for RadixMLP. + /// If the compression ratio is lower than the threshold, RadixMLP will be used. + /// The default is 0.85 for most models, and 0.0 (force disabled) for bidirectional models. + #[clap(long, env, default_value = "0.85", value_parser = pct_parser)] + radix_mlp_threshold: f32, + /// Control the maximum number of inputs that a client can send in a single request #[clap(default_value = "32", long, env)] max_client_batch_size: usize, @@ -232,6 +251,7 @@ async fn main() -> Result<()> { args.max_concurrent_requests, args.max_batch_tokens, args.max_batch_requests, + args.radix_mlp_threshold, args.max_client_batch_size, args.auto_truncate, args.default_prompt, diff --git a/router/src/prometheus.rs b/router/src/prometheus.rs index d011efbad..0e1c2c4fb 100644 --- a/router/src/prometheus.rs +++ b/router/src/prometheus.rs @@ -37,11 +37,18 @@ pub(crate) fn prometheus_builer( let batch_tokens_matcher = Matcher::Full(String::from("te_batch_next_tokens")); let batch_tokens_buckets: Vec = (0..21).map(|x| 2.0_f64.powi(x)).collect(); + // Compression ratio buckets (for values between 0 and 1) + let compression_ratio_matcher = Matcher::Full(String::from("te_radix_mlp_compression_ratio")); + let compression_ratio_buckets: Vec = vec![ + 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.85, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0 + ]; + // Prometheus handler PrometheusBuilder::new() .with_http_listener(addr) .set_buckets_for_metric(duration_matcher, &duration_buckets)? .set_buckets_for_metric(input_length_matcher, &input_length_buckets)? .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)? - .set_buckets_for_metric(batch_tokens_matcher, &batch_tokens_buckets) + .set_buckets_for_metric(batch_tokens_matcher, &batch_tokens_buckets)? + .set_buckets_for_metric(compression_ratio_matcher, &compression_ratio_buckets) } diff --git a/router/tests/common.rs b/router/tests/common.rs index 476211764..7dc40714b 100644 --- a/router/tests/common.rs +++ b/router/tests/common.rs @@ -54,6 +54,7 @@ pub async fn start_server(model_id: String, revision: Option, dtype: DTy 4, 1024, None, + 0.0, 32, false, None,