diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index 9e237bf0..748306a6 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -8,6 +8,7 @@ #include #include #include +#include #include @@ -37,6 +38,13 @@ namespace ucxx { */ class Request : public Component { protected: + /// Structure to hold cached request attributes including the debug string + struct RequestAttributes { + ucs_status_t status{UCS_INPROGRESS}; ///< Status of the request + ucs_memory_type memoryType{UCS_MEMORY_TYPE_UNKNOWN}; ///< Memory type of the request + std::string debugString{}; ///< Stored debug string + }; + ucs_status_t _status{UCS_INPROGRESS}; ///< Requests status std::string _status_msg{}; ///< Human-readable status message void* _request{nullptr}; ///< Pointer to UCP request @@ -54,6 +62,8 @@ class Request : public Component { bool _enablePythonFuture{true}; ///< Whether Python future is enabled for this request RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data + RequestAttributes _requestAttr{}; ///< Request attributes queried when request is posted + bool _isRequestAttrValid{false}; ///< Whether the request attributes are valid /** * @brief Protected constructor of an abstract `ucxx::Request`. @@ -224,6 +234,33 @@ class Request : public Component { * @return The received buffer (if applicable) or `nullptr`. */ [[nodiscard]] virtual std::shared_ptr getRecvBuffer(); + + /** + * @brief Get the request attributes. + * + * Get the request attributes. If the request attributes are not available yet, this + * method will throw an error. + * + * @throw ucxx::Error if the request attributes are not available yet. + * + * @return A RequestAttributes containing the request attributes. + */ + [[nodiscard]] RequestAttributes getRequestAttributes(); + + protected: + /** + * @brief Query the UCP request attributes. + * + * Helper method that queries the UCP request for its attributes using ucp_request_query. + * Currently queries for: + * - Request status + * - Memory type + * - Debug string + * + * @return A RequestAttributes containing the query status, request attributes and debug + * string. + */ + void queryRequestAttributes(); }; } // namespace ucxx diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index 2165d623..496171b8 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -977,7 +977,7 @@ class Worker : public Component { * * Using a Python future may be requested by specifying `enablePythonFuture`. If a * Python future is requested, the Python application must then await on this future to - * ensure the transfer has completed. Requires UCXX Python support. + * ensure the transfer has completed. * * @note If a `callbackFunction` is specified, the lifetime of `callbackData` and of any * other objects used in the scope of `callbackFunction` must be guaranteed by the caller @@ -997,6 +997,17 @@ class Worker : public Component { const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); + + /** + * @brief Query worker attributes. + * + * Queries the worker attributes using ucp_worker_query. This provides information about + * the worker's thread mode and other attributes. + * + * @returns The worker attributes structure. + * @throws ucxx::Error if an error occurred while querying worker attributes. + */ + [[nodiscard]] ucp_worker_attr_t queryAttributes() const; }; /** diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 41b6f96e..049c8582 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include @@ -239,6 +240,50 @@ void Request::setStatus(ucs_status_t status) const std::string& Request::getOwnerString() const { return _ownerString; } +void Request::queryRequestAttributes() +{ + std::lock_guard lock(_mutex); + + if (_isRequestAttrValid) return; + + ucp_request_attr_t result; + + // Get the debug string size from worker attributes + auto worker_attr = _worker->queryAttributes(); + + // Allocate buffer for debug string with size from worker attributes + std::vector debug_str(worker_attr.max_debug_string, '\0'); + + result.field_mask = UCP_REQUEST_ATTR_FIELD_STATUS | // Request status + UCP_REQUEST_ATTR_FIELD_MEM_TYPE | // Memory type + UCP_REQUEST_ATTR_FIELD_INFO_STRING | // Debug string + UCP_REQUEST_ATTR_FIELD_INFO_STRING_SIZE; // Debug string size + + // Set up the debug string buffer + result.debug_string = debug_str.data(); + result.debug_string_size = debug_str.size(); + + if (UCS_PTR_IS_PTR(_request)) { + result.status = ucp_request_query(_request, &result); + if (result.status == UCS_OK && result.debug_string != nullptr) { + _requestAttr.debugString = std::string(result.debug_string); + _requestAttr.memoryType = result.mem_type; + _requestAttr.status = result.status; + _isRequestAttrValid = true; + } + } +} + +Request::RequestAttributes Request::getRequestAttributes() +{ + std::lock_guard lock(_mutex); + + if (_isRequestAttrValid) + return _requestAttr; + else + throw ucxx::Error("Request attributes not available yet"); +} + std::shared_ptr Request::getRecvBuffer() { return nullptr; } } // namespace ucxx diff --git a/cpp/src/request_tag.cpp b/cpp/src/request_tag.cpp index a4424a4f..1253478a 100644 --- a/cpp/src/request_tag.cpp +++ b/cpp/src/request_tag.cpp @@ -163,6 +163,7 @@ void RequestTag::request() std::lock_guard lock(_mutex); _request = request; + queryRequestAttributes(); } void RequestTag::populateDelayedSubmission() diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index ca49de12..11de6d63 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -196,6 +196,17 @@ std::string Worker::getInfo() return utils::decodeTextFileDescriptor(TextFileDescriptor); } +ucp_worker_attr_t Worker::queryAttributes() const +{ + ucp_worker_attr_t attr = { + .field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE | // Request thread mode info + UCP_WORKER_ATTR_FIELD_MAX_INFO_STRING // Request debug string size + }; + + utils::ucsErrorThrow(ucp_worker_query(_handle, &attr)); + return attr; +} + bool Worker::isDelayedRequestSubmissionEnabled() const { return _delayedSubmissionCollection->isDelayedRequestSubmissionEnabled(); diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index fa7fb1c1..30cacb33 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -295,6 +296,15 @@ TEST_P(RequestTest, ProgressTag) requests.push_back(_ep->tagRecv(_recvPtr[0], _messageSize, ucxx::Tag{0}, ucxx::TagMaskFull)); waitRequests(_worker, requests, _progressWorker); + for (const auto& request : requests) { + auto debugString = request->getRequestAttributes().debugString; + // Check that debugString contains the expected host memory length substring + std::string expectedSubstring = "length " + std::to_string(_messageSize); + ASSERT_THAT(debugString, ::testing::HasSubstr(expectedSubstring)); + ASSERT_THAT(debugString, + ::testing::HasSubstr(_memoryType == UCS_MEMORY_TYPE_HOST ? "host" : "cuda")); + } + copyResults(); // Assert data correctness diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index 1d33e574..68b4a456 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -108,6 +108,17 @@ class WorkerGenericCallbackSingleTest : public WorkerProgressTest {}; TEST_F(WorkerTest, HandleIsValid) { ASSERT_TRUE(_worker->getHandle() != nullptr); } +TEST_F(WorkerTest, QueryAttributes) +{ + auto attrs = _worker->queryAttributes(); + + // Verify that the thread mode field was requested and returned + ASSERT_TRUE(attrs.field_mask & UCP_WORKER_ATTR_FIELD_THREAD_MODE); + + // The worker was created with UCS_THREAD_MODE_MULTI in the constructor + ASSERT_EQ(attrs.thread_mode, UCS_THREAD_MODE_MULTI); +} + TEST_P(WorkerCapabilityTest, CheckCapability) { ASSERT_EQ(_worker->isDelayedRequestSubmissionEnabled(), _enableDelayedSubmission);