Skip to content

Commit d2e3b66

Browse files
oandreeva-nvoandreeva-nv
andauthored
feat: Transition to FullyContiguous Host and Disk layouts (#3090)
Signed-off-by: Olga Andreeva <[email protected]> Signed-off-by: Olga Andreeva <[email protected]> Co-authored-by: oandreeva-nv <[email protected]>
1 parent a5e1d45 commit d2e3b66

File tree

10 files changed

+1995
-61
lines changed

10 files changed

+1995
-61
lines changed

container/Dockerfile.vllm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,4 +343,4 @@ RUN uv pip install maturin[patchelf] && \
343343
uv pip install --no-deps -e .
344344

345345
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
346-
CMD []
346+
CMD []

lib/bindings/python/rust/llm/block_manager/distributed.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ mod worker;
99

1010
pub use leader::KvbmLeader;
1111
pub use utils::get_barrier_id_prefix;
12-
pub use worker::{KvbmWorker, VllmTensor};
12+
pub use worker::{KvbmWorker, PyLayoutType, VllmTensor};

lib/bindings/python/rust/llm/block_manager/distributed/worker.rs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,44 @@ use llm_rs::block_manager::distributed::{
1010
BlockTransferHandler as RustBlockTransferHandler, KvbmWorker as KvbmWorkerImpl,
1111
KvbmWorkerConfig,
1212
};
13+
use llm_rs::block_manager::layout::LayoutType;
1314
use llm_rs::block_manager::storage::torch::{TorchDevice, TorchTensor};
1415

16+
/// A wrapper around a layout type.
17+
/// This is used to convert between the Python and Rust layout types.
18+
#[pyclass(eq, eq_int)]
19+
#[derive(Clone, PartialEq, Eq)]
20+
pub enum PyLayoutType {
21+
FullyContiguous,
22+
LayerSeparate,
23+
}
24+
25+
#[pymethods]
26+
impl PyLayoutType {
27+
/// String representation of the layout type
28+
fn __str__(&self) -> &'static str {
29+
match self {
30+
PyLayoutType::FullyContiguous => "FullyContiguous",
31+
PyLayoutType::LayerSeparate => "LayerSeparate",
32+
}
33+
}
34+
35+
/// Representation for debugging
36+
fn __repr__(&self) -> String {
37+
format!("PyLayoutType.{}", self.__str__())
38+
}
39+
}
40+
41+
impl From<PyLayoutType> for LayoutType {
42+
fn from(py_layout: PyLayoutType) -> Self {
43+
match py_layout {
44+
PyLayoutType::FullyContiguous => LayoutType::FullyContiguous,
45+
// Layout (outer_contiguous vs block_contiguous) is auto-detected from tensor shapes
46+
PyLayoutType::LayerSeparate => LayoutType::layer_separate_auto_default(),
47+
}
48+
}
49+
}
50+
1551
/// A wrapper around a Torch tensor.
1652
/// We hold onto the py object to ensure it doesn't get GCed.
1753
#[derive(Clone, Debug)]
@@ -107,7 +143,7 @@ impl KvbmWorker {
107143
#[pymethods]
108144
impl KvbmWorker {
109145
#[new]
110-
#[pyo3(signature = (num_device_blocks, page_size, tensors, device_id=0, dtype_width_bytes=2, drt=None, layout_blocking=false))]
146+
#[pyo3(signature = (num_device_blocks, page_size, tensors, device_id=0, dtype_width_bytes=2, drt=None, layout_blocking=false, device_layout_type=None, host_layout_type=None, disk_layout_type=None))]
111147
fn new(
112148
num_device_blocks: usize,
113149
page_size: usize,
@@ -116,6 +152,9 @@ impl KvbmWorker {
116152
dtype_width_bytes: usize,
117153
drt: Option<DistributedRuntime>,
118154
layout_blocking: bool,
155+
device_layout_type: Option<PyLayoutType>,
156+
host_layout_type: Option<PyLayoutType>,
157+
disk_layout_type: Option<PyLayoutType>,
119158
) -> PyResult<Self> {
120159
let py_drt = drt.ok_or_else(|| {
121160
pyo3::exceptions::PyValueError::new_err("DistributedRuntime (drt) must be provided")
@@ -142,6 +181,21 @@ impl KvbmWorker {
142181
.device_id(device_id)
143182
.dtype_width_bytes(dtype_width_bytes)
144183
.barrier_id_prefix(barrier_id_prefix)
184+
.device_layout_type(
185+
device_layout_type
186+
.map(|py_layout| py_layout.into())
187+
.unwrap_or(LayoutType::FullyContiguous),
188+
)
189+
.host_layout_type(
190+
host_layout_type
191+
.map(|py_layout| py_layout.into())
192+
.unwrap_or(LayoutType::FullyContiguous),
193+
)
194+
.disk_layout_type(
195+
disk_layout_type
196+
.map(|py_layout| py_layout.into())
197+
.unwrap_or(LayoutType::FullyContiguous),
198+
)
145199
.build()
146200
.map_err(to_pyerr)?;
147201

lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use crate::{
1919

2020
use anyhow;
2121
use dynamo_llm::block_manager::distributed::{KvbmWorker, KvbmWorkerConfig};
22+
use dynamo_llm::block_manager::layout::LayoutType;
2223
use dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics;
2324
use dynamo_llm::block_manager::storage::torch::TorchTensor;
2425
use dynamo_runtime::DistributedRuntime;
@@ -144,7 +145,9 @@ impl Worker for KvConnectorWorker {
144145
.tensors(kv_cache_tensors)
145146
.device_id(device_id)
146147
.dtype_width_bytes(dtype_width_bytes)
147-
.is_fully_contiguous_layout(true)
148+
.device_layout_type(LayoutType::FullyContiguous)
149+
.host_layout_type(LayoutType::FullyContiguous)
150+
.disk_layout_type(LayoutType::FullyContiguous)
148151
.barrier_id_prefix(get_barrier_id_prefix())
149152
.scheduler_client(Some(self.transfer_client.clone()))
150153
.build()?;

lib/bindings/python/rust/llm/block_manager/vllm/connector/worker.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ use crate::{
1818
};
1919
use dynamo_runtime::metrics::prometheus_names::kvbm_connector;
2020

21+
use crate::llm::block_manager::distributed::PyLayoutType;
2122
use anyhow;
2223
use dynamo_llm::block_manager::distributed::{KvbmWorker, KvbmWorkerConfig};
24+
use dynamo_llm::block_manager::layout::LayoutType;
2325
use dynamo_llm::block_manager::storage::torch::TorchTensor;
2426
use dynamo_runtime::DistributedRuntime;
2527
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
@@ -33,6 +35,9 @@ pub trait Worker: Send + Sync {
3335
dtype_width_bytes: usize,
3436
kv_caches: Vec<(String, Arc<VllmTensor>)>,
3537
raw_event_handles: Vec<u64>,
38+
device_layout_type: Option<LayoutType>,
39+
host_layout_type: Option<LayoutType>,
40+
disk_layout_type: Option<LayoutType>,
3641
) -> anyhow::Result<()>;
3742

3843
fn bind_connector_metadata(&mut self, metadata: Vec<u8>) -> anyhow::Result<()>;
@@ -133,6 +138,9 @@ impl Worker for KvConnectorWorker {
133138
dtype_width_bytes: usize,
134139
kv_caches: Vec<(String, Arc<VllmTensor>)>,
135140
raw_event_handles: Vec<u64>,
141+
device_layout_type: Option<LayoutType>,
142+
host_layout_type: Option<LayoutType>,
143+
disk_layout_type: Option<LayoutType>,
136144
) -> anyhow::Result<()> {
137145
if self.kvbm_worker.get().is_some() {
138146
tracing::warn!("kvbm worker already registered");
@@ -147,9 +155,16 @@ impl Worker for KvConnectorWorker {
147155

148156
// Process kv_caches in layer execution order (already sorted by layer index)
149157
let mut vllm_tensors = Vec::new();
158+
let mut first_tensor_shape: Option<Vec<usize>> = None;
159+
150160
for (layer_name, vllm_tensor) in kv_caches {
151161
tracing::trace!("Registering KV cache layer: {layer_name}, tensor: {vllm_tensor:?}");
152162

163+
// Capture the shape of the first tensor for layout detection
164+
if first_tensor_shape.is_none() {
165+
first_tensor_shape = Some(vllm_tensor.shape());
166+
}
167+
153168
// Store for later lookup by name
154169
self.kv_cache_layers.push((layer_name, vllm_tensor.clone()));
155170

@@ -159,6 +174,35 @@ impl Worker for KvConnectorWorker {
159174

160175
self.layer_events = raw_event_handles;
161176

177+
// Auto-detect device layout type if not explicitly provided
178+
let detected_device_layout_type = match device_layout_type {
179+
Some(layout) => layout,
180+
None => {
181+
if let Some(ref shape) = first_tensor_shape {
182+
match LayoutType::layer_separate_auto(shape, num_device_blocks) {
183+
Ok(detected) => {
184+
tracing::info!(
185+
"Auto-detected device layout from tensor shape: {:?}",
186+
detected
187+
);
188+
detected
189+
}
190+
Err(e) => {
191+
tracing::warn!(
192+
"Failed to auto-detect layout from shape {:?}: {}. Using default.",
193+
shape,
194+
e
195+
);
196+
LayoutType::layer_separate_auto_default()
197+
}
198+
}
199+
} else {
200+
tracing::warn!("No tensors available for layout detection. Using default.");
201+
LayoutType::layer_separate_auto_default()
202+
}
203+
}
204+
};
205+
162206
let config = KvbmWorkerConfig::builder()
163207
.drt(self.drt.clone())
164208
.num_device_blocks(num_device_blocks)
@@ -168,6 +212,9 @@ impl Worker for KvConnectorWorker {
168212
.dtype_width_bytes(dtype_width_bytes)
169213
.barrier_id_prefix(get_barrier_id_prefix())
170214
.scheduler_client(Some(self.transfer_client.clone()))
215+
.device_layout_type(detected_device_layout_type)
216+
.host_layout_type(host_layout_type.unwrap_or(LayoutType::FullyContiguous))
217+
.disk_layout_type(disk_layout_type.unwrap_or(LayoutType::FullyContiguous))
171218
.build()?;
172219

173220
let worker = self.drt.runtime().primary().block_on(async move {
@@ -416,6 +463,7 @@ impl PyKvConnectorWorker {
416463
Ok(Self { connector_worker })
417464
}
418465

466+
#[pyo3(signature = (num_device_blocks, page_size, device_id, dtype_width_bytes, kv_caches, raw_event_handles, device_layout_type=None, host_layout_type=None, disk_layout_type=None))]
419467
pub fn register_kv_caches(
420468
&mut self,
421469
num_device_blocks: usize,
@@ -424,6 +472,9 @@ impl PyKvConnectorWorker {
424472
dtype_width_bytes: usize,
425473
kv_caches: Vec<(String, Py<PyAny>)>,
426474
raw_event_handles: Vec<u64>,
475+
device_layout_type: Option<PyLayoutType>,
476+
host_layout_type: Option<PyLayoutType>,
477+
disk_layout_type: Option<PyLayoutType>,
427478
) -> PyResult<()> {
428479
// Convert Python tensors to Rust VllmTensor objects
429480
let mut rust_kv_caches = Vec::new();
@@ -440,6 +491,9 @@ impl PyKvConnectorWorker {
440491
dtype_width_bytes,
441492
rust_kv_caches,
442493
raw_event_handles,
494+
device_layout_type.map(|py_layout| py_layout.into()),
495+
host_layout_type.map(|py_layout| py_layout.into()),
496+
disk_layout_type.map(|py_layout| py_layout.into()),
443497
)
444498
.map_err(to_pyerr)
445499
}

0 commit comments

Comments
 (0)