@@ -302,7 +302,7 @@ std::vector<ComputationClient::DataPtr> XrtComputationClient::TransferToServer(
302302 }
303303 XLA_COUNTER (" XrtPartitionedTransferToServer" , 1 );
304304
305- util::MultiWait mwait (partitions.size ());
305+ auto mwait = std::make_shared< util::MultiWait> (partitions.size ());
306306 std::vector<DataPtr> results (tensors.size ());
307307 for (size_t i = 0 ; i < partitions.size (); ++i) {
308308 auto sender = [&, i]() {
@@ -316,9 +316,10 @@ std::vector<ComputationClient::DataPtr> XrtComputationClient::TransferToServer(
316316 results[base_index + r] = std::move (partitions_results[r]);
317317 }
318318 };
319- env::ScheduleIoClosure (mwait.Completer (std::move (sender)));
319+ env::ScheduleIoClosure (
320+ util::MultiWait::Completer (mwait, std::move (sender)));
320321 }
321- mwait. Wait ();
322+ mwait-> Wait ();
322323 return results;
323324}
324325
@@ -330,7 +331,7 @@ XrtComputationClient::TransferToServerInternal(
330331 std::mutex lock;
331332 XrtSessionCache::SessionMap session_map;
332333 int64 total_size = 0 ;
333- util::MultiWait mwait (tensors.size ());
334+ auto mwait = std::make_shared< util::MultiWait> (tensors.size ());
334335 std::map<XrtSession*, SessionWork> session_work_map;
335336 {
336337 metrics::TimedSection timed (TransferToServerTransformMetric ());
@@ -363,13 +364,14 @@ XrtComputationClient::TransferToServerInternal(
363364 total_size += tdata.size ();
364365 }
365366 };
366- env::ScheduleClosure (mwait.Completer (std::move (converter)));
367+ env::ScheduleClosure (
368+ util::MultiWait::Completer (mwait, std::move (converter)));
367369 }
368- mwait. Wait ();
370+ mwait-> Wait ();
369371 }
370372 OutboundDataMetric ()->AddSample (total_size);
371373
372- mwait. Reset (session_work_map.size ());
374+ mwait-> Reset (session_work_map.size ());
373375 std::vector<DataPtr> results (tensors.size ());
374376 for (auto & session_session_work : session_work_map) {
375377 XrtSession* session = session_session_work.first ;
@@ -388,9 +390,10 @@ XrtComputationClient::TransferToServerInternal(
388390 }
389391 CreateDataHandlesCounter ()->AddValue (outputs.size ());
390392 };
391- env::ScheduleIoClosure (mwait.Completer (std::move (runner)));
393+ env::ScheduleIoClosure (
394+ util::MultiWait::Completer (mwait, std::move (runner)));
392395 }
393- mwait. Wait ();
396+ mwait-> Wait ();
394397 return results;
395398}
396399
@@ -426,7 +429,7 @@ std::vector<Literal> XrtComputationClient::TransferFromServer(
426429 session_work->index_mapping .push_back (i);
427430 }
428431
429- util::MultiWait mwait (session_work_map.size ());
432+ auto mwait = std::make_shared< util::MultiWait> (session_work_map.size ());
430433 std::atomic<int64> total_size (0 );
431434 std::vector<Literal> results (handles.size ());
432435 for (auto & session_session_work : session_work_map) {
@@ -446,9 +449,10 @@ std::vector<Literal> XrtComputationClient::TransferFromServer(
446449 total_size += results[li].size_bytes ();
447450 }
448451 };
449- env::ScheduleIoClosure (mwait.Completer (std::move (runner)));
452+ env::ScheduleIoClosure (
453+ util::MultiWait::Completer (mwait, std::move (runner)));
450454 }
451- mwait. Wait ();
455+ mwait-> Wait ();
452456 InboundDataMetric ()->AddSample (total_size.load ());
453457 return results;
454458}
@@ -458,7 +462,7 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
458462 metrics::TimedSection timed (CompileMetric ());
459463
460464 std::mutex lock;
461- util::MultiWait mwait (instances.size ());
465+ auto mwait = std::make_shared< util::MultiWait> (instances.size ());
462466 std::vector<ProgramShape> program_shapes (instances.size ());
463467 std::vector<ComputationPtr> results (instances.size ());
464468 std::vector<CompilationCacheKey> cache_keys (instances.size ());
@@ -499,10 +503,10 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
499503 results[i] = computation_ptr;
500504 }
501505 };
502- env::ScheduleClosure (mwait. Completer (std::move (builder)));
506+ env::ScheduleClosure (util::MultiWait:: Completer (mwait, std::move (builder)));
503507 }
504- mwait. Wait ();
505- mwait. Reset (session_work_map.size ());
508+ mwait-> Wait ();
509+ mwait-> Reset (session_work_map.size ());
506510
507511 for (auto & session_and_work : session_work_map) {
508512 XrtSession* session = session_and_work.first ;
@@ -532,9 +536,10 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
532536 CreateCompileHandlesCounter ()->AddValue (1 );
533537 }
534538 };
535- env::ScheduleIoClosure (mwait.Completer (std::move (session_runner)));
539+ env::ScheduleIoClosure (
540+ util::MultiWait::Completer (mwait, std::move (session_runner)));
536541 }
537- mwait. Wait ();
542+ mwait-> Wait ();
538543 return results;
539544}
540545
@@ -626,7 +631,7 @@ XrtComputationClient::RunComputations(
626631 }
627632 XLA_CHECK_EQ (computations.size (), devices.size ());
628633
629- util::MultiWait mwait (session_replicas.size ());
634+ auto mwait = std::make_shared< util::MultiWait> (session_replicas.size ());
630635 std::vector<std::vector<DataPtr>> results (devices.size ());
631636 for (auto & sess_replica : session_replicas) {
632637 XrtSession* session = sess_replica.first ;
@@ -655,9 +660,10 @@ XrtComputationClient::RunComputations(
655660 GetEffectiveDevice (devices[replica]));
656661 }
657662 };
658- env::ScheduleIoClosure (mwait.Completer (std::move (session_runner)));
663+ env::ScheduleIoClosure (
664+ util::MultiWait::Completer (mwait, std::move (session_runner)));
659665 }
660- mwait. Wait ();
666+ mwait-> Wait ();
661667 return results;
662668}
663669
0 commit comments