@@ -29,6 +29,7 @@ use hyperactor::attrs::Attrs;
29
29
use hyperactor:: attrs:: declare_attrs;
30
30
use hyperactor:: cap;
31
31
use hyperactor:: cap:: CanSend ;
32
+ use hyperactor:: config;
32
33
use hyperactor:: mailbox:: MailboxSenderError ;
33
34
use hyperactor:: mailbox:: PortReceiver ;
34
35
use hyperactor:: message:: Castable ;
@@ -39,6 +40,8 @@ use ndslice::Selection;
39
40
use ndslice:: Shape ;
40
41
use ndslice:: ShapeError ;
41
42
use ndslice:: SliceError ;
43
+ use ndslice:: reshape:: Limit ;
44
+ use ndslice:: reshape:: ReshapeSliceExt ;
42
45
use ndslice:: selection;
43
46
use ndslice:: selection:: EvalOpts ;
44
47
use ndslice:: selection:: ReifySlice ;
@@ -93,13 +96,49 @@ where
93
96
cast_mesh_shape. clone ( ) ,
94
97
message,
95
98
) ?;
99
+
100
+ // Mesh's shape might have large extents on some dimensions. Those
101
+ // dimensions would cause large fanout in our comm actor
102
+ // implementation. To avoid that, we reshape it by increasing
103
+ // dimensionality and limiting the extent of each dimension. Note
104
+ // the reshape is only visibility to the internal algorithom. The
105
+ // shape that user sees maintains intact.
106
+ //
107
+ // For example, a typical shape is [hosts=1024, gpus=8]. By using
108
+ // limit 8, it becomes [8, 8, 8, 2, 8] during casting. In other
109
+ // words, it adds 3 extra layers to the comm actor tree, while
110
+ // keeping the fanout in each layer at 8 or smaller.
111
+
112
+ let slice_of_root = root_mesh_shape. slice ( ) ;
113
+
114
+ let fanout_limit = config:: global:: get ( config:: CASTING_FANOUT_SIZE ) ;
115
+
116
+ let ( selection_of_cast, slice_of_cast) =
117
+ // A fanout limit of 0 means that we have configured there to be no reshaping
118
+ if fanout_limit > 0 {
119
+ let reshaped_slice = slice_of_root. reshape_with_limit ( Limit :: from ( fanout_limit) ) ;
120
+
121
+ (
122
+ if reshaped_slice != * slice_of_root {
123
+ Selection :: of_ranks (
124
+ & reshaped_slice,
125
+ & selection_of_root
126
+ . eval ( & selection:: EvalOpts :: strict ( ) , slice_of_root) ?
127
+ . collect :: < BTreeSet < _ > > ( ) ,
128
+ ) ?
129
+ } else {
130
+ selection_of_root
131
+ } ,
132
+ reshaped_slice,
133
+ )
134
+ } else {
135
+ ( selection_of_root, slice_of_root. clone ( ) )
136
+ } ;
137
+
96
138
let cast_message = CastMessage {
97
- // Note: `dest` is on the root mesh' shape, which could be different
98
- // from the cast mesh's shape if the cast is on a view, e.g. a sliced
99
- // mesh.
100
139
dest : Uslice {
101
- slice : root_mesh_shape . slice ( ) . clone ( ) ,
102
- selection : selection_of_root ,
140
+ slice : slice_of_cast ,
141
+ selection : selection_of_cast ,
103
142
} ,
104
143
message,
105
144
} ;
@@ -1466,4 +1505,182 @@ mod tests {
1466
1505
1467
1506
actor_mesh_test_suite ! ( SimAllocator :: new_and_start_simnet( ) ) ;
1468
1507
}
1508
+
1509
+ mod reshape_cast {
1510
+ use async_trait:: async_trait;
1511
+ use hyperactor:: Actor ;
1512
+ use hyperactor:: Context ;
1513
+ use hyperactor:: Handler ;
1514
+ use hyperactor:: channel:: ChannelAddr ;
1515
+ use hyperactor:: channel:: ChannelTransport ;
1516
+ use hyperactor:: channel:: ChannelTx ;
1517
+ use hyperactor:: channel:: Rx ;
1518
+ use hyperactor:: channel:: Tx ;
1519
+ use hyperactor:: channel:: dial;
1520
+ use hyperactor:: channel:: serve;
1521
+ use hyperactor_mesh_macros:: sel;
1522
+ use ndslice:: Selection ;
1523
+ use ndslice:: extent;
1524
+
1525
+ use crate :: Mesh ;
1526
+ use crate :: ProcMesh ;
1527
+ use crate :: RootActorMesh ;
1528
+ use crate :: actor_mesh:: ActorMesh ;
1529
+ use crate :: alloc:: AllocSpec ;
1530
+ use crate :: alloc:: Allocator ;
1531
+ use crate :: alloc:: LocalAllocator ;
1532
+
1533
+ #[ derive( Debug ) ]
1534
+ #[ hyperactor:: export(
1535
+ spawn = true ,
1536
+ handlers = [ ( ) { cast = true } ] ,
1537
+ ) ]
1538
+ struct EchoActor ( ChannelTx < usize > ) ;
1539
+
1540
+ #[ async_trait]
1541
+ impl Actor for EchoActor {
1542
+ type Params = ChannelAddr ;
1543
+
1544
+ async fn new ( params : ChannelAddr ) -> Result < Self , anyhow:: Error > {
1545
+ Ok ( Self ( dial :: < usize > ( params) ?) )
1546
+ }
1547
+ }
1548
+
1549
+ #[ async_trait]
1550
+ impl Handler < ( ) > for EchoActor {
1551
+ async fn handle (
1552
+ & mut self ,
1553
+ cx : & Context < Self > ,
1554
+ _message : ( ) ,
1555
+ ) -> Result < ( ) , anyhow:: Error > {
1556
+ let Self ( port) = self ;
1557
+ port. post ( cx. self_id ( ) . rank ( ) ) ;
1558
+ Ok ( ( ) )
1559
+ }
1560
+ }
1561
+
1562
+ #[ tokio:: test]
1563
+ async fn test_reshaped_actor_mesh_cast ( ) {
1564
+ let config = hyperactor:: config:: global:: lock ( ) ;
1565
+ let _guard = config. override_key ( hyperactor:: config:: CASTING_FANOUT_SIZE , 2 ) ;
1566
+
1567
+ let alloc = LocalAllocator
1568
+ . allocate ( AllocSpec {
1569
+ extent : extent ! ( host = 16 , gpu = 4 ) ,
1570
+ constraints : Default :: default ( ) ,
1571
+ proc_name : None ,
1572
+ } )
1573
+ . await
1574
+ . unwrap ( ) ;
1575
+ let proc_mesh = ProcMesh :: allocate ( alloc) . await . unwrap ( ) ;
1576
+
1577
+ let addr = ChannelAddr :: any ( ChannelTransport :: Unix ) ;
1578
+ let ( _, mut rx) = serve :: < u64 > ( addr. clone ( ) ) . await . unwrap ( ) ;
1579
+ let actor_mesh: RootActorMesh < EchoActor > =
1580
+ proc_mesh. spawn ( "echo" , & addr) . await . unwrap ( ) ;
1581
+
1582
+ actor_mesh. cast ( proc_mesh. client ( ) , sel ! ( * ) , ( ) ) . unwrap ( ) ;
1583
+
1584
+ for _ in 0 ..( 16 * 4 ) {
1585
+ assert ! ( rx. recv( ) . await . is_ok( ) ) ;
1586
+ }
1587
+ }
1588
+
1589
+ #[ tokio:: test]
1590
+ async fn test_reshaped_actor_mesh_ref_cast ( ) {
1591
+ let config = hyperactor:: config:: global:: lock ( ) ;
1592
+ let _guard = config. override_key ( hyperactor:: config:: CASTING_FANOUT_SIZE , 2 ) ;
1593
+
1594
+ let alloc = LocalAllocator
1595
+ . allocate ( AllocSpec {
1596
+ extent : extent ! ( host = 16 , gpu = 4 ) ,
1597
+ constraints : Default :: default ( ) ,
1598
+ proc_name : None ,
1599
+ } )
1600
+ . await
1601
+ . unwrap ( ) ;
1602
+ let proc_mesh = ProcMesh :: allocate ( alloc) . await . unwrap ( ) ;
1603
+
1604
+ let addr = ChannelAddr :: any ( ChannelTransport :: Unix ) ;
1605
+ let ( _, mut rx) = serve :: < u64 > ( addr. clone ( ) ) . await . unwrap ( ) ;
1606
+
1607
+ let actor_mesh: RootActorMesh < EchoActor > =
1608
+ proc_mesh. spawn ( "echo" , & addr) . await . unwrap ( ) ;
1609
+
1610
+ let mesh_ref = actor_mesh. bind ( ) ;
1611
+ mesh_ref. cast ( proc_mesh. client ( ) , sel ! ( * ) , ( ) ) . unwrap ( ) ;
1612
+
1613
+ for _ in 0 ..( 16 * 4 ) {
1614
+ assert ! ( rx. recv( ) . await . is_ok( ) ) ;
1615
+ }
1616
+ }
1617
+
1618
+ #[ tokio:: test]
1619
+ async fn test_reshaped_actor_mesh_slice_cast ( ) {
1620
+ let config = hyperactor:: config:: global:: lock ( ) ;
1621
+ let _guard = config. override_key ( hyperactor:: config:: CASTING_FANOUT_SIZE , 2 ) ;
1622
+
1623
+ let alloc = LocalAllocator
1624
+ . allocate ( AllocSpec {
1625
+ extent : extent ! ( host = 8 , gpu = 4 ) ,
1626
+ constraints : Default :: default ( ) ,
1627
+ proc_name : None ,
1628
+ } )
1629
+ . await
1630
+ . unwrap ( ) ;
1631
+ let proc_mesh = ProcMesh :: allocate ( alloc) . await . unwrap ( ) ;
1632
+
1633
+ let addr = ChannelAddr :: any ( ChannelTransport :: Unix ) ;
1634
+ let ( _, mut rx) = serve :: < u64 > ( addr. clone ( ) ) . await . unwrap ( ) ;
1635
+
1636
+ let actor_mesh: RootActorMesh < EchoActor > =
1637
+ proc_mesh. spawn ( "echo" , & addr) . await . unwrap ( ) ;
1638
+ let slice = actor_mesh. select ( "host" , 2 ..6 ) . unwrap ( ) ;
1639
+ let slice = slice. select ( "gpu" , 2 ..6 ) . unwrap ( ) ;
1640
+
1641
+ slice. cast ( proc_mesh. client ( ) , sel ! ( * ) , ( ) ) . unwrap ( ) ;
1642
+
1643
+ let mut received_ranks = vec ! [ ] ;
1644
+ for _ in 0 ..8 {
1645
+ let rank = rx. recv ( ) . await . unwrap ( ) ;
1646
+ received_ranks. push ( rank) ;
1647
+ }
1648
+ received_ranks. sort ( ) ;
1649
+ assert_eq ! ( received_ranks, vec![ 10 , 11 , 14 , 15 , 18 , 19 , 22 , 23 ] ) ;
1650
+ }
1651
+
1652
+ #[ tokio:: test]
1653
+ async fn test_reshaped_actor_mesh_cast_with_selection ( ) {
1654
+ let config = hyperactor:: config:: global:: lock ( ) ;
1655
+ let _guard = config. override_key ( hyperactor:: config:: CASTING_FANOUT_SIZE , 2 ) ;
1656
+
1657
+ let alloc = LocalAllocator
1658
+ . allocate ( AllocSpec {
1659
+ extent : extent ! ( host = 8 , gpu = 4 ) ,
1660
+ constraints : Default :: default ( ) ,
1661
+ proc_name : None ,
1662
+ } )
1663
+ . await
1664
+ . unwrap ( ) ;
1665
+ let proc_mesh = ProcMesh :: allocate ( alloc) . await . unwrap ( ) ;
1666
+
1667
+ let addr = ChannelAddr :: any ( ChannelTransport :: Unix ) ;
1668
+ let ( _, mut rx) = serve :: < u64 > ( addr. clone ( ) ) . await . unwrap ( ) ;
1669
+
1670
+ let actor_mesh: RootActorMesh < EchoActor > =
1671
+ proc_mesh. spawn ( "echo" , & addr) . await . unwrap ( ) ;
1672
+
1673
+ actor_mesh
1674
+ . cast ( proc_mesh. client ( ) , sel ! ( 2 : 6 , 2 : 6 ) , ( ) )
1675
+ . unwrap ( ) ;
1676
+
1677
+ let mut received_ranks = vec ! [ ] ;
1678
+ for _ in 0 ..8 {
1679
+ let rank = rx. recv ( ) . await . unwrap ( ) ;
1680
+ received_ranks. push ( rank) ;
1681
+ }
1682
+ received_ranks. sort ( ) ;
1683
+ assert_eq ! ( received_ranks, vec![ 10 , 11 , 14 , 15 , 18 , 19 , 22 , 23 ] ) ;
1684
+ }
1685
+ }
1469
1686
}
0 commit comments