Skip to content

Commit 7dfa496

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Limit cast fanout (meta-pytorch#1196)
Summary: ActorMesh's shape might have large extents on some dimensions. Those dimensions would cause large fanout in our comm actor implementation. To avoid that, we reshape it by increasing dimensionality and limiting the extent of each dimension. Note: the reshape is only visibility to the internal algorithom. Theshape that user sees maintains intact. For example, a typical shape is [hosts=1024, gpus=8]. By using limit 8, it becomes [8, 8, 8, 2, 8] during casting. In other words, it adds 3 extra layers to the comm actor tree, while keeping the fanout in each layer at 8 or smaller. The limit for cast fanouts will be configured by the key `CASTING_FANOUT_SIZE` which is currently set to 0 as default disabling the feature. Differential Revision: D82320948
1 parent 5913990 commit 7dfa496

File tree

4 files changed

+265
-5
lines changed

4 files changed

+265
-5
lines changed

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 224 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use hyperactor::attrs::Attrs;
2929
use hyperactor::attrs::declare_attrs;
3030
use hyperactor::cap;
3131
use hyperactor::cap::CanSend;
32+
use hyperactor::config;
3233
use hyperactor::mailbox::MailboxSenderError;
3334
use hyperactor::mailbox::PortReceiver;
3435
use hyperactor::message::Castable;
@@ -39,6 +40,8 @@ use ndslice::Selection;
3940
use ndslice::Shape;
4041
use ndslice::ShapeError;
4142
use ndslice::SliceError;
43+
use ndslice::reshape::Limit;
44+
use ndslice::reshape::ReshapeSliceExt;
4245
use ndslice::selection;
4346
use ndslice::selection::EvalOpts;
4447
use ndslice::selection::ReifySlice;
@@ -53,6 +56,7 @@ use crate::Mesh;
5356
use crate::comm::multicast::CastMessage;
5457
use crate::comm::multicast::CastMessageEnvelope;
5558
use crate::comm::multicast::Uslice;
59+
use crate::config::MAX_CAST_DIMENSION_SIZE;
5660
use crate::metrics;
5761
use crate::proc_mesh::ProcMesh;
5862
use crate::reference::ActorMeshId;
@@ -93,13 +97,49 @@ where
9397
cast_mesh_shape.clone(),
9498
message,
9599
)?;
100+
101+
// Mesh's shape might have large extents on some dimensions. Those
102+
// dimensions would cause large fanout in our comm actor
103+
// implementation. To avoid that, we reshape it by increasing
104+
// dimensionality and limiting the extent of each dimension. Note
105+
// the reshape is only visible to the internal algorithm. The
106+
// shape that user sees maintains intact.
107+
//
108+
// For example, a typical shape is [hosts=1024, gpus=8]. By using
109+
// limit 8, it becomes [8, 8, 8, 2, 8] during casting. In other
110+
// words, it adds 3 extra layers to the comm actor tree, while
111+
// keeping the fanout in each layer per dimension at 8 or smaller.
112+
113+
let slice_of_root = root_mesh_shape.slice();
114+
115+
let max_cast_dimension_size = config::global::get(MAX_CAST_DIMENSION_SIZE);
116+
117+
let (selection_of_cast, slice_of_cast) =
118+
// A fanout limit of usize::MAX means that we have configured there to be no reshaping
119+
if max_cast_dimension_size < usize::MAX {
120+
let reshaped_slice = slice_of_root.reshape_with_limit(Limit::from(max_cast_dimension_size));
121+
122+
(
123+
if reshaped_slice != *slice_of_root {
124+
Selection::of_ranks(
125+
&reshaped_slice,
126+
&selection_of_root
127+
.eval(&selection::EvalOpts::strict(), slice_of_root)?
128+
.collect::<BTreeSet<_>>(),
129+
)?
130+
} else {
131+
selection_of_root
132+
},
133+
reshaped_slice,
134+
)
135+
} else {
136+
(selection_of_root, slice_of_root.clone())
137+
};
138+
96139
let cast_message = CastMessage {
97-
// Note: `dest` is on the root mesh' shape, which could be different
98-
// from the cast mesh's shape if the cast is on a view, e.g. a sliced
99-
// mesh.
100140
dest: Uslice {
101-
slice: root_mesh_shape.slice().clone(),
102-
selection: selection_of_root,
141+
slice: slice_of_cast,
142+
selection: selection_of_cast,
103143
},
104144
message,
105145
};
@@ -1469,4 +1509,183 @@ mod tests {
14691509

14701510
actor_mesh_test_suite!(SimAllocator::new_and_start_simnet());
14711511
}
1512+
1513+
mod reshape_cast {
1514+
use async_trait::async_trait;
1515+
use hyperactor::Actor;
1516+
use hyperactor::Context;
1517+
use hyperactor::Handler;
1518+
use hyperactor::channel::ChannelAddr;
1519+
use hyperactor::channel::ChannelTransport;
1520+
use hyperactor::channel::ChannelTx;
1521+
use hyperactor::channel::Rx;
1522+
use hyperactor::channel::Tx;
1523+
use hyperactor::channel::dial;
1524+
use hyperactor::channel::serve;
1525+
use hyperactor_mesh_macros::sel;
1526+
use ndslice::Selection;
1527+
use ndslice::extent;
1528+
1529+
use crate::Mesh;
1530+
use crate::ProcMesh;
1531+
use crate::RootActorMesh;
1532+
use crate::actor_mesh::ActorMesh;
1533+
use crate::alloc::AllocSpec;
1534+
use crate::alloc::Allocator;
1535+
use crate::alloc::LocalAllocator;
1536+
use crate::config::MAX_CAST_DIMENSION_SIZE;
1537+
1538+
#[derive(Debug)]
1539+
#[hyperactor::export(
1540+
spawn = true,
1541+
handlers = [() { cast = true }],
1542+
)]
1543+
struct EchoActor(ChannelTx<usize>);
1544+
1545+
#[async_trait]
1546+
impl Actor for EchoActor {
1547+
type Params = ChannelAddr;
1548+
1549+
async fn new(params: ChannelAddr) -> Result<Self, anyhow::Error> {
1550+
Ok(Self(dial::<usize>(params)?))
1551+
}
1552+
}
1553+
1554+
#[async_trait]
1555+
impl Handler<()> for EchoActor {
1556+
async fn handle(
1557+
&mut self,
1558+
cx: &Context<Self>,
1559+
_message: (),
1560+
) -> Result<(), anyhow::Error> {
1561+
let Self(port) = self;
1562+
port.post(cx.self_id().rank());
1563+
Ok(())
1564+
}
1565+
}
1566+
1567+
#[tokio::test]
1568+
async fn test_reshaped_actor_mesh_cast() {
1569+
let config = hyperactor::config::global::lock();
1570+
let _guard = config.override_key(MAX_CAST_DIMENSION_SIZE, 2);
1571+
1572+
let alloc = LocalAllocator
1573+
.allocate(AllocSpec {
1574+
extent: extent!(host = 16, gpu = 4),
1575+
constraints: Default::default(),
1576+
proc_name: None,
1577+
})
1578+
.await
1579+
.unwrap();
1580+
let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1581+
1582+
let addr = ChannelAddr::any(ChannelTransport::Unix);
1583+
let (_, mut rx) = serve::<u64>(addr.clone()).await.unwrap();
1584+
let actor_mesh: RootActorMesh<EchoActor> =
1585+
proc_mesh.spawn("echo", &addr).await.unwrap();
1586+
1587+
actor_mesh.cast(proc_mesh.client(), sel!(*), ()).unwrap();
1588+
1589+
for _ in 0..(16 * 4) {
1590+
assert!(rx.recv().await.is_ok());
1591+
}
1592+
}
1593+
1594+
#[tokio::test]
1595+
async fn test_reshaped_actor_mesh_ref_cast() {
1596+
let config = hyperactor::config::global::lock();
1597+
let _guard = config.override_key(MAX_CAST_DIMENSION_SIZE, 2);
1598+
1599+
let alloc = LocalAllocator
1600+
.allocate(AllocSpec {
1601+
extent: extent!(host = 16, gpu = 4),
1602+
constraints: Default::default(),
1603+
proc_name: None,
1604+
})
1605+
.await
1606+
.unwrap();
1607+
let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1608+
1609+
let addr = ChannelAddr::any(ChannelTransport::Unix);
1610+
let (_, mut rx) = serve::<u64>(addr.clone()).await.unwrap();
1611+
1612+
let actor_mesh: RootActorMesh<EchoActor> =
1613+
proc_mesh.spawn("echo", &addr).await.unwrap();
1614+
1615+
let mesh_ref = actor_mesh.bind();
1616+
mesh_ref.cast(proc_mesh.client(), sel!(*), ()).unwrap();
1617+
1618+
for _ in 0..(16 * 4) {
1619+
assert!(rx.recv().await.is_ok());
1620+
}
1621+
}
1622+
1623+
#[tokio::test]
1624+
async fn test_reshaped_actor_mesh_slice_cast() {
1625+
let config = hyperactor::config::global::lock();
1626+
let _guard = config.override_key(MAX_CAST_DIMENSION_SIZE, 2);
1627+
1628+
let alloc = LocalAllocator
1629+
.allocate(AllocSpec {
1630+
extent: extent!(host = 8, gpu = 4),
1631+
constraints: Default::default(),
1632+
proc_name: None,
1633+
})
1634+
.await
1635+
.unwrap();
1636+
let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1637+
1638+
let addr = ChannelAddr::any(ChannelTransport::Unix);
1639+
let (_, mut rx) = serve::<u64>(addr.clone()).await.unwrap();
1640+
1641+
let actor_mesh: RootActorMesh<EchoActor> =
1642+
proc_mesh.spawn("echo", &addr).await.unwrap();
1643+
let slice = actor_mesh.select("host", 2..6).unwrap();
1644+
let slice = slice.select("gpu", 2..6).unwrap();
1645+
1646+
slice.cast(proc_mesh.client(), sel!(*), ()).unwrap();
1647+
1648+
let mut received_ranks = vec![];
1649+
for _ in 0..8 {
1650+
let rank = rx.recv().await.unwrap();
1651+
received_ranks.push(rank);
1652+
}
1653+
received_ranks.sort();
1654+
assert_eq!(received_ranks, vec![10, 11, 14, 15, 18, 19, 22, 23]);
1655+
}
1656+
1657+
#[tokio::test]
1658+
async fn test_reshaped_actor_mesh_cast_with_selection() {
1659+
let config = hyperactor::config::global::lock();
1660+
let _guard = config.override_key(MAX_CAST_DIMENSION_SIZE, 2);
1661+
1662+
let alloc = LocalAllocator
1663+
.allocate(AllocSpec {
1664+
extent: extent!(host = 8, gpu = 4),
1665+
constraints: Default::default(),
1666+
proc_name: None,
1667+
})
1668+
.await
1669+
.unwrap();
1670+
let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1671+
1672+
let addr = ChannelAddr::any(ChannelTransport::Unix);
1673+
let (_, mut rx) = serve::<u64>(addr.clone()).await.unwrap();
1674+
1675+
let actor_mesh: RootActorMesh<EchoActor> =
1676+
proc_mesh.spawn("echo", &addr).await.unwrap();
1677+
1678+
actor_mesh
1679+
.cast(proc_mesh.client(), sel!(2:6, 2:6), ())
1680+
.unwrap();
1681+
1682+
let mut received_ranks = vec![];
1683+
for _ in 0..8 {
1684+
let rank = rx.recv().await.unwrap();
1685+
received_ranks.push(rank);
1686+
}
1687+
received_ranks.sort();
1688+
assert_eq!(received_ranks, vec![10, 11, 14, 15, 18, 19, 22, 23]);
1689+
}
1690+
}
14721691
}

