Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion container/Dockerfile.vllm
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,4 @@ RUN uv pip install maturin[patchelf]
ENV PYTHONPATH=${WORKSPACE_DIR}/components/metrics/src:${WORKSPACE_DIR}/components/frontend/src:${WORKSPACE_DIR}/components/planner/src:${WORKSPACE_DIR}/components/backends/mocker/src:${WORKSPACE_DIR}/components/backends/trtllm/src:${WORKSPACE_DIR}/components/backends/vllm/src:${WORKSPACE_DIR}/components/backends/sglang/src:${WORKSPACE_DIR}/components/backends/llama_cpp/src

ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
CMD []
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ mod worker;

pub use leader::KvbmLeader;
pub use utils::get_barrier_id_prefix;
pub use worker::{KvbmWorker, VllmTensor};
pub use worker::{KvbmWorker, PyLayoutType, VllmTensor};
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,44 @@ use llm_rs::block_manager::distributed::{
BlockTransferHandler as RustBlockTransferHandler, KvbmWorker as KvbmWorkerImpl,
KvbmWorkerConfig,
};
use llm_rs::block_manager::layout::LayoutType;
use llm_rs::block_manager::storage::torch::{TorchDevice, TorchTensor};

/// A wrapper around a layout type.
/// This is used to convert between the Python and Rust layout types.
#[pyclass(eq, eq_int)]
#[derive(Clone, PartialEq, Eq)]
pub enum PyLayoutType {
FullyContiguous,
LayerSeparate,
}

#[pymethods]
impl PyLayoutType {
/// String representation of the layout type
fn __str__(&self) -> &'static str {
match self {
PyLayoutType::FullyContiguous => "FullyContiguous",
PyLayoutType::LayerSeparate => "LayerSeparate",
}
}

/// Representation for debugging
fn __repr__(&self) -> String {
format!("PyLayoutType.{}", self.__str__())
}
}

impl From<PyLayoutType> for LayoutType {
fn from(py_layout: PyLayoutType) -> Self {
match py_layout {
PyLayoutType::FullyContiguous => LayoutType::FullyContiguous,
// Layout (outer_contiguous vs block_contiguous) is auto-detected from tensor shapes
PyLayoutType::LayerSeparate => LayoutType::layer_separate_auto_default(),
}
}
}

