diff --git a/src/utils/libfabric/libfabric_common.cpp b/src/utils/libfabric/libfabric_common.cpp index 140694e11..f7b6e945b 100644 --- a/src/utils/libfabric/libfabric_common.cpp +++ b/src/utils/libfabric/libfabric_common.cpp @@ -50,6 +50,9 @@ getAvailableNetworkDevices() { hints->mode = FI_CONTEXT; hints->ep_attr->type = FI_EP_RDM; + // Add CXI-compatible memory registration mode + hints->domain_attr->mr_mode = FI_MR_LOCAL | FI_MR_ENDPOINT | FI_MR_ALLOCATED | FI_MR_PROV_KEY; + int ret = fi_getinfo(FI_VERSION(1, 18), NULL, NULL, 0, hints, &info); if (ret) { NIXL_ERROR << "fi_getinfo failed " << fi_strerror(-ret); @@ -85,7 +88,9 @@ getAvailableNetworkDevices() { } } - if (provider_device_map.find("efa") != provider_device_map.end()) { + if (provider_device_map.find("cxi") != provider_device_map.end()) { + return {"cxi", provider_device_map["cxi"]}; + } else if (provider_device_map.find("efa") != provider_device_map.end()) { return {"efa", provider_device_map["efa"]}; } else if (provider_device_map.find("sockets") != provider_device_map.end()) { return {"sockets", {provider_device_map["sockets"][0]}}; diff --git a/src/utils/libfabric/libfabric_rail.cpp b/src/utils/libfabric/libfabric_rail.cpp index f5b155c7b..643dd738a 100644 --- a/src/utils/libfabric/libfabric_rail.cpp +++ b/src/utils/libfabric/libfabric_rail.cpp @@ -421,10 +421,16 @@ nixlLibfabricRail::nixlLibfabricRail(const std::string &device, // TCP provider doesn't support FI_MR_PROV_KEY or FI_MR_VIRT_ADDR, use basic mode hints->domain_attr->mr_mode = FI_MR_LOCAL | FI_MR_ALLOCATED; hints->domain_attr->mr_key_size = 0; // Let provider decide + } else if (provider == "cxi") { + hints->caps |= FI_RMA_EVENT; + hints->domain_attr->mr_mode = + FI_MR_LOCAL | FI_MR_HMEM | FI_MR_VIRT_ADDR | + FI_MR_ALLOCATED | FI_MR_PROV_KEY | FI_MR_ENDPOINT; } else { // EFA and other providers support advanced memory registration hints->domain_attr->mr_mode = - FI_MR_LOCAL | FI_MR_HMEM | FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY; + FI_MR_LOCAL | FI_MR_HMEM | FI_MR_VIRT_ADDR | + FI_MR_ALLOCATED | FI_MR_PROV_KEY; hints->domain_attr->mr_key_size = 2; } hints->domain_attr->name = strdup(device_name.c_str()); @@ -1347,8 +1353,14 @@ nixlLibfabricRail::registerMemory(void *buffer, iov.iov_len = length; mr_attr.mr_iov = &iov; mr_attr.iov_count = 1; + int ret = 0; + + if (provider_name == "cxi") { + ret = fi_mr_regattr(domain, &mr_attr, FI_RMA_EVENT, &mr); + } else { + ret = fi_mr_regattr(domain, &mr_attr, 0, &mr); + } - int ret = fi_mr_regattr(domain, &mr_attr, 0, &mr); if (ret) { NIXL_ERROR << "fi_mr_reg failed on rail " << rail_id << ": " << fi_strerror(-ret) << " (buffer=" << buffer << ", length=" << length @@ -1357,6 +1369,24 @@ nixlLibfabricRail::registerMemory(void *buffer, } *mr_out = mr; + + if (info->domain_attr->mr_mode & FI_MR_ENDPOINT) { + ret = fi_mr_bind(mr, &endpoint->fid, 0); + if (ret) { + NIXL_ERROR << "fi_mr_bind failed on rail " << rail_id << ": " << fi_strerror(-ret); + fi_close(&mr->fid); + return NIXL_ERR_BACKEND; + } + + ret = fi_mr_enable(mr); + if (ret) { + NIXL_ERROR << "fi_mr_enable failed on rail " << rail_id << ": " << fi_strerror(-ret); + fi_close(&mr->fid); + return NIXL_ERR_BACKEND; + } + } + + *key_out = fi_mr_key(mr); NIXL_TRACE << "Memory Registration SUCCESS: rail=" << rail_id << " provider=" << provider_name diff --git a/src/utils/libfabric/libfabric_rail_manager.cpp b/src/utils/libfabric/libfabric_rail_manager.cpp index 5bf23ae16..56b9f2509 100644 --- a/src/utils/libfabric/libfabric_rail_manager.cpp +++ b/src/utils/libfabric/libfabric_rail_manager.cpp @@ -179,7 +179,8 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer( // For TCP providers, use offset 0 instead of virtual address // TCP providers don't support FI_MR_VIRT_ADDR and expect offset-based addressing - if (data_rails_[rail_id]->provider_name == "tcp" || + if (data_rails_[rail_id]->provider_name == "cxi" || + data_rails_[rail_id]->provider_name == "tcp" || data_rails_[rail_id]->provider_name == "sockets") { req->remote_addr = 0; // Use offset 0 for TCP providers NIXL_DEBUG << "TCP provider detected: using offset 0 instead of virtual address " @@ -258,7 +259,8 @@ nixlLibfabricRailManager::prepareAndSubmitTransfer( // For TCP providers, use offset instead of virtual address // TCP providers don't support FI_MR_VIRT_ADDR and expect offset-based addressing - if (data_rails_[rail_id]->provider_name == "tcp" || + if (data_rails_[rail_id]->provider_name == "cxi" || + data_rails_[rail_id]->provider_name == "tcp" || data_rails_[rail_id]->provider_name == "sockets") { req->remote_addr = chunk_offset; // Use chunk offset for TCP providers NIXL_DEBUG << "TCP provider detected: using chunk offset " << chunk_offset