@@ -29,6 +29,7 @@ use hyperactor::attrs::Attrs;
29
29
use hyperactor:: attrs:: declare_attrs;
30
30
use hyperactor:: cap;
31
31
use hyperactor:: cap:: CanSend ;
32
+ use hyperactor:: config;
32
33
use hyperactor:: mailbox:: MailboxSenderError ;
33
34
use hyperactor:: mailbox:: PortReceiver ;
34
35
use hyperactor:: message:: Castable ;
@@ -39,6 +40,8 @@ use ndslice::Selection;
39
40
use ndslice:: Shape ;
40
41
use ndslice:: ShapeError ;
41
42
use ndslice:: SliceError ;
43
+ use ndslice:: reshape:: Limit ;
44
+ use ndslice:: reshape:: ReshapeSliceExt ;
42
45
use ndslice:: selection;
43
46
use ndslice:: selection:: EvalOpts ;
44
47
use ndslice:: selection:: ReifySlice ;
@@ -53,6 +56,7 @@ use crate::Mesh;
53
56
use crate :: comm:: multicast:: CastMessage ;
54
57
use crate :: comm:: multicast:: CastMessageEnvelope ;
55
58
use crate :: comm:: multicast:: Uslice ;
59
+ use crate :: config:: MAX_CAST_DIMENSION_SIZE ;
56
60
use crate :: metrics;
57
61
use crate :: proc_mesh:: ProcMesh ;
58
62
use crate :: reference:: ActorMeshId ;
@@ -93,13 +97,49 @@ where
93
97
cast_mesh_shape. clone ( ) ,
94
98
message,
95
99
) ?;
100
+
101
+ // Mesh's shape might have large extents on some dimensions. Those
102
+ // dimensions would cause large fanout in our comm actor
103
+ // implementation. To avoid that, we reshape it by increasing
104
+ // dimensionality and limiting the extent of each dimension. Note
105
+ // the reshape is only visible to the internal algorithm. The
106
+ // shape that user sees maintains intact.
107
+ //
108
+ // For example, a typical shape is [hosts=1024, gpus=8]. By using
109
+ // limit 8, it becomes [8, 8, 8, 2, 8] during casting. In other
110
+ // words, it adds 3 extra layers to the comm actor tree, while
111
+ // keeping the fanout in each layer per dimension at 8 or smaller.
112
+
113
+ let slice_of_root = root_mesh_shape. slice ( ) ;
114
+
115
+ let max_cast_dimension_size = config:: global:: get ( MAX_CAST_DIMENSION_SIZE ) ;
116
+
117
+ let ( selection_of_cast, slice_of_cast) =
118
+ // A fanout limit of usize::MAX means that we have configured there to be no reshaping
119
+ if max_cast_dimension_size < usize:: MAX {
120
+ let reshaped_slice = slice_of_root. reshape_with_limit ( Limit :: from ( max_cast_dimension_size) ) ;
121
+
122
+ (
123
+ if reshaped_slice != * slice_of_root {
124
+ Selection :: of_ranks (
125
+ & reshaped_slice,
126
+ & selection_of_root
127
+ . eval ( & selection:: EvalOpts :: strict ( ) , slice_of_root) ?
128
+ . collect :: < BTreeSet < _ > > ( ) ,
129
+ ) ?
130
+ } else {
131
+ selection_of_root
132
+ } ,
133
+ reshaped_slice,
134
+ )
135
+ } else {
136
+ ( selection_of_root, slice_of_root. clone ( ) )
137
+ } ;
138
+
96
139
let cast_message = CastMessage {
97
- // Note: `dest` is on the root mesh' shape, which could be different
98
- // from the cast mesh's shape if the cast is on a view, e.g. a sliced
99
- // mesh.
100
140
dest : Uslice {
101
- slice : root_mesh_shape . slice ( ) . clone ( ) ,
102
- selection : selection_of_root ,
141
+ slice : slice_of_cast ,
142
+ selection : selection_of_cast ,
103
143
} ,
104
144
message,
105
145
} ;
@@ -1469,4 +1509,183 @@ mod tests {
1469
1509
1470
1510
actor_mesh_test_suite ! ( SimAllocator :: new_and_start_simnet( ) ) ;
1471
1511
}
1512
+
1513
+ mod reshape_cast {
1514
+ use async_trait:: async_trait;
1515
+ use hyperactor:: Actor ;
1516
+ use hyperactor:: Context ;
1517
+ use hyperactor:: Handler ;
1518
+ use hyperactor:: channel:: ChannelAddr ;
1519
+ use hyperactor:: channel:: ChannelTransport ;
1520
+ use hyperactor:: channel:: ChannelTx ;
1521
+ use hyperactor:: channel:: Rx ;
1522
+ use hyperactor:: channel:: Tx ;
1523
+ use hyperactor:: channel:: dial;
1524
+ use hyperactor:: channel:: serve;
1525
+ use hyperactor_mesh_macros:: sel;
1526
+ use ndslice:: Selection ;
1527
+ use ndslice:: extent;
1528
+
1529
+ use crate :: Mesh ;
1530
+ use crate :: ProcMesh ;
1531
+ use crate :: RootActorMesh ;
1532
+ use crate :: actor_mesh:: ActorMesh ;
1533
+ use crate :: alloc:: AllocSpec ;
1534
+ use crate :: alloc:: Allocator ;
1535
+ use crate :: alloc:: LocalAllocator ;
1536
+ use crate :: config:: MAX_CAST_DIMENSION_SIZE ;
1537
+
1538
+ #[ derive( Debug ) ]
1539
+ #[ hyperactor:: export(
1540
+ spawn = true ,
1541
+ handlers = [ ( ) { cast = true } ] ,
1542
+ ) ]
1543
+ struct EchoActor ( ChannelTx < usize > ) ;
1544
+
1545
+ #[ async_trait]
1546
+ impl Actor for EchoActor {
1547
+ type Params = ChannelAddr ;
1548
+
1549
+ async fn new ( params : ChannelAddr ) -> Result < Self , anyhow:: Error > {
1550
+ Ok ( Self ( dial :: < usize > ( params) ?) )
1551
+ }
1552
+ }
1553
+
1554
+ #[ async_trait]
1555
+ impl Handler < ( ) > for EchoActor {
1556
+ async fn handle (
1557
+ & mut self ,
1558
+ cx : & Context < Self > ,
1559
+ _message : ( ) ,
1560
+ ) -> Result < ( ) , anyhow:: Error > {
1561
+ let Self ( port) = self ;
1562
+ port. post ( cx. self_id ( ) . rank ( ) ) ;
1563
+ Ok ( ( ) )
1564
+ }
1565
+ }
1566
+
1567
+ #[ tokio:: test]
1568
+ async fn test_reshaped_actor_mesh_cast ( ) {
1569
+ let config = hyperactor:: config:: global:: lock ( ) ;
1570
+ let _guard = config. override_key ( MAX_CAST_DIMENSION_SIZE , 2 ) ;
1571
+
1572
+ let alloc = LocalAllocator
1573
+ . allocate ( AllocSpec {
1574
+ extent : extent ! ( host = 16 , gpu = 4 ) ,
1575
+ constraints : Default :: default ( ) ,
1576
+ proc_name : None ,
1577
+ } )
1578
+ . await
1579
+ . unwrap ( ) ;
1580
+ let proc_mesh = ProcMesh :: allocate ( alloc) . await . unwrap ( ) ;
1581
+
1582
+ let addr = ChannelAddr :: any ( ChannelTransport :: Unix ) ;
1583
+ let ( _, mut rx) = serve :: < u64 > ( addr. clone ( ) ) . await . unwrap ( ) ;
1584
+ let actor_mesh: RootActorMesh < EchoActor > =
1585
+ proc_mesh. spawn ( "echo" , & addr) . await . unwrap ( ) ;
1586
+
1587
+ actor_mesh. cast ( proc_mesh. client ( ) , sel ! ( * ) , ( ) ) . unwrap ( ) ;
1588
+
1589
+ for _ in 0 ..( 16 * 4 ) {
1590
+ assert ! ( rx. recv( ) . await . is_ok( ) ) ;
1591
+ }
1592
+ }
1593
+
1594
+ #[ tokio:: test]
1595
+ async fn test_reshaped_actor_mesh_ref_cast ( ) {
1596
+ let config = hyperactor:: config:: global:: lock ( ) ;
1597
+ let _guard = config. override_key ( MAX_CAST_DIMENSION_SIZE , 2 ) ;
1598
+
1599
+ let alloc = LocalAllocator
1600
+ . allocate ( AllocSpec {
1601
+ extent : extent ! ( host = 16 , gpu = 4 ) ,
1602
+ constraints : Default :: default ( ) ,
1603
+ proc_name : None ,
1604
+ } )
1605
+ . await
1606
+ . unwrap ( ) ;
1607
+ let proc_mesh = ProcMesh :: allocate ( alloc) . await . unwrap ( ) ;
1608
+
1609
+ let addr = ChannelAddr :: any ( ChannelTransport :: Unix ) ;
1610
+ let ( _, mut rx) = serve :: < u64 > ( addr. clone ( ) ) . await . unwrap ( ) ;
1611
+
1612
+ let actor_mesh: RootActorMesh < EchoActor > =
1613
+ proc_mesh. spawn ( "echo" , & addr) . await . unwrap ( ) ;
1614
+
1615
+ let mesh_ref = actor_mesh. bind ( ) ;
1616
+ mesh_ref. cast ( proc_mesh. client ( ) , sel ! ( * ) , ( ) ) . unwrap ( ) ;
1617
+
1618
+ for _ in 0 ..( 16 * 4 ) {
1619
+ assert ! ( rx. recv( ) . await . is_ok( ) ) ;
1620
+ }
1621
+ }
1622
+
1623
+ #[ tokio:: test]
1624
+ async fn test_reshaped_actor_mesh_slice_cast ( ) {
1625
+ let config = hyperactor:: config:: global:: lock ( ) ;
1626
+ let _guard = config. override_key ( MAX_CAST_DIMENSION_SIZE , 2 ) ;
1627
+
1628
+ let alloc = LocalAllocator
1629
+ . allocate ( AllocSpec {
1630
+ extent : extent ! ( host = 8 , gpu = 4 ) ,
1631
+ constraints : Default :: default ( ) ,
1632
+ proc_name : None ,
1633
+ } )
1634
+ . await
1635
+ . unwrap ( ) ;
1636
+ let proc_mesh = ProcMesh :: allocate ( alloc) . await . unwrap ( ) ;
1637
+
1638
+ let addr = ChannelAddr :: any ( ChannelTransport :: Unix ) ;
1639
+ let ( _, mut rx) = serve :: < u64 > ( addr. clone ( ) ) . await . unwrap ( ) ;
1640
+
1641
+ let actor_mesh: RootActorMesh < EchoActor > =
1642
+ proc_mesh. spawn ( "echo" , & addr) . await . unwrap ( ) ;
1643
+ let slice = actor_mesh. select ( "host" , 2 ..6 ) . unwrap ( ) ;
1644
+ let slice = slice. select ( "gpu" , 2 ..6 ) . unwrap ( ) ;
1645
+
1646
+ slice. cast ( proc_mesh. client ( ) , sel ! ( * ) , ( ) ) . unwrap ( ) ;
1647
+
1648
+ let mut received_ranks = vec ! [ ] ;
1649
+ for _ in 0 ..8 {
1650
+ let rank = rx. recv ( ) . await . unwrap ( ) ;
1651
+ received_ranks. push ( rank) ;
1652
+ }
1653
+ received_ranks. sort ( ) ;
1654
+ assert_eq ! ( received_ranks, vec![ 10 , 11 , 14 , 15 , 18 , 19 , 22 , 23 ] ) ;
1655
+ }
1656
+
1657
+ #[ tokio:: test]
1658
+ async fn test_reshaped_actor_mesh_cast_with_selection ( ) {
1659
+ let config = hyperactor:: config:: global:: lock ( ) ;
1660
+ let _guard = config. override_key ( MAX_CAST_DIMENSION_SIZE , 2 ) ;
1661
+
1662
+ let alloc = LocalAllocator
1663
+ . allocate ( AllocSpec {
1664
+ extent : extent ! ( host = 8 , gpu = 4 ) ,
1665
+ constraints : Default :: default ( ) ,
1666
+ proc_name : None ,
1667
+ } )
1668
+ . await
1669
+ . unwrap ( ) ;
1670
+ let proc_mesh = ProcMesh :: allocate ( alloc) . await . unwrap ( ) ;
1671
+
1672
+ let addr = ChannelAddr :: any ( ChannelTransport :: Unix ) ;
1673
+ let ( _, mut rx) = serve :: < u64 > ( addr. clone ( ) ) . await . unwrap ( ) ;
1674
+
1675
+ let actor_mesh: RootActorMesh < EchoActor > =
1676
+ proc_mesh. spawn ( "echo" , & addr) . await . unwrap ( ) ;
1677
+
1678
+ actor_mesh
1679
+ . cast ( proc_mesh. client ( ) , sel ! ( 2 : 6 , 2 : 6 ) , ( ) )
1680
+ . unwrap ( ) ;
1681
+
1682
+ let mut received_ranks = vec ! [ ] ;
1683
+ for _ in 0 ..8 {
1684
+ let rank = rx. recv ( ) . await . unwrap ( ) ;
1685
+ received_ranks. push ( rank) ;
1686
+ }
1687
+ received_ranks. sort ( ) ;
1688
+ assert_eq ! ( received_ranks, vec![ 10 , 11 , 14 , 15 , 18 , 19 , 22 , 23 ] ) ;
1689
+ }
1690
+ }
1472
1691
}
0 commit comments