/// A wrapper around a Torch tensor.
/// We hold onto the py object to ensure it doesn't get GCed.
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -107,7 +143,7 @@ impl KvbmWorker {
#[pymethods]
impl KvbmWorker {
#[new]
#[pyo3(signature = (num_device_blocks, page_size, tensors, device_id=0, dtype_width_bytes=2, drt=None, layout_blocking=false))]
#[pyo3(signature = (num_device_blocks, page_size, tensors, device_id=0, dtype_width_bytes=2, drt=None, layout_blocking=false, device_layout_type=None, host_layout_type=None, disk_layout_type=None))]
fn new(
num_device_blocks: usize,
page_size: usize,
Expand All @@ -116,6 +152,9 @@ impl KvbmWorker {
dtype_width_bytes: usize,
drt: Option<DistributedRuntime>,
layout_blocking: bool,
device_layout_type: Option<PyLayoutType>,
host_layout_type: Option<PyLayoutType>,
disk_layout_type: Option<PyLayoutType>,
) -> PyResult<Self> {
let py_drt = drt.ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err("DistributedRuntime (drt) must be provided")
Expand All @@ -142,6 +181,21 @@ impl KvbmWorker {
.device_id(device_id)
.dtype_width_bytes(dtype_width_bytes)
.barrier_id_prefix(barrier_id_prefix)
.device_layout_type(
device_layout_type
.map(|py_layout| py_layout.into())
.unwrap_or(LayoutType::FullyContiguous),
)
.host_layout_type(
host_layout_type
.map(|py_layout| py_layout.into())
.unwrap_or(LayoutType::FullyContiguous),
)
.disk_layout_type(
disk_layout_type
.map(|py_layout| py_layout.into())
.unwrap_or(LayoutType::FullyContiguous),
)
.build()
.map_err(to_pyerr)?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::{

use anyhow;
use dynamo_llm::block_manager::distributed::{KvbmWorker, KvbmWorkerConfig};
use dynamo_llm::block_manager::layout::LayoutType;
use dynamo_llm::block_manager::storage::torch::TorchTensor;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
Expand Down Expand Up @@ -134,7 +135,9 @@ impl Worker for KvConnectorWorker {
.tensors(kv_cache_tensors)
.device_id(device_id)
.dtype_width_bytes(dtype_width_bytes)
.is_fully_contiguous_layout(true)
.device_layout_type(LayoutType::FullyContiguous)
.host_layout_type(LayoutType::FullyContiguous)
.disk_layout_type(LayoutType::FullyContiguous)
.barrier_id_prefix(get_barrier_id_prefix())
.scheduler_client(Some(self.transfer_client.clone()))
.build()?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ use crate::{
};
use dynamo_runtime::metrics::prometheus_names::kvbm_connector;

use crate::llm::block_manager::distributed::PyLayoutType;
use anyhow;
use dynamo_llm::block_manager::distributed::{KvbmWorker, KvbmWorkerConfig};
use dynamo_llm::block_manager::layout::LayoutType;
use dynamo_llm::block_manager::storage::torch::TorchTensor;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
Expand All @@ -33,6 +35,9 @@ pub trait Worker: Send + Sync {
dtype_width_bytes: usize,
kv_caches: Vec<(String, Arc<VllmTensor>)>,
raw_event_handles: Vec<u64>,
device_layout_type: Option<LayoutType>,
host_layout_type: Option<LayoutType>,
disk_layout_type: Option<LayoutType>,
) -> anyhow::Result<()>;

fn bind_connector_metadata(&mut self, metadata: Vec<u8>) -> anyhow::Result<()>;
Expand Down Expand Up @@ -133,6 +138,9 @@ impl Worker for KvConnectorWorker {
dtype_width_bytes: usize,
kv_caches: Vec<(String, Arc<VllmTensor>)>,
raw_event_handles: Vec<u64>,
device_layout_type: Option<LayoutType>,
host_layout_type: Option<LayoutType>,
disk_layout_type: Option<LayoutType>,
) -> anyhow::Result<()> {
if self.kvbm_worker.get().is_some() {
tracing::warn!("kvbm worker already registered");
Expand All @@ -147,9 +155,16 @@ impl Worker for KvConnectorWorker {

// Process kv_caches in layer execution order (already sorted by layer index)
let mut vllm_tensors = Vec::new();
let mut first_tensor_shape: Option<Vec<usize>> = None;

for (layer_name, vllm_tensor) in kv_caches {
tracing::trace!("Registering KV cache layer: {layer_name}, tensor: {vllm_tensor:?}");

// Capture the shape of the first tensor for layout detection
if first_tensor_shape.is_none() {
first_tensor_shape = Some(vllm_tensor.shape());
}

// Store for later lookup by name
self.kv_cache_layers.push((layer_name, vllm_tensor.clone()));

Expand All @@ -159,6 +174,35 @@ impl Worker for KvConnectorWorker {

self.layer_events = raw_event_handles;

// Auto-detect device layout type if not explicitly provided
let detected_device_layout_type = match device_layout_type {
Some(layout) => layout,
None => {
if let Some(ref shape) = first_tensor_shape {
match LayoutType::layer_separate_auto(shape, num_device_blocks) {
Ok(detected) => {
tracing::info!(
"Auto-detected device layout from tensor shape: {:?}",
detected
);
detected
}
Err(e) => {
tracing::warn!(
"Failed to auto-detect layout from shape {:?}: {}. Using default.",
shape,
e
);
LayoutType::layer_separate_auto_default()
}
}
} else {
tracing::warn!("No tensors available for layout detection. Using default.");
LayoutType::layer_separate_auto_default()
}
}
};

let config = KvbmWorkerConfig::builder()
.drt(self.drt.clone())
.num_device_blocks(num_device_blocks)
Expand All @@ -168,6 +212,9 @@ impl Worker for KvConnectorWorker {
.dtype_width_bytes(dtype_width_bytes)
.barrier_id_prefix(get_barrier_id_prefix())
.scheduler_client(Some(self.transfer_client.clone()))
.device_layout_type(detected_device_layout_type)
.host_layout_type(host_layout_type.unwrap_or(LayoutType::FullyContiguous))
.disk_layout_type(disk_layout_type.unwrap_or(LayoutType::FullyContiguous))
.build()?;

let worker = self.drt.runtime().primary().block_on(async move {
Expand Down Expand Up @@ -416,6 +463,7 @@ impl PyKvConnectorWorker {
Ok(Self { connector_worker })
}

#[pyo3(signature = (num_device_blocks, page_size, device_id, dtype_width_bytes, kv_caches, raw_event_handles, device_layout_type=None, host_layout_type=None, disk_layout_type=None))]
pub fn register_kv_caches(
&mut self,
num_device_blocks: usize,
Expand All @@ -424,6 +472,9 @@ impl PyKvConnectorWorker {
dtype_width_bytes: usize,
kv_caches: Vec<(String, Py<PyAny>)>,
raw_event_handles: Vec<u64>,
device_layout_type: Option<PyLayoutType>,
host_layout_type: Option<PyLayoutType>,
disk_layout_type: Option<PyLayoutType>,
) -> PyResult<()> {
// Convert Python tensors to Rust VllmTensor objects
let mut rust_kv_caches = Vec::new();
Expand All @@ -440,6 +491,9 @@ impl PyKvConnectorWorker {
dtype_width_bytes,
rust_kv_caches,
raw_event_handles,
device_layout_type.map(|py_layout| py_layout.into()),
host_layout_type.map(|py_layout| py_layout.into()),
disk_layout_type.map(|py_layout| py_layout.into()),
)
.map_err(to_pyerr)
}
Expand Down
Loading
Loading