Skip to content

Commit ac8ca5e

Browse files
1tnguyenbettinaheim
authored andcommitted
Fix a potential race condition in quantum_platform thread Id to QPU Id map (#3291)
* Fix potential race condition in quantum_platform threadId to QPU Id map Signed-off-by: Thien Nguyen <[email protected]> * Code review: guard the qpu id assignement as well Signed-off-by: Thien Nguyen <[email protected]> --------- Signed-off-by: Thien Nguyen <[email protected]>
1 parent e32fe22 commit ac8ca5e

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

runtime/cudaq/platform/quantum_platform.cpp

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,16 @@ void quantum_platform::set_current_qpu(const std::size_t device_id) {
9191
" is not valid (greater than number of available QPUs: " +
9292
std::to_string(platformNumQPUs) + ").");
9393
}
94-
platformCurrentQPU = device_id;
9594
auto tid = std::hash<std::thread::id>{}(std::this_thread::get_id());
96-
auto iter = threadToQpuId.find(tid);
97-
if (iter != threadToQpuId.end())
98-
iter->second = device_id;
99-
else
100-
threadToQpuId.emplace(tid, device_id);
95+
{
96+
std::unique_lock lock(threadToQpuIdMutex);
97+
platformCurrentQPU = device_id;
98+
auto iter = threadToQpuId.find(tid);
99+
if (iter != threadToQpuId.end())
100+
iter->second = device_id;
101+
else
102+
threadToQpuId.emplace(tid, device_id);
103+
}
101104
}
102105

103106
std::size_t quantum_platform::get_current_qpu() { return platformCurrentQPU; }
@@ -150,9 +153,12 @@ void quantum_platform::launchVQE(const std::string kernelName,
150153
std::size_t qpu_id = 0;
151154

152155
auto tid = std::hash<std::thread::id>{}(std::this_thread::get_id());
153-
auto iter = threadToQpuId.find(tid);
154-
if (iter != threadToQpuId.end())
155-
qpu_id = iter->second;
156+
{
157+
std::shared_lock lock(threadToQpuIdMutex);
158+
auto iter = threadToQpuId.find(tid);
159+
if (iter != threadToQpuId.end())
160+
qpu_id = iter->second;
161+
}
156162

157163
auto &qpu = platformQPUs[qpu_id];
158164
qpu->launchVQE(kernelName, kernelArgs, gradient, H, optimizer, n_params,
@@ -173,10 +179,12 @@ KernelThunkResultType quantum_platform::launchKernel(
173179
std::size_t qpu_id = 0;
174180

175181
auto tid = std::hash<std::thread::id>{}(std::this_thread::get_id());
176-
auto iter = threadToQpuId.find(tid);
177-
if (iter != threadToQpuId.end())
178-
qpu_id = iter->second;
179-
182+
{
183+
std::shared_lock lock(threadToQpuIdMutex);
184+
auto iter = threadToQpuId.find(tid);
185+
if (iter != threadToQpuId.end())
186+
qpu_id = iter->second;
187+
}
180188
auto &qpu = platformQPUs[qpu_id];
181189
return qpu->launchKernel(kernelName, kernelFunc, args, voidStarSize,
182190
resultOffset, rawArgs);
@@ -187,10 +195,12 @@ void quantum_platform::launchKernel(std::string kernelName,
187195
std::size_t qpu_id = 0;
188196

189197
auto tid = std::hash<std::thread::id>{}(std::this_thread::get_id());
190-
auto iter = threadToQpuId.find(tid);
191-
if (iter != threadToQpuId.end())
192-
qpu_id = iter->second;
193-
198+
{
199+
std::shared_lock lock(threadToQpuIdMutex);
200+
auto iter = threadToQpuId.find(tid);
201+
if (iter != threadToQpuId.end())
202+
qpu_id = iter->second;
203+
}
194204
auto &qpu = platformQPUs[qpu_id];
195205
qpu->launchKernel(kernelName, rawArgs);
196206
}
@@ -201,10 +211,12 @@ void quantum_platform::launchSerializedCodeExecution(
201211
std::size_t qpu_id = 0;
202212

203213
auto tid = std::hash<std::thread::id>{}(std::this_thread::get_id());
204-
auto iter = threadToQpuId.find(tid);
205-
if (iter != threadToQpuId.end())
206-
qpu_id = iter->second;
207-
214+
{
215+
std::shared_lock lock(threadToQpuIdMutex);
216+
auto iter = threadToQpuId.find(tid);
217+
if (iter != threadToQpuId.end())
218+
qpu_id = iter->second;
219+
}
208220
auto &qpu = platformQPUs[qpu_id];
209221
qpu->launchSerializedCodeExecution(name, serializeCodeExecutionObject);
210222
}

runtime/cudaq/platform/quantum_platform.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ class quantum_platform {
204204
/// that it is running in a multi-QPU context.
205205
std::unordered_map<std::size_t, std::size_t> threadToQpuId;
206206

207+
/// @brief Mutex to protect access to the thread-QPU map.
208+
std::shared_mutex threadToQpuIdMutex;
209+
207210
/// Optional number of shots.
208211
std::optional<int> platformNumShots;
209212

0 commit comments

Comments
 (0)