Skip to content

Commit 979a7b1

Browse files
authored
Expose async collectives to Python (#685)
This adds bindings for both the streaming node versions, and the object-based implementation. - Closes #682 Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Mads R. B. Kristensen (https://github.com/madsbk) - Richard (Rick) Zamora (https://github.com/rjzamora) URL: #685
1 parent cce7378 commit 979a7b1

File tree

22 files changed

+1122
-105
lines changed

22 files changed

+1122
-105
lines changed

cpp/include/rapidsmpf/streaming/chunks/packed_data.hpp

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,46 +12,33 @@
1212
namespace rapidsmpf::streaming {
1313

1414
/**
15-
* @brief Chunk of `PackedData`.
16-
*/
17-
struct PackedDataChunk {
18-
/**
19-
* @brief Packed data payload.
20-
*/
21-
PackedData data;
22-
};
23-
24-
/**
25-
* @brief Generate a content description for a `PackedDataChunk`.
15+
* @brief Generate a content description for `PackedData`.
2616
*
2717
* @param obj The object's content to describe.
2818
* @return A new content description.
2919
*/
30-
inline ContentDescription get_content_description(PackedDataChunk const& obj) {
20+
inline ContentDescription get_content_description(PackedData const& obj) {
3121
return ContentDescription{
32-
{{obj.data.data->mem_type(), obj.data.data->size}},
33-
ContentDescription::Spillable::YES
22+
{{obj.data->mem_type(), obj.data->size}}, ContentDescription::Spillable::YES
3423
};
3524
}
3625

3726
/**
38-
* @brief Wrap a `PackedDataChunk` into a `Message`.
27+
* @brief Wrap `PackedData` into a `Message`.
3928
*
4029
* @param sequence_number Ordering identifier for the message.
4130
* @param chunk The chunk to wrap into a message.
4231
* @return A `Message` encapsulating the provided chunk as its payload.
4332
*/
44-
Message to_message(
45-
std::uint64_t sequence_number, std::unique_ptr<PackedDataChunk> chunk
46-
) {
33+
Message to_message(std::uint64_t sequence_number, std::unique_ptr<PackedData> chunk) {
4734
auto cd = get_content_description(*chunk);
4835
return Message{
4936
sequence_number,
5037
std::move(chunk),
5138
cd,
5239
[](Message const& msg, MemoryReservation& reservation) -> Message {
53-
auto const& self = msg.get<PackedDataChunk>();
54-
auto chunk = std::make_unique<PackedDataChunk>(self.data.copy(reservation));
40+
auto const& self = msg.get<PackedData>();
41+
auto chunk = std::make_unique<PackedData>(self.copy(reservation));
5542
auto cd = get_content_description(*chunk);
5643
return Message{msg.sequence_number(), std::move(chunk), cd, msg.copy_cb()};
5744
}

cpp/include/rapidsmpf/streaming/coll/allgather.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <rapidsmpf/allgather/allgather.hpp>
1313
#include <rapidsmpf/communicator/communicator.hpp>
1414
#include <rapidsmpf/memory/packed_data.hpp>
15-
#include <rapidsmpf/streaming/chunks/packed_data.hpp>
1615
#include <rapidsmpf/streaming/core/channel.hpp>
1716
#include <rapidsmpf/streaming/core/context.hpp>
1817

@@ -61,7 +60,7 @@ class AllGather {
6160
* @param sequence_number The sequence number for this chunk.
6261
* @param chunk The chunk to insert.
6362
*/
64-
void insert(std::uint64_t sequence_number, PackedDataChunk&& chunk);
63+
void insert(std::uint64_t sequence_number, PackedData&& chunk);
6564

6665
/// @copydoc rapidsmpf::allgather::AllGather::insert_finished()
6766
void insert_finished();
@@ -76,7 +75,7 @@ class AllGather {
7675
* @return Coroutine that completes when all data is available for extraction and
7776
* returns the data.
7877
*/
79-
coro::task<std::vector<PackedDataChunk>> extract_all(Ordered ordered = Ordered::YES);
78+
coro::task<std::vector<PackedData>> extract_all(Ordered ordered = Ordered::YES);
8079

8180
private:
8281
coro::event
@@ -94,8 +93,8 @@ namespace node {
9493
* packed data received through `Channel`s.
9594
*
9695
* @param ctx The streaming context to use.
97-
* @param ch_in Input channel providing `PackedDataChunk`s to be gathered.
98-
* @param ch_out Output channel where the gathered `PackedDataChunk`s are sent.
96+
* @param ch_in Input channel providing `PackedData`s to be gathered.
97+
* @param ch_out Output channel where the gathered `PackedData`s are sent.
9998
* @param op_id Unique identifier for the operation.
10099
* @param ordered If the extracted data should be sent to the output channel with sequence
101100
* numbers corresponding to the global total order of input chunks. If yes, then the

cpp/include/rapidsmpf/streaming/coll/shuffler.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class ShufflerAsync {
125125
*
126126
* @return A coroutine that, when awaited, indicates the shuffle has completed.
127127
*/
128-
[[nodiscard]] Node insert_finished(std::vector<shuffler::PartID>&& pids);
128+
[[nodiscard]] Node insert_finished();
129129

130130
/**
131131
* @brief Asynchronously extracts all data for a specific partition.

cpp/src/streaming/coll/allgather.cpp

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
#include <algorithm>
7-
86
#include <rapidsmpf/memory/packed_data.hpp>
97
#include <rapidsmpf/streaming/chunks/packed_data.hpp>
108
#include <rapidsmpf/streaming/coll/allgather.hpp>
@@ -41,27 +39,19 @@ AllGather::~AllGather() {
4139
return ctx_;
4240
}
4341

44-
void AllGather::insert(std::uint64_t sequence_number, PackedDataChunk&& packed_data) {
45-
gatherer_.insert(sequence_number, std::move(packed_data.data));
42+
void AllGather::insert(std::uint64_t sequence_number, PackedData&& packed_data) {
43+
gatherer_.insert(sequence_number, std::move(packed_data));
4644
}
4745

4846
void AllGather::insert_finished() {
4947
gatherer_.insert_finished();
5048
}
5149

52-
coro::task<std::vector<PackedDataChunk>> AllGather::extract_all(
53-
AllGather::Ordered ordered
54-
) {
50+
coro::task<std::vector<PackedData>> AllGather::extract_all(AllGather::Ordered ordered) {
5551
// Wait until we're notified that everything is done.
5652
co_await event_;
5753
// And now this will not block.
58-
auto data = gatherer_.wait_and_extract(ordered);
59-
std::vector<PackedDataChunk> result;
60-
result.reserve(data.size());
61-
std::ranges::transform(data, std::back_inserter(result), [](auto&& pd) {
62-
return PackedDataChunk{.data = std::move(pd)};
63-
});
64-
co_return result;
54+
co_return gatherer_.wait_and_extract(ordered);
6555
}
6656

6757
namespace node {
@@ -80,14 +70,14 @@ Node allgather(
8070
if (msg.empty()) {
8171
break;
8272
}
83-
gatherer.insert(msg.sequence_number(), msg.release<PackedDataChunk>());
73+
gatherer.insert(msg.sequence_number(), msg.release<PackedData>());
8474
}
8575
gatherer.insert_finished();
8676
auto data = co_await gatherer.extract_all(ordered);
8777
std::uint64_t sequence{0};
8878
for (auto&& chunk : data) {
8979
co_await ch_out->send(
90-
to_message(sequence++, std::make_unique<PackedDataChunk>(std::move(chunk)))
80+
to_message(sequence++, std::make_unique<PackedData>(std::move(chunk)))
9181
);
9282
}
9383
co_await ch_out->drain(ctx->executor());

cpp/src/streaming/coll/shuffler.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ void ShufflerAsync::insert(std::unordered_map<shuffler::PartID, PackedData>&& ch
153153
shuffler_.insert(std::move(chunks));
154154
}
155155

156-
Node ShufflerAsync::insert_finished(std::vector<shuffler::PartID>&& pids) {
156+
Node ShufflerAsync::insert_finished() {
157+
std::vector<shuffler::PartID> pids(total_num_partitions());
158+
std::iota(pids.begin(), pids.end(), shuffler::PartID{0});
157159
shuffler_.insert_finished(std::move(pids));
158160
return finished_drain();
159161
}
@@ -258,12 +260,7 @@ Node shuffler(
258260
shuffler_async.insert(std::move(partition_map.data));
259261
}
260262

261-
// Tell the shuffler that we have no more input data.
262-
std::vector<rapidsmpf::shuffler::PartID> finished(
263-
shuffler_async.total_num_partitions()
264-
);
265-
std::iota(finished.begin(), finished.end(), 0);
266-
auto finish_token = shuffler_async.insert_finished(std::move(finished));
263+
auto finish_token = shuffler_async.insert_finished();
267264

268265
for ([[maybe_unused]] auto& _ : shuffler_async.local_partitions()) {
269266
auto finished = co_await shuffler_async.extract_any_async();

cpp/tests/streaming/test_allgather.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ TEST_P(StreamingAllGather, basic) {
8787
});
8888
auto meta = std::make_unique<std::vector<std::uint8_t>>(sizeof(int));
8989
std::memcpy(meta->data(), &size, sizeof(int));
90-
allgather.insert(
91-
sequence, streaming::PackedDataChunk{{std::move(meta), std::move(buf)}}
92-
);
90+
allgather.insert(sequence, PackedData{std::move(meta), std::move(buf)});
9391
latch.count_down();
9492
co_return;
9593
};
@@ -104,20 +102,20 @@ TEST_P(StreamingAllGather, basic) {
104102
std::size_t offset{0};
105103
for (auto& pd : data) {
106104
RAPIDSMPF_EXPECTS(
107-
pd.data.metadata->size() == sizeof(int), "Invalid metadata buffer size"
105+
pd.metadata->size() == sizeof(int), "Invalid metadata buffer size"
108106
);
109107
int msize;
110-
std::memcpy(&msize, pd.data.metadata->data(), sizeof(int));
108+
std::memcpy(&msize, pd.metadata->data(), sizeof(int));
111109
RAPIDSMPF_EXPECTS(msize == size, "Corrupted metadata value");
112110
RAPIDSMPF_CUDA_TRY(cudaMemcpyAsync(
113111
result.data() + offset,
114-
pd.data.data->data(),
115-
pd.data.data->size,
112+
pd.data->data(),
113+
pd.data->size,
116114
cudaMemcpyDefault,
117-
pd.data.data->stream()
115+
pd.data->stream()
118116
));
119117
offset += msize;
120-
pd.data.data->stream().synchronize();
118+
pd.data->stream().synchronize();
121119
}
122120
};
123121

@@ -170,9 +168,7 @@ TEST_P(StreamingAllGather, streaming_node) {
170168
input_messages.emplace_back(
171169
streaming::to_message(
172170
insertion_id,
173-
std::make_unique<streaming::PackedDataChunk>(streaming::PackedDataChunk{
174-
PackedData{std::move(meta), std::move(buf)}
175-
})
171+
std::make_unique<PackedData>(std::move(meta), std::move(buf))
176172
)
177173
);
178174
}
@@ -191,8 +187,7 @@ TEST_P(StreamingAllGather, streaming_node) {
191187
std::vector<int> actual(size * size * n_inserts);
192188
std::size_t offset{0};
193189
for (auto& msg : output_messages) {
194-
auto chunk = msg.release<streaming::PackedDataChunk>();
195-
auto& pd = chunk.data;
190+
auto pd = msg.release<PackedData>();
196191
RAPIDSMPF_EXPECTS(
197192
pd.metadata->size() == sizeof(int), "Invalid metadata buffer size"
198193
);

cpp/tests/streaming/test_shuffler.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,7 @@ TEST_P(ShufflerAsyncTest, multi_consumer_extract) {
232232
shuffler->insert(std::move(data));
233233
}
234234

235-
auto finish_token =
236-
shuffler->insert_finished(iota_vector<shuffler::PartID>(n_partitions));
235+
auto finish_token = shuffler->insert_finished();
237236

238237
std::mutex mtx;
239238
std::vector<shuffler::PartID> finished_pids;
@@ -264,8 +263,7 @@ TEST_F(BaseStreamingShuffle, extract_any_before_extract) {
264263
auto shuffler = std::make_unique<ShufflerAsync>(ctx, op_id, n_partitions);
265264

266265
// all empty partitions
267-
auto finish_token =
268-
shuffler->insert_finished(iota_vector<shuffler::PartID>(n_partitions));
266+
auto finish_token = shuffler->insert_finished();
269267

270268
auto local_pids = shuffler::Shuffler::local_partitions(
271269
ctx->comm(), n_partitions, shuffler::Shuffler::round_robin
@@ -313,8 +311,7 @@ class CompetingShufflerAsyncTest : public BaseStreamingShuffle {
313311

314312
auto shuffler = std::make_unique<ShufflerAsync>(ctx, op_id, n_partitions);
315313

316-
auto finish_token =
317-
shuffler->insert_finished(iota_vector<shuffler::PartID>(n_partitions));
314+
auto finish_token = shuffler->insert_finished();
318315
coro::sync_wait(finish_token);
319316
auto [extract_any_result, extract_result] =
320317
produce_results_fn(shuffler.get(), this_pid);

python/rapidsmpf/rapidsmpf/shuffler.pxd

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,24 @@ cdef extern from "<rapidsmpf/shuffler/shuffler.hpp>" nogil:
3737
uint32_t wait_any() except +
3838
void wait_on(uint32_t pid) except +
3939
string str() except +
40+
# Insert PackedData into a partition map. We implement this in C++ because
41+
# PackedData doesn't have a default ctor.
42+
cdef extern from *:
43+
"""
44+
void cpp_insert_chunk_into_partition_map(
45+
std::unordered_map<std::uint32_t, rapidsmpf::PackedData> &partition_map,
46+
std::uint32_t pid,
47+
std::unique_ptr<rapidsmpf::PackedData> packed_data
48+
) {
49+
partition_map.insert({pid, std::move(*packed_data)});
50+
}
51+
"""
52+
void cpp_insert_chunk_into_partition_map(
53+
unordered_map[uint32_t, cpp_PackedData] &partition_map,
54+
uint32_t pid,
55+
unique_ptr[cpp_PackedData] packed_data,
56+
) except + nogil
57+
4058

4159
cdef class Shuffler:
4260
cdef unique_ptr[cpp_Shuffler] _handle

python/rapidsmpf/rapidsmpf/shuffler.pyx

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,6 @@ from rapidsmpf.progress_thread cimport ProgressThread
1515
from rapidsmpf.statistics cimport Statistics
1616

1717

18-
# Insert PackedData into a partition map. We implement this in C++ because
19-
# PackedData doesn't have a default ctor.
20-
cdef extern from *:
21-
"""
22-
void cpp_insert_chunk_into_partition_map(
23-
std::unordered_map<std::uint32_t, rapidsmpf::PackedData> &partition_map,
24-
std::uint32_t pid,
25-
std::unique_ptr<rapidsmpf::PackedData> packed_data
26-
) {
27-
partition_map.insert({pid, std::move(*packed_data)});
28-
}
29-
"""
30-
void cpp_insert_chunk_into_partition_map(
31-
unordered_map[uint32_t, cpp_PackedData] &partition_map,
32-
uint32_t pid,
33-
unique_ptr[cpp_PackedData] packed_data,
34-
) except + nogil
35-
36-
3718
cdef class Shuffler:
3819
"""
3920
Shuffle service for partitioned data.

python/rapidsmpf/rapidsmpf/streaming/chunks/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# cmake-format: on
66
# =================================================================================
77

8-
set(cython_modules arbitrary.pyx partition.pyx utils.pyx)
8+
set(cython_modules arbitrary.pyx packed_data.pyx partition.pyx utils.pyx)
99

1010
rapids_cython_create_modules(
1111
CXX

0 commit comments

Comments
 (0)