Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions cpp/include/ucxx/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include <netdb.h>

#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include <ucp/api/ucp.h>
Expand All @@ -18,6 +20,7 @@
#include <ucxx/inflight_requests.h>
#include <ucxx/listener.h>
#include <ucxx/request.h>
#include <ucxx/request_tag_params.h>
#include <ucxx/typedefs.h>
#include <ucxx/utils/sockaddr.h>
#include <ucxx/worker.h>
Expand Down Expand Up @@ -611,6 +614,61 @@ class Endpoint : public Component {
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

// Template version with named parameters
template <typename... Options>
[[nodiscard]] std::enable_if_t<
detail::contains_type<request_tag_params::EndpointParam, Options...>::value &&
detail::contains_type<request_tag_params::RequestDataParam, Options...>::value &&
detail::has_unique_types<detail::remove_cvref<Options>...>::value,
std::shared_ptr<Request>>
tagSend(Options&&... opts)
{
// Default values for optional parameters
std::shared_ptr<Component> endpoint = nullptr;
std::optional<std::variant<data::TagSend, data::TagReceive>> requestData;
bool enablePythonFuture = false;
RequestCallbackUserFunction callbackFunction = nullptr;
RequestCallbackUserData callbackData = nullptr;

// Helper to set parameters
auto setParam = [&](auto&& param) {
using ParamType = std::decay_t<decltype(param)>;
if constexpr (std::is_same_v<ParamType, request_tag_params::EndpointParam>) {
endpoint = std::move(param.value);
} else if constexpr (std::is_same_v<ParamType, request_tag_params::RequestDataParam>) {
requestData.emplace(std::move(param.value));
} else if constexpr (std::is_same_v<ParamType, request_tag_params::EnablePythonFutureParam>) {
enablePythonFuture = param.value;
} else if constexpr (std::is_same_v<ParamType, request_tag_params::CallbackFunctionParam>) {
callbackFunction = param.value;
} else if constexpr (std::is_same_v<ParamType, request_tag_params::CallbackDataParam>) {
callbackData = param.value;
}
};

// Set all parameters
(setParam(std::forward<Options>(opts)), ...);

// Ensure required parameters are present
if (!endpoint || !requestData) {
throw std::runtime_error("Missing required parameters for tagSend");
}

// Create the request with the collected parameters and register it
return registerInflightRequest(createRequestTag(std::forward<Options>(opts)...));
}

// Overload for template-style parameters (deprecated)
[[nodiscard]] std::shared_ptr<Request> tagSend(
request_tag_params::EndpointParam&& endpointParam,
request_tag_params::RequestDataParam&& requestDataParam,
request_tag_params::EnablePythonFutureParam&& enablePythonFutureParam =
request_tag_params::EnablePythonFutureParam{false},
request_tag_params::CallbackFunctionParam&& callbackFunctionParam =
request_tag_params::CallbackFunctionParam{nullptr},
request_tag_params::CallbackDataParam&& callbackDataParam =
request_tag_params::CallbackDataParam{nullptr});

/**
* @brief Enqueue a tag receive operation.
*
Expand Down Expand Up @@ -644,6 +702,61 @@ class Endpoint : public Component {
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

// Template version with named parameters
template <typename... Options>
[[nodiscard]] std::enable_if_t<
detail::contains_type<request_tag_params::EndpointParam, Options...>::value &&
detail::contains_type<request_tag_params::RequestDataParam, Options...>::value &&
detail::has_unique_types<detail::remove_cvref<Options>...>::value,
std::shared_ptr<Request>>
tagRecv(Options&&... opts)
{
// Default values for optional parameters
std::shared_ptr<Component> endpoint = nullptr;
std::optional<std::variant<data::TagSend, data::TagReceive>> requestData;
bool enablePythonFuture = false;
RequestCallbackUserFunction callbackFunction = nullptr;
RequestCallbackUserData callbackData = nullptr;

// Helper to set parameters
auto setParam = [&](auto&& param) {
using ParamType = std::decay_t<decltype(param)>;
if constexpr (std::is_same_v<ParamType, request_tag_params::EndpointParam>) {
endpoint = std::move(param.value);
} else if constexpr (std::is_same_v<ParamType, request_tag_params::RequestDataParam>) {
requestData.emplace(std::move(param.value));
} else if constexpr (std::is_same_v<ParamType, request_tag_params::EnablePythonFutureParam>) {
enablePythonFuture = param.value;
} else if constexpr (std::is_same_v<ParamType, request_tag_params::CallbackFunctionParam>) {
callbackFunction = param.value;
} else if constexpr (std::is_same_v<ParamType, request_tag_params::CallbackDataParam>) {
callbackData = param.value;
}
};

// Set all parameters
(setParam(std::forward<Options>(opts)), ...);

// Ensure required parameters are present
if (!endpoint || !requestData) {
throw std::runtime_error("Missing required parameters for tagRecv");
}

// Create the request with the collected parameters and register it
return registerInflightRequest(createRequestTag(std::forward<Options>(opts)...));
}

// Overload for template-style parameters (deprecated)
[[nodiscard]] std::shared_ptr<Request> tagRecv(
request_tag_params::EndpointParam&& endpointParam,
request_tag_params::RequestDataParam&& requestDataParam,
request_tag_params::EnablePythonFutureParam&& enablePythonFutureParam =
request_tag_params::EnablePythonFutureParam{false},
request_tag_params::CallbackFunctionParam&& callbackFunctionParam =
request_tag_params::CallbackFunctionParam{nullptr},
request_tag_params::CallbackDataParam&& callbackDataParam =
request_tag_params::CallbackDataParam{nullptr});

/**
* @brief Enqueue a multi-buffer tag send operation.
*
Expand Down
100 changes: 72 additions & 28 deletions cpp/include/ucxx/request_tag.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
#pragma once
#include <memory>
#include <string>
#include <type_traits>
#include <utility>

#include <ucp/api/ucp.h>
Expand Down Expand Up @@ -53,44 +54,30 @@ class RequestTag : public Request {
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
*/
RequestTag(std::shared_ptr<Component> endpointOrWorker,
const std::variant<data::TagSend, data::TagReceive> requestData,
const std::string operationName,
const std::variant<data::TagSend, data::TagReceive>& requestData,
const std::string& operationName,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

public:
/**
* @brief Constructor for `std::shared_ptr<ucxx::RequestTag>`.
*
* The constructor for a `std::shared_ptr<ucxx::RequestTag>` object, creating a send or
* receive tag request, returning a pointer to a request object that can be later awaited
* and checked for errors. This is a non-blocking operation, and the status of the
* transfer must be verified from the resulting request object before the data can be
* released (for a send operation) or consumed (for a receive operation).
*
* @throws ucxx::Error if send is `true` and `endpointOrWorker` is not a
* `std::shared_ptr<ucxx::Endpoint>`.
*
* @param[in] endpointOrWorker the parent component, which may either be a
* `std::shared_ptr<Endpoint>` or
* `std::shared_ptr<Worker>`.
* @param[in] requestData container of the specified message type, including all
* type-specific data.
* @param[in] enablePythonFuture whether a python future should be created and
* subsequently notified.
* @param[in] callbackFunction user-defined callback function to call upon completion.
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
*
* @returns The `shared_ptr<ucxx::RequestTag>` object
*/
// Friend declarations for both createRequestTag functions
friend std::shared_ptr<RequestTag> createRequestTag(
std::shared_ptr<Component> endpointOrWorker,
const std::variant<data::TagSend, data::TagReceive> requestData,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

// Friend the templated version
template <typename... Options>
friend std::enable_if_t<
detail::contains_type<request_tag_params::EndpointParam, Options...>::value &&
detail::contains_type<request_tag_params::RequestDataParam, Options...>::value &&
detail::has_unique_types<detail::remove_cvref<Options>...>::value,
std::shared_ptr<RequestTag>>
createRequestTag(Options&&... opts);

public:
virtual void populateDelayedSubmission();

/**
Expand Down Expand Up @@ -160,4 +147,61 @@ class RequestTag : public Request {
void callback(void* request, ucs_status_t status, const ucp_tag_recv_info_t* info);
};

// Implementation of the templated createRequestTag function
template <typename... Options>
std::enable_if_t<detail::contains_type<request_tag_params::EndpointParam, Options...>::value &&
detail::contains_type<request_tag_params::RequestDataParam, Options...>::value &&
detail::has_unique_types<detail::remove_cvref<Options>...>::value,
std::shared_ptr<RequestTag>>
createRequestTag(Options&&... opts)
{
// Default values for optional parameters
std::shared_ptr<Component> endpointOrWorker;
std::optional<std::variant<data::TagSend, data::TagReceive>> requestData;
bool enablePythonFuture = false;
RequestCallbackUserFunction callbackFunction = nullptr;
RequestCallbackUserData callbackData = nullptr;
std::string operationName = "tagOp";

// Helper to set parameters
auto setParam = [&](auto&& param) {
using ParamType = std::decay_t<decltype(param)>;
if constexpr (std::is_same_v<ParamType, request_tag_params::EndpointParam>) {
endpointOrWorker = std::move(param.value);
} else if constexpr (std::is_same_v<ParamType, request_tag_params::RequestDataParam>) {
requestData.emplace(std::move(param.value));
} else if constexpr (std::is_same_v<ParamType, request_tag_params::EnablePythonFutureParam>) {
enablePythonFuture = param.value;
} else if constexpr (std::is_same_v<ParamType, request_tag_params::CallbackFunctionParam>) {
callbackFunction = param.value;
} else if constexpr (std::is_same_v<ParamType, request_tag_params::CallbackDataParam>) {
callbackData = param.value;
} else if constexpr (std::is_same_v<ParamType, request_tag_params::OperationNameParam>) {
operationName = std::move(param.value);
}
};

// Set all parameters
(setParam(std::forward<Options>(opts)), ...);

// Ensure required parameters are present
if (!endpointOrWorker || !requestData) {
throw std::runtime_error("Missing required parameters for RequestTag creation");
}

// Create the RequestTag with the collected parameters
auto req = std::shared_ptr<RequestTag>(new RequestTag(std::move(endpointOrWorker),
std::move(*requestData),
std::move(operationName),
enablePythonFuture,
callbackFunction,
callbackData));

// Register delayed submission
req->_worker->registerDelayedSubmission(
req, std::bind(std::mem_fn(&Request::populateDelayedSubmission), req.get()));

return req;
}

} // namespace ucxx
Loading