Skip to content

Commit 0df3eda

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Limit cast fanout
Summary: ActorMesh's shape might have large extents on some dimensions. Those dimensions would cause large fanout in our comm actor implementation. To avoid that, we reshape it by increasing dimensionality and limiting the extent of each dimension. Note: the reshape is only visibility to the internal algorithom. Theshape that user sees maintains intact. For example, a typical shape is [hosts=1024, gpus=8]. By using limit 8, it becomes [8, 8, 8, 2, 8] during casting. In other words, it adds 3 extra layers to the comm actor tree, while keeping the fanout in each layer at 8 or smaller. The limit for cast fanouts will be configured by the key `CASTING_FANOUT_SIZE` which is currently set to 0 as default disabling the feature. Differential Revision: D82320948
1 parent 45b18bd commit 0df3eda

File tree

2 files changed

+233
-5
lines changed

2 files changed

+233
-5
lines changed

hyperactor/src/config.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ declare_attrs! {
5959

6060
/// How often to check for full MSPC channel on NetRx.
6161
pub attr CHANNEL_NET_RX_BUFFER_FULL_CHECK_INTERVAL: Duration = Duration::from_secs(5);
62+
63+
/// The reshaping limit used by casting. Zero means no reshaping.
64+
pub attr CASTING_FANOUT_SIZE: usize = 0;
6265
}
6366

6467
/// Load configuration from environment variables
@@ -132,6 +135,14 @@ pub fn from_env() -> Attrs {
132135
}
133136
}
134137