hyperactor_mesh/src/config.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
//! Configuration for Hyperactor Mesh.
10+
//!
11+
//! This module provides hyperactor_mesh-specific configuration attributes that extend
12+
//! the base hyperactor configuration system.
13+
14+
use std::env;
15+
16+
use hyperactor::attrs::declare_attrs;
17+
18+
// Declare hyperactor_mesh-specific configuration keys
19+
declare_attrs! {
20+
/// The maximium for a dimension size allowed for a folded shape
21+
/// when reshaping during casting to limit fanout.
22+
/// usize::MAX means no reshaping as any shape will always be below
23+
/// the limit so no dimension needs to be folded.
24+
pub attr MAX_CAST_DIMENSION_SIZE: usize = usize::MAX;
25+
}
26+
27+
pub fn init_from_env() {
28+
let config = hyperactor::config::global::lock();
29+
30+
// Load max cast dimension size.
31+
if let Ok(val) = env::var("HYPERACTOR_MESH_MAX_CAST_DIMENSION_SIZE") {
32+
if let Ok(parsed) = val.parse::<usize>() {
33+
if parsed > 0 {
34+
tracing::info!("overriding MAX_CAST_DIMENSION_SIZE to {}", parsed);
35+
let _guard = config.override_key(MAX_CAST_DIMENSION_SIZE, parsed);
36+
}
37+
}
38+
}
39+
}

hyperactor_mesh/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub mod alloc;
1818
mod assign;
1919
pub mod bootstrap;
2020
pub mod comm;
21+
pub mod config;
2122
pub mod connect;
2223
pub mod logging;
2324
pub mod mesh;

monarch_hyperactor/src/config.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ declare_attrs! {
2727
pub fn reload_config_from_env() -> PyResult<()> {
2828
// Reload the hyperactor global configuration from environment variables
2929
hyperactor::config::global::init_from_env();
30+
hyperactor_mesh::config::init_from_env();
3031
Ok(())
3132
}
3233

0 commit comments

Comments
 (0)