Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions cpp/include/ucxx/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <chrono>
#include <memory>
#include <string>
#include <utility>

#include <ucp/api/ucp.h>

Expand Down Expand Up @@ -37,6 +38,13 @@ namespace ucxx {
*/
class Request : public Component {
protected:
/// Structure to hold cached request attributes including the debug string
struct RequestAttributes {
ucs_status_t status{UCS_INPROGRESS}; ///< Status of the request
ucs_memory_type memoryType{UCS_MEMORY_TYPE_UNKNOWN}; ///< Memory type of the request
std::string debugString{}; ///< Stored debug string
};

ucs_status_t _status{UCS_INPROGRESS}; ///< Requests status
std::string _status_msg{}; ///< Human-readable status message
void* _request{nullptr}; ///< Pointer to UCP request
Expand All @@ -54,6 +62,8 @@ class Request : public Component {
bool _enablePythonFuture{true}; ///< Whether Python future is enabled for this request
RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback
RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data
RequestAttributes _requestAttr{}; ///< Request attributes queried when request is posted
bool _isRequestAttrValid{false}; ///< Whether the request attributes are valid

/**
* @brief Protected constructor of an abstract `ucxx::Request`.
Expand Down Expand Up @@ -224,6 +234,33 @@ class Request : public Component {
* @return The received buffer (if applicable) or `nullptr`.
*/
[[nodiscard]] virtual std::shared_ptr<Buffer> getRecvBuffer();

/**
* @brief Get the request attributes.
*
* Get the request attributes. If the request attributes are not available yet, this
* method will throw an error.
*
* @throw ucxx::Error if the request attributes are not available yet.
*
* @return A RequestAttributes containing the request attributes.
*/
[[nodiscard]] RequestAttributes getRequestAttributes();

protected:
/**
* @brief Query the UCP request attributes.
*
* Helper method that queries the UCP request for its attributes using ucp_request_query.
* Currently queries for:
* - Request status
* - Memory type
* - Debug string
*
* @return A RequestAttributes containing the query status, request attributes and debug
* string.
*/
void queryRequestAttributes();
};

} // namespace ucxx
13 changes: 12 additions & 1 deletion cpp/include/ucxx/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ class Worker : public Component {
*
* Using a Python future may be requested by specifying `enablePythonFuture`. If a
* Python future is requested, the Python application must then await on this future to
* ensure the transfer has completed. Requires UCXX Python support.
* ensure the transfer has completed.
*
* @note If a `callbackFunction` is specified, the lifetime of `callbackData` and of any
* other objects used in the scope of `callbackFunction` must be guaranteed by the caller
Expand All @@ -997,6 +997,17 @@ class Worker : public Component {
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Query worker attributes.
*
* Queries the worker attributes using ucp_worker_query. This provides information about
* the worker's thread mode and other attributes.
*
* @returns The worker attributes structure.
* @throws ucxx::Error if an error occurred while querying worker attributes.
*/
[[nodiscard]] ucp_worker_attr_t queryAttributes() const;
};

/**
Expand Down
45 changes: 45 additions & 0 deletions cpp/src/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <sstream>
#include <string>
#include <utility>
#include <vector>

#include <ucp/api/ucp.h>

Expand Down Expand Up @@ -239,6 +240,50 @@ void Request::setStatus(ucs_status_t status)

const std::string& Request::getOwnerString() const { return _ownerString; }

void Request::queryRequestAttributes()
{
std::lock_guard<std::recursive_mutex> lock(_mutex);

if (_isRequestAttrValid) return;

ucp_request_attr_t result;

// Get the debug string size from worker attributes
auto worker_attr = _worker->queryAttributes();

// Allocate buffer for debug string with size from worker attributes
std::vector<char> debug_str(worker_attr.max_debug_string, '\0');

result.field_mask = UCP_REQUEST_ATTR_FIELD_STATUS | // Request status
UCP_REQUEST_ATTR_FIELD_MEM_TYPE | // Memory type
UCP_REQUEST_ATTR_FIELD_INFO_STRING | // Debug string
UCP_REQUEST_ATTR_FIELD_INFO_STRING_SIZE; // Debug string size

// Set up the debug string buffer
result.debug_string = debug_str.data();
result.debug_string_size = debug_str.size();

if (UCS_PTR_IS_PTR(_request)) {
result.status = ucp_request_query(_request, &result);
if (result.status == UCS_OK && result.debug_string != nullptr) {
_requestAttr.debugString = std::string(result.debug_string);
_requestAttr.memoryType = result.mem_type;
_requestAttr.status = result.status;
_isRequestAttrValid = true;
}
}
}

Request::RequestAttributes Request::getRequestAttributes()
{
std::lock_guard<std::recursive_mutex> lock(_mutex);

if (_isRequestAttrValid)
return _requestAttr;
else
throw ucxx::Error("Request attributes not available yet");
}

std::shared_ptr<Buffer> Request::getRecvBuffer() { return nullptr; }

} // namespace ucxx
1 change: 1 addition & 0 deletions cpp/src/request_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ void RequestTag::request()

std::lock_guard<std::recursive_mutex> lock(_mutex);
_request = request;
queryRequestAttributes();
}

void RequestTag::populateDelayedSubmission()
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,17 @@ std::string Worker::getInfo()
return utils::decodeTextFileDescriptor(TextFileDescriptor);
}

ucp_worker_attr_t Worker::queryAttributes() const
{
ucp_worker_attr_t attr = {
.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE | // Request thread mode info
UCP_WORKER_ATTR_FIELD_MAX_INFO_STRING // Request debug string size
};

utils::ucsErrorThrow(ucp_worker_query(_handle, &attr));
return attr;
}

bool Worker::isDelayedRequestSubmissionEnabled() const
{
return _delayedSubmissionCollection->isDelayedRequestSubmissionEnabled();
Expand Down
10 changes: 10 additions & 0 deletions cpp/tests/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <algorithm>
#include <memory>
#include <numeric>
#include <string>
#include <tuple>
#include <ucp/api/ucp.h>
#include <ucs/type/status.h>
Expand Down Expand Up @@ -295,6 +296,15 @@ TEST_P(RequestTest, ProgressTag)
requests.push_back(_ep->tagRecv(_recvPtr[0], _messageSize, ucxx::Tag{0}, ucxx::TagMaskFull));
waitRequests(_worker, requests, _progressWorker);

for (const auto& request : requests) {
auto debugString = request->getRequestAttributes().debugString;
// Check that debugString contains the expected host memory length substring
std::string expectedSubstring = "length " + std::to_string(_messageSize);
ASSERT_THAT(debugString, ::testing::HasSubstr(expectedSubstring));
ASSERT_THAT(debugString,
::testing::HasSubstr(_memoryType == UCS_MEMORY_TYPE_HOST ? "host" : "cuda"));
}

copyResults();

// Assert data correctness
Expand Down
11 changes: 11 additions & 0 deletions cpp/tests/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ class WorkerGenericCallbackSingleTest : public WorkerProgressTest {};

TEST_F(WorkerTest, HandleIsValid) { ASSERT_TRUE(_worker->getHandle() != nullptr); }

TEST_F(WorkerTest, QueryAttributes)
{
auto attrs = _worker->queryAttributes();

// Verify that the thread mode field was requested and returned
ASSERT_TRUE(attrs.field_mask & UCP_WORKER_ATTR_FIELD_THREAD_MODE);

// The worker was created with UCS_THREAD_MODE_MULTI in the constructor
ASSERT_EQ(attrs.thread_mode, UCS_THREAD_MODE_MULTI);
}

TEST_P(WorkerCapabilityTest, CheckCapability)
{
ASSERT_EQ(_worker->isDelayedRequestSubmissionEnabled(), _enableDelayedSubmission);
Expand Down