@@ -109,7 +109,7 @@ impl<'a> AmMsg<'a> {
109109 AmMsg { worker, msg }
110110 }
111111
112- /// Get the message ID.
112+ /// Get the ActiveStream id
113113 #[ inline]
114114 pub fn id ( & self ) -> u16 {
115115 self . msg . id
@@ -121,10 +121,10 @@ impl<'a> AmMsg<'a> {
121121 self . msg . header . as_ref ( )
122122 }
123123
124- /// Get the message header length .
124+ /// Returns `true` if the message contains data. Otherwise, `false` .
125125 #[ inline]
126126 pub fn contains_data ( & self ) -> bool {
127- self . data_type ( ) . is_some ( )
127+ self . msg . data . is_some ( )
128128 }
129129
130130 /// Get the message data type.
@@ -133,10 +133,14 @@ impl<'a> AmMsg<'a> {
133133 }
134134
135135 /// Get the message data.
136- /// Returns `None` if the message doesn't contain data.
136+ /// Returns `None` if needs to receive data.
137+ /// Returns `Some(slice)` if the message contains concrete data.
137138 #[ inline]
138139 pub fn get_data ( & self ) -> Option < & [ u8 ] > {
139- self . msg . data . as_ref ( ) . and_then ( |data| data. data ( ) )
140+ match self . msg . data {
141+ Some ( ref amdata) => amdata. data ( ) ,
142+ None => Some ( & [ ] ) ,
143+ }
140144 }
141145
142146 /// Get the message data length.
@@ -151,6 +155,11 @@ impl<'a> AmMsg<'a> {
151155 match self . msg . data . take ( ) {
152156 None => Ok ( Vec :: new ( ) ) ,
153157 Some ( AmData :: Eager ( vec) ) => Ok ( vec) ,
158+ Some ( AmData :: Data ( data) ) => {
159+ let v = data. to_vec ( ) ;
160+ self . drop_msg ( AmData :: Data ( data) ) ;
161+ Ok ( v)
162+ }
154163 Some ( data) => {
155164 self . msg . data = Some ( data) ;
156165 let mut buf = Vec :: with_capacity ( self . data_len ( ) ) ;
@@ -181,104 +190,110 @@ impl<'a> AmMsg<'a> {
181190
182191 /// Receive the message data.
183192 pub async fn recv_data_vectored ( & mut self , iov : & [ IoSliceMut < ' _ > ] ) -> Result < usize , Error > {
184- let data = self . msg . data . take ( ) ;
185- if let Some ( data) = data {
186- if let AmData :: Eager ( data) = data {
187- // return error if buffer size < data length, same with ucx
188- let cap = iov. iter ( ) . fold ( 0_usize , |cap, buf| cap + buf. len ( ) ) ;
189- if cap < data. len ( ) {
190- return Err ( Error :: MessageTruncated ) ;
191- }
193+ fn copy_data_to_iov ( data : & [ u8 ] , iov : & [ IoSliceMut < ' _ > ] ) -> Result < usize , Error > {
194+ // return error if buffer size < data length, same with ucx
195+ let cap = iov. iter ( ) . fold ( 0_usize , |cap, buf| cap + buf. len ( ) ) ;
196+ if cap < data. len ( ) {
197+ return Err ( Error :: MessageTruncated ) ;
198+ }
192199
193- let mut copied = 0_usize ;
194- for buf in iov {
195- let len = std:: cmp:: min ( data. len ( ) - copied, buf. len ( ) ) ;
196- if len == 0 {
197- break ;
198- }
200+ let mut copied = 0_usize ;
201+ for buf in iov {
202+ let len = std:: cmp:: min ( data. len ( ) - copied, buf. len ( ) ) ;
203+ if len == 0 {
204+ break ;
205+ }
199206
200- let buf = & buf[ ..len] ;
201- unsafe {
202- std:: ptr:: copy_nonoverlapping (
203- data[ copied..] . as_ptr ( ) ,
204- buf. as_ptr ( ) as _ ,
205- len,
206- )
207- }
208- copied += len;
207+ let buf = & buf[ ..len] ;
208+ unsafe {
209+ std:: ptr:: copy_nonoverlapping ( data[ copied..] . as_ptr ( ) , buf. as_ptr ( ) as _ , len)
209210 }
210- return Ok ( copied) ;
211+ copied += len ;
211212 }
213+ Ok ( copied)
214+ }
215+ let data = self . msg . data . take ( ) ;
212216
213- let ( data_desc, data_len) = match data {
214- AmData :: Data ( data) => ( data. as_ptr ( ) , data. len ( ) ) ,
215- AmData :: Rndv ( data) => ( data. as_ptr ( ) , data. len ( ) ) ,
216- _ => unreachable ! ( ) ,
217- } ;
218-
219- unsafe extern "C" fn callback (
220- request : * mut c_void ,
221- status : ucs_status_t ,
222- _length : usize ,
223- _data : * mut c_void ,
224- ) {
225- // todo: handle error & fix real data length
217+ match data {
218+ Some ( AmData :: Eager ( data) ) => {
219+ // eager message, no need to receive
220+ copy_data_to_iov ( & data, iov)
221+ }
222+ Some ( AmData :: Data ( data) ) => {
223+ // data message, no need to receive
224+ let size = copy_data_to_iov ( & data, iov) ?;
225+ self . drop_msg ( AmData :: Data ( data) ) ;
226+ Ok ( size)
227+ }
228+ Some ( AmData :: Rndv ( desc) ) => {
229+ // rndv message, need to receive
230+ let ( data_desc, data_len) = ( desc. as_ptr ( ) , desc. len ( ) ) ;
231+
232+ unsafe extern "C" fn callback (
233+ request : * mut c_void ,
234+ status : ucs_status_t ,
235+ _length : usize ,
236+ _data : * mut c_void ,
237+ ) {
238+ // todo: handle error & fix real data length
239+ trace ! (
240+ "recv_data_vectored: complete, req={:?}, status={:?}" ,
241+ request,
242+ status
243+ ) ;
244+ let request = & mut * ( request as * mut Request ) ;
245+ request. waker . wake ( ) ;
246+ }
226247 trace ! (
227- "recv_data_vectored: complete, req ={:?}, status={:? }" ,
228- request ,
229- status
248+ "recv_data_vectored: worker ={:?} iov.len={ }" ,
249+ self . worker . handle ,
250+ iov . len ( )
230251 ) ;
231- let request = & mut * ( request as * mut Request ) ;
232- request. waker . wake ( ) ;
233- }
234- trace ! (
235- "recv_data_vectored: worker={:?} iov.len={}" ,
236- self . worker. handle,
237- iov. len( )
238- ) ;
239- let mut param = MaybeUninit :: < ucp_request_param_t > :: uninit ( ) ;
240- let ( buffer, count) = unsafe {
241- let param = & mut * param. as_mut_ptr ( ) ;
242- param. op_attr_mask = ucp_op_attr_t:: UCP_OP_ATTR_FIELD_CALLBACK as u32
243- | ucp_op_attr_t:: UCP_OP_ATTR_FIELD_DATATYPE as u32 ;
244- param. cb = ucp_request_param_t__bindgen_ty_1 {
245- recv_am : Some ( callback) ,
252+ let mut param = MaybeUninit :: < ucp_request_param_t > :: uninit ( ) ;
253+ let ( buffer, count) = unsafe {
254+ let param = & mut * param. as_mut_ptr ( ) ;
255+ param. op_attr_mask = ucp_op_attr_t:: UCP_OP_ATTR_FIELD_CALLBACK as u32
256+ | ucp_op_attr_t:: UCP_OP_ATTR_FIELD_DATATYPE as u32 ;
257+ param. cb = ucp_request_param_t__bindgen_ty_1 {
258+ recv_am : Some ( callback) ,
259+ } ;
260+
261+ if iov. len ( ) == 1 {
262+ param. datatype = ucp_dt_make_contig ( 1 ) ;
263+ ( iov[ 0 ] . as_ptr ( ) , iov[ 0 ] . len ( ) )
264+ } else {
265+ param. datatype = ucp_dt_type:: UCP_DATATYPE_IOV as _ ;
266+ ( iov. as_ptr ( ) as _ , iov. len ( ) )
267+ }
246268 } ;
247269
248- if iov. len ( ) == 1 {
249- param. datatype = ucp_dt_make_contig ( 1 ) ;
250- ( iov[ 0 ] . as_ptr ( ) , iov[ 0 ] . len ( ) )
270+ let status = unsafe {
271+ ucp_am_recv_data_nbx (
272+ self . worker . handle ,
273+ data_desc as _ ,
274+ buffer as _ ,
275+ count as _ ,
276+ param. as_ptr ( ) ,
277+ )
278+ } ;
279+ if status. is_null ( ) {
280+ trace ! ( "recv_data_vectored: complete" ) ;
281+ Ok ( data_len)
282+ } else if UCS_PTR_IS_PTR ( status) {
283+ RequestHandle {
284+ ptr : status,
285+ poll_fn : poll_recv,
286+ }
287+ . await ;
288+ Ok ( data_len)
251289 } else {
252- param. datatype = ucp_dt_type:: UCP_DATATYPE_IOV as _ ;
253- ( iov. as_ptr ( ) as _ , iov. len ( ) )
290+ Err ( Error :: from_ptr ( status) . unwrap_err ( ) )
254291 }
255- } ;
256-
257- let status = unsafe {
258- ucp_am_recv_data_nbx (
259- self . worker . handle ,
260- data_desc as _ ,
261- buffer as _ ,
262- count as _ ,
263- param. as_ptr ( ) ,
264- )
265- } ;
266- if status. is_null ( ) {
267- trace ! ( "recv_data_vectored: complete" ) ;
268- Ok ( data_len)
269- } else if UCS_PTR_IS_PTR ( status) {
270- RequestHandle {
271- ptr : status,
272- poll_fn : poll_recv,
273- }
274- . await ;
275- Ok ( data_len)
276- } else {
277- Err ( Error :: from_ptr ( status) . unwrap_err ( ) )
278292 }
279- } else {
280- // no data
281- Ok ( 0 )
293+ None => {
294+ // no data
295+ Ok ( 0 )
296+ }
282297 }
283298 }
284299
@@ -321,18 +336,24 @@ impl<'a> AmMsg<'a> {
321336 assert ! ( self . need_reply( ) ) ;
322337 am_send ( self . msg . reply_ep , id, header, data, need_reply, proto) . await
323338 }
339+
340+ fn drop_msg ( & mut self , data : AmData ) {
341+ match data {
342+ AmData :: Eager ( _) => ( ) ,
343+ AmData :: Data ( data) => unsafe {
344+ ucp_am_data_release ( self . worker . handle , data. as_ptr ( ) as _ ) ;
345+ } ,
346+ AmData :: Rndv ( data) => unsafe {
347+ ucp_am_data_release ( self . worker . handle , data. as_ptr ( ) as _ ) ;
348+ } ,
349+ }
350+ }
324351}
325352
326353impl < ' a > Drop for AmMsg < ' a > {
327354 fn drop ( & mut self ) {
328- match self . msg . data . take ( ) {
329- Some ( AmData :: Data ( desc) ) => unsafe {
330- ucp_am_data_release ( self . worker . handle , desc. as_ptr ( ) as _ ) ;
331- } ,
332- Some ( AmData :: Rndv ( desc) ) => unsafe {
333- ucp_am_data_release ( self . worker . handle , desc. as_ptr ( ) as _ ) ;
334- } ,
335- _ => ( ) ,
355+ if let Some ( data) = self . msg . data . take ( ) {
356+ self . drop_msg ( data) ;
336357 }
337358 }
338359}
@@ -502,6 +523,8 @@ impl Endpoint {
502523}
503524
504525/// Active message protocol
526+ #[ derive( Debug , Clone , Copy ) ]
527+ #[ repr( u32 ) ]
505528pub enum AmProto {
506529 /// Eager protocol
507530 Eager ,
@@ -594,12 +617,20 @@ mod tests {
594617
595618 #[ test_log:: test]
596619 fn am ( ) {
597- for i in 0 ..20_usize {
598- spawn_thread ! ( send_recv( 4 << i) ) . join ( ) . unwrap ( ) ;
620+ let protos = vec ! [ None , Some ( AmProto :: Eager ) , Some ( AmProto :: Rndv ) ] ;
621+ for block_size_shift in 0 ..20_usize {
622+ for p in protos. iter ( ) {
623+ let rt = tokio:: runtime:: Builder :: new_current_thread ( )
624+ . enable_time ( )
625+ . build ( )
626+ . unwrap ( ) ;
627+ let local = tokio:: task:: LocalSet :: new ( ) ;
628+ local. block_on ( & rt, send_recv ( 4 << block_size_shift, * p) ) ;
629+ }
599630 }
600631 }
601632
602- async fn send_recv ( data_size : usize ) {
633+ async fn send_recv ( data_size : usize , proto : Option < AmProto > ) {
603634 let context1 = Context :: new ( ) . unwrap ( ) ;
604635 let worker1 = context1. create_worker ( ) . unwrap ( ) ;
605636 let context2 = Context :: new ( ) . unwrap ( ) ;
@@ -631,13 +662,7 @@ mod tests {
631662 async {
632663 // send msg
633664 let result = endpoint2
634- . am_send(
635- 16 ,
636- header. as_slice( ) ,
637- data. as_slice( ) ,
638- true ,
639- Some ( AmProto :: Rndv ) ,
640- )
665+ . am_send( 16 , header. as_slice( ) , data. as_slice( ) , true , proto)
641666 . await ;
642667 assert!( result. is_ok( ) ) ;
643668 } ,
@@ -662,10 +687,7 @@ mod tests {
662687 tokio:: join!(
663688 async {
664689 // send reply
665- let result = unsafe {
666- msg. reply( 12 , & header, & data, false , Some ( AmProto :: Rndv ) )
667- . await
668- } ;
690+ let result = unsafe { msg. reply( 12 , & header, & data, false , proto) . await } ;
669691 assert!( result. is_ok( ) ) ;
670692 } ,
671693 async {
0 commit comments