138+
// Load channel cast fanout size
139+
if let Ok(val) = env::var("HYPERACTOR_CASTING_FANOUT_SIZE") {
140+
if let Ok(parsed) = val.parse::<usize>() {
141+
tracing::info!("overriding CASTING_FANOUT_SIZE to {}", parsed);
142+
config[CASTING_FANOUT_SIZE] = parsed;
143+
}
144+
}
145+
135146
config
136147
}
137148

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 222 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use hyperactor::attrs::Attrs;
2929
use hyperactor::attrs::declare_attrs;
3030
use hyperactor::cap;
3131
use hyperactor::cap::CanSend;
32+
use hyperactor::config;
3233
use hyperactor::mailbox::MailboxSenderError;
3334
use hyperactor::mailbox::PortReceiver;
3435
use hyperactor::message::Castable;
@@ -39,6 +40,8 @@ use ndslice::Selection;
3940
use ndslice::Shape;
4041
use ndslice::ShapeError;
4142
use ndslice::SliceError;
43+
use ndslice::reshape::Limit;
44+
use ndslice::reshape::ReshapeSliceExt;
4245
use ndslice::selection;
4346
use ndslice::selection::EvalOpts;
4447
use ndslice::selection::ReifySlice;
@@ -93,13 +96,49 @@ where
9396
cast_mesh_shape.clone(),
9497
message,
9598
)?;
99+
100+
// Mesh's shape might have large extents on some dimensions. Those
101+
// dimensions would cause large fanout in our comm actor
102+
// implementation. To avoid that, we reshape it by increasing
103+
// dimensionality and limiting the extent of each dimension. Note
104+
// the reshape is only visibility to the internal algorithom. The
105+
// shape that user sees maintains intact.
106+
//
107+
// For example, a typical shape is [hosts=1024, gpus=8]. By using
108+
// limit 8, it becomes [8, 8, 8, 2, 8] during casting. In other
109+
// words, it adds 3 extra layers to the comm actor tree, while
110+
// keeping the fanout in each layer at 8 or smaller.
111+
112+
let slice_of_root = root_mesh_shape.slice();
113+
114+
let fanout_limit = config::global::get(config::CASTING_FANOUT_SIZE);
115+
116+
let (selection_of_cast, slice_of_cast) =
117+
// A fanout limit of 0 means that we have configured there to be no reshaping
118+
if fanout_limit > 0 {
119+
let reshaped_slice = slice_of_root.reshape_with_limit(Limit::from(fanout_limit));
120+
121+
(
122+
if reshaped_slice != *slice_of_root {
123+
Selection::of_ranks(
124+
&reshaped_slice,
125+
&selection_of_root
126+
.eval(&selection::EvalOpts::strict(), slice_of_root)?
127+
.collect::<BTreeSet<_>>(),
128+
)?
129+
} else {
130+
selection_of_root
131+
},
132+
reshaped_slice,
133+
)
134+
} else {
135+
(selection_of_root, slice_of_root.clone())
136+
};
137+
96138
let cast_message = CastMessage {
97-
// Note: `dest` is on the root mesh' shape, which could be different
98-
// from the cast mesh's shape if the cast is on a view, e.g. a sliced
99-
// mesh.
100139
dest: Uslice {
101-
slice: root_mesh_shape.slice().clone(),
102-
selection: selection_of_root,
140+
slice: slice_of_cast,
141+
selection: selection_of_cast,
103142
},
104143
message,
105144
};
@@ -1466,4 +1505,182 @@ mod tests {
14661505

14671506
actor_mesh_test_suite!(SimAllocator::new_and_start_simnet());
14681507
}
1508+
1509+
mod reshape_cast {
1510+
use async_trait::async_trait;
1511+
use hyperactor::Actor;
1512+
use hyperactor::Context;
1513+
use hyperactor::Handler;
1514+
use hyperactor::channel::ChannelAddr;
1515+
use hyperactor::channel::ChannelTransport;
1516+
use hyperactor::channel::ChannelTx;
1517+
use hyperactor::channel::Rx;
1518+
use hyperactor::channel::Tx;
1519+
use hyperactor::channel::dial;
1520+
use hyperactor::channel::serve;
1521+
use hyperactor_mesh_macros::sel;
1522+
use ndslice::Selection;
1523+
use ndslice::extent;
1524+
1525+
use crate::Mesh;
1526+
use crate::ProcMesh;
1527+
use crate::RootActorMesh;
1528+
use crate::actor_mesh::ActorMesh;
1529+
use crate::alloc::AllocSpec;
1530+
use crate::alloc::Allocator;
1531+
use crate::alloc::LocalAllocator;
1532+
1533+
#[derive(Debug)]
1534+
#[hyperactor::export(
1535+
spawn = true,
1536+
handlers = [() { cast = true }],
1537+
)]
1538+
struct EchoActor(ChannelTx<usize>);
1539+
1540+
#[async_trait]
1541+
impl Actor for EchoActor {
1542+
type Params = ChannelAddr;
1543+
1544+
async fn new(params: ChannelAddr) -> Result<Self, anyhow::Error> {
1545+
Ok(Self(dial::<usize>(params)?))
1546+
}
1547+
}
1548+
1549+
#[async_trait]
1550+
impl Handler<()> for EchoActor {
1551+
async fn handle(
1552+
&mut self,
1553+
cx: &Context<Self>,
1554+
_message: (),
1555+
) -> Result<(), anyhow::Error> {
1556+
let Self(port) = self;
1557+
port.post(cx.self_id().rank());
1558+
Ok(())
1559+
}
1560+
}
1561+
1562+
#[tokio::test]
1563+
async fn test_reshaped_actor_mesh_cast() {
1564+
let config = hyperactor::config::global::lock();
1565+
let _guard = config.override_key(hyperactor::config::CASTING_FANOUT_SIZE, 2);
1566+
1567+
let alloc = LocalAllocator
1568+
.allocate(AllocSpec {
1569+
extent: extent!(host = 16, gpu = 4),
1570+
constraints: Default::default(),
1571+
proc_name: None,
1572+
})
1573+
.await
1574+
.unwrap();
1575+
let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1576+
1577+
let addr = ChannelAddr::any(ChannelTransport::Unix);
1578+
let (_, mut rx) = serve::<u64>(addr.clone()).await.unwrap();
1579+
let actor_mesh: RootActorMesh<EchoActor> =
1580+
proc_mesh.spawn("echo", &addr).await.unwrap();
1581+
1582+
actor_mesh.cast(proc_mesh.client(), sel!(*), ()).unwrap();
1583+
1584+
for _ in 0..(16 * 4) {
1585+
assert!(rx.recv().await.is_ok());
1586+
}
1587+
}
1588+
1589+
#[tokio::test]
1590+
async fn test_reshaped_actor_mesh_ref_cast() {
1591+
let config = hyperactor::config::global::lock();
1592+
let _guard = config.override_key(hyperactor::config::CASTING_FANOUT_SIZE, 2);
1593+
1594+
let alloc = LocalAllocator
1595+
.allocate(AllocSpec {
1596+
extent: extent!(host = 16, gpu = 4),
1597+
constraints: Default::default(),
1598+
proc_name: None,
1599+
})
1600+
.await
1601+
.unwrap();
1602+
let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1603+
1604+
let addr = ChannelAddr::any(ChannelTransport::Unix);
1605+
let (_, mut rx) = serve::<u64>(addr.clone()).await.unwrap();
1606+
1607+
let actor_mesh: RootActorMesh<EchoActor> =
1608+
proc_mesh.spawn("echo", &addr).await.unwrap();
1609+
1610+
let mesh_ref = actor_mesh.bind();
1611+
mesh_ref.cast(proc_mesh.client(), sel!(*), ()).unwrap();
1612+
1613+
for _ in 0..(16 * 4) {
1614+
assert!(rx.recv().await.is_ok());
1615+
}
1616+
}
1617+
1618+
#[tokio::test]
1619+
async fn test_reshaped_actor_mesh_slice_cast() {
1620+
let config = hyperactor::config::global::lock();
1621+
let _guard = config.override_key(hyperactor::config::CASTING_FANOUT_SIZE, 2);
1622+
1623+
let alloc = LocalAllocator
1624+
.allocate(AllocSpec {
1625+
extent: extent!(host = 8, gpu = 4),
1626+
constraints: Default::default(),
1627+
proc_name: None,
1628+
})
1629+
.await
1630+
.unwrap();
1631+
let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1632+
1633+
let addr = ChannelAddr::any(ChannelTransport::Unix);
1634+
let (_, mut rx) = serve::<u64>(addr.clone()).await.unwrap();
1635+
1636+
let actor_mesh: RootActorMesh<EchoActor> =
1637+
proc_mesh.spawn("echo", &addr).await.unwrap();
1638+
let slice = actor_mesh.select("host", 2..6).unwrap();
1639+
let slice = slice.select("gpu", 2..6).unwrap();
1640+
1641+
slice.cast(proc_mesh.client(), sel!(*), ()).unwrap();
1642+
1643+
let mut received_ranks = vec![];
1644+
for _ in 0..8 {
1645+
let rank = rx.recv().await.unwrap();
1646+
received_ranks.push(rank);
1647+
}
1648+
received_ranks.sort();
1649+
assert_eq!(received_ranks, vec![10, 11, 14, 15, 18, 19, 22, 23]);
1650+
}
1651+
1652+
#[tokio::test]
1653+
async fn test_reshaped_actor_mesh_cast_with_selection() {
1654+
let config = hyperactor::config::global::lock();
1655+
let _guard = config.override_key(hyperactor::config::CASTING_FANOUT_SIZE, 2);
1656+
1657+
let alloc = LocalAllocator
1658+
.allocate(AllocSpec {
1659+
extent: extent!(host = 8, gpu = 4),
1660+
constraints: Default::default(),
1661+
proc_name: None,
1662+
})
1663+
.await
1664+
.unwrap();
1665+
let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1666+
1667+
let addr = ChannelAddr::any(ChannelTransport::Unix);
1668+
let (_, mut rx) = serve::<u64>(addr.clone()).await.unwrap();
1669+
1670+
let actor_mesh: RootActorMesh<EchoActor> =
1671+
proc_mesh.spawn("echo", &addr).await.unwrap();
1672+
1673+
actor_mesh
1674+
.cast(proc_mesh.client(), sel!(2:6, 2:6), ())
1675+
.unwrap();
1676+
1677+
let mut received_ranks = vec![];
1678+
for _ in 0..8 {
1679+
let rank = rx.recv().await.unwrap();
1680+
received_ranks.push(rank);
1681+
}
1682+
received_ranks.sort();
1683+
assert_eq!(received_ranks, vec![10, 11, 14, 15, 18, 19, 22, 23]);
1684+
}
1685+
}
14691686
}

0 commit comments

Comments
 (0)