@@ -22,88 +22,92 @@ void PostBox<KeyType>::insert(Chunk&& chunk) {
2222 " PostBox.insert(): all messages in the chunk must map to the same key"
2323 );
2424 }
25- std::lock_guard const lock (mutex_);
25+ // std::lock_guard const lock(mutex_);
26+ auto & map_value = pigeonhole_.at (key);
27+ std::lock_guard lock (map_value.mutex );
28+ if (map_value.chunks .empty ()) {
29+ RAPIDSMPF_EXPECTS (
30+ n_non_empty_keys_.fetch_add (1 , std::memory_order_relaxed) + 1
31+ <= pigeonhole_.size (),
32+ " PostBox.insert(): n_non_empty_keys_ is already at the maximum"
33+ );
34+ }
2635 RAPIDSMPF_EXPECTS (
27- pigeonhole_[key]. emplace (chunk. chunk_id (), std::move (chunk)).second ,
36+ map_value. chunks . emplace (std::move (chunk)).second ,
2837 " PostBox.insert(): chunk already exist"
2938 );
3039}
3140
3241template <typename KeyType>
3342bool PostBox<KeyType>::is_empty(PartID pid) const {
34- std::lock_guard const lock (mutex_);
35- return !pigeonhole_.contains (key_map_fn_ (pid));
43+ auto & map_value = pigeonhole_.at (key_map_fn_ (pid));
44+ std::lock_guard lock (map_value.mutex );
45+ return map_value.chunks .empty ();
3646}
3747
3848template <typename KeyType>
39- Chunk PostBox<KeyType>::extract(PartID pid, ChunkID cid) {
40- std::lock_guard const lock (mutex_);
41- return extract_item (pigeonhole_[key_map_fn_ (pid)], cid).second ;
49+ std::vector<Chunk> PostBox<KeyType>::extract(PartID pid) {
50+ return extract_by_key (key_map_fn_ (pid));
4251}
4352
4453template <typename KeyType>
45- std::unordered_map<ChunkID, Chunk> PostBox<KeyType>::extract(PartID pid) {
46- std::lock_guard const lock (mutex_);
47- return extract_value (pigeonhole_, key_map_fn_ (pid));
48- }
54+ std::vector<Chunk> PostBox<KeyType>::extract_by_key(KeyType key) {
55+ auto & map_value = pigeonhole_.at (key);
56+ std::vector<Chunk> ret;
57+ std::lock_guard lock (map_value.mutex );
58+ RAPIDSMPF_EXPECTS (!map_value.chunks .empty (), " PostBox.extract(): partition is empty" );
59+ ret.reserve (map_value.chunks .size ());
4960
50- template <typename KeyType>
51- std::unordered_map<ChunkID, Chunk> PostBox<KeyType>::extract_by_key(KeyType key) {
52- std::lock_guard const lock (mutex_);
53- return extract_value (pigeonhole_, key);
61+ for (auto it = map_value.chunks .begin (); it != map_value.chunks .end ();) {
62+ auto node = map_value.chunks .extract (it++);
63+ ret.emplace_back (std::move (node.value ()));
64+ }
65+
66+ RAPIDSMPF_EXPECTS (
67+ n_non_empty_keys_.fetch_sub (1 , std::memory_order_relaxed) > 0 ,
68+ " PostBox.extract(): n_non_empty_keys_ is already 0"
69+ );
70+ return ret;
5471}
5572
5673template <typename KeyType>
5774std::vector<Chunk> PostBox<KeyType>::extract_all_ready() {
58- std::lock_guard const lock (mutex_);
5975 std::vector<Chunk> ret;
6076
6177 // Iterate through the outer map
62- auto pid_it = pigeonhole_.begin ();
63- while (pid_it != pigeonhole_.end ()) {
64- // Iterate through the inner map
65- auto & chunks = pid_it->second ;
66- auto chunk_it = chunks.begin ();
67- while (chunk_it != chunks.end ()) {
68- if (chunk_it->second .is_ready ()) {
69- ret.emplace_back (std::move (chunk_it->second ));
70- chunk_it = chunks.erase (chunk_it);
78+ for (auto & [key, map_value] : pigeonhole_) {
79+ std::lock_guard lock (map_value.mutex );
80+ bool chunks_available = !map_value.chunks .empty ();
81+ auto chunk_it = map_value.chunks .begin ();
82+ while (chunk_it != map_value.chunks .end ()) {
83+ if (chunk_it->is_ready ()) {
84+ auto node = map_value.chunks .extract (chunk_it++);
85+ ret.emplace_back (std::move (node.value ()));
7186 } else {
7287 ++chunk_it;
7388 }
7489 }
7590
76- // Remove the pid entry if its chunks map is empty
77- if (chunks.empty ()) {
78- pid_it = pigeonhole_.erase (pid_it);
79- } else {
80- ++pid_it;
91+ // if the chunks were available and are now empty, its fully extracted
92+ if (chunks_available && map_value.chunks .empty ()) {
93+ RAPIDSMPF_EXPECTS (
94+ n_non_empty_keys_.fetch_sub (1 , std::memory_order_relaxed) > 0 ,
95+ " PostBox.extract_all_ready(): n_non_empty_keys_ is already 0"
96+ );
8197 }
8298 }
83-
8499 return ret;
85100}
86101
87102template <typename KeyType>
88103bool PostBox<KeyType>::empty() const {
89- std::lock_guard const lock (mutex_);
90- return pigeonhole_.empty ();
104+ return n_non_empty_keys_.load (std::memory_order_acquire) == 0 ;
91105}
92106
93107template <typename KeyType>
94- std::vector<std::tuple<KeyType, ChunkID, std::size_t >> PostBox<KeyType>::search(
95- MemoryType mem_type
96- ) const {
97- std::lock_guard const lock (mutex_);
98- std::vector<std::tuple<KeyType, ChunkID, std::size_t >> ret;
99- for (auto & [key, chunks] : pigeonhole_) {
100- for (auto & [cid, chunk] : chunks) {
101- if (!chunk.is_control_message (0 ) && chunk.data_memory_type () == mem_type) {
102- ret.emplace_back (key, cid, chunk.concat_data_size ());
103- }
104- }
105- }
106- return ret;
108+ size_t PostBox<KeyType>::spill(BufferResource* /* br */ , size_t /* amount */ ) {
109+ // TODO: implement spill
110+ return 0 ;
107111}
108112
109113template <typename KeyType>
@@ -113,14 +117,14 @@ std::string PostBox<KeyType>::str() const {
113117 }
114118 std::stringstream ss;
115119 ss << " PostBox(" ;
116- for (auto const & [key, chunks ] : pigeonhole_) {
120+ for (auto const & [key, map_value ] : pigeonhole_) {
117121 ss << " k=" << key << " : [" ;
118- for (auto const & [cid, chunk] : chunks) {
119- assert (cid == chunk.chunk_id ());
122+ for (auto const & chunk : map_value. chunks ) {
123+ // assert(cid == chunk.chunk_id());
120124 if (chunk.is_control_message (0 )) {
121125 ss << " EOP" << chunk.expected_num_chunks (0 ) << " , " ;
122126 } else {
123- ss << cid << " , " ;
127+ ss << chunk. chunk_id () << " , " ;
124128 }
125129 }
126130 ss << " \b\b ], " ;
0 commit comments