Skip to content

Commit 4f962e6

Browse files
Strict enforcing on sampled kernels when targeting hardware backends -
no explicit measurement operation allowed. For simulators, issue a warning. Signed-off-by: Pradnya Khalate <[email protected]>
1 parent 9db4859 commit 4f962e6

File tree

7 files changed

+89
-14
lines changed

7 files changed

+89
-14
lines changed

lib/Optimizer/Transforms/QuakeAddMetadata.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace {
2727

2828
/// Define a type to contain the Quake Function Metadata
2929
struct QuakeMetadata {
30+
bool hasMeasurements = false;
3031
bool hasConditionalsOnMeasure = false;
3132

3233
// If the following flag is set, it means we've detected quantum to classical
@@ -99,6 +100,13 @@ struct QuakeFunctionAnalysis {
99100
LLVM_DEBUG(llvm::dbgs()
100101
<< "Function to analyze: " << funcOp.getName() << '\n');
101102
QuakeMetadata data;
103+
104+
// Check for any measurements
105+
funcOp->walk([&](quake::MeasurementInterface meas) {
106+
data.hasMeasurements = true;
107+
return WalkResult::interrupt();
108+
});
109+
102110
SmallPtrSet<Operation *, 8> dirtySet;
103111
funcOp->walk([&](quake::DiscriminateOp disc) {
104112
dirtySet.insert(disc.getOperation());
@@ -209,6 +217,11 @@ class QuakeAddMetadataPass
209217
assert(iter != funcAnalysisInfo.end());
210218
const auto &info = iter->second;
211219

220+
if (info.hasMeasurements) {
221+
auto builder = OpBuilder::atBlockBegin(&funcOp.getBody().front());
222+
funcOp->setAttr("hasMeasurements", builder.getBoolAttr(true));
223+
}
224+
212225
// Did this function have conditionals on measures?
213226
if (info.hasConditionalsOnMeasure) {
214227
// if so, add a function attribute

lib/Optimizer/Transforms/QuakePropagateMetadata.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class QuakePropagateMetadataPass
3737
/// positives. expand-measurements and loop-unrolling may further reduce
3838
/// false positives.
3939
void runOnOperation() override {
40+
static const std::vector<StringRef> attributesToPropagate = {
41+
"qubitMeasurementFeedback", "hasMeasurements"};
4042
ModuleOp moduleOp = getOperation();
4143
/// NOTE: If the module has an occurrence of `quake.apply` then the step to
4244
/// build call graph fails. Hence, we skip the pass in such cases.
@@ -90,15 +92,15 @@ class QuakePropagateMetadataPass
9092
LLVM_DEBUG(llvm::dbgs()
9193
<< "Visiting callee: " << callee.getName() << "\n\n");
9294
for (auto caller : callers) {
93-
9495
LLVM_DEBUG(llvm::dbgs() << " Caller: " << caller.getName() << "\n\n");
95-
if (auto boolAttr = callee->getAttr("qubitMeasurementFeedback")
96-
.dyn_cast_or_null<mlir::BoolAttr>()) {
97-
if (boolAttr.getValue()) {
98-
LLVM_DEBUG(llvm::dbgs()
99-
<< " Propagating qubitMeasurementFeedback attr: "
100-
<< boolAttr << "\n");
101-
caller->setAttr("qubitMeasurementFeedback", boolAttr);
96+
for (auto attribute : attributesToPropagate) {
97+
if (auto boolAttr = callee->getAttr(attribute)
98+
.dyn_cast_or_null<mlir::BoolAttr>()) {
99+
if (boolAttr.getValue()) {
100+
LLVM_DEBUG(llvm::dbgs() << " Propagating " << attribute
101+
<< " attr: " << boolAttr << "\n");
102+
caller->setAttr(attribute, boolAttr);
103+
}
102104
}
103105
}
104106
}

python/cudaq/kernel/kernel_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def __init__(self, argTypeList):
273273
cc.register_dialect(context=self.ctx)
274274
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)
275275

276+
self.hasMeasurements = False
276277
self.conditionalOnMeasure = False
277278
self.regCounter = 0
278279
self.loc = Location.unknown(context=self.ctx)
@@ -1135,6 +1136,7 @@ def mz(self, target, regName=None):
11351136
kernel.mz(target=qubit))
11361137
```
11371138
"""
1139+
self.hasMeasurements = True
11381140
with self.ctx, self.insertPoint, self.loc:
11391141
i1Ty = IntegerType.get_signless(1)
11401142
qubitTy = target.mlirValue.type
@@ -1182,6 +1184,7 @@ def mx(self, target, regName=None):
11821184
kernel.mx(qubit))
11831185
```
11841186
"""
1187+
self.hasMeasurements = True
11851188
with self.ctx, self.insertPoint, self.loc:
11861189
i1Ty = IntegerType.get_signless(1)
11871190
qubitTy = target.mlirValue.type
@@ -1230,6 +1233,7 @@ def my(self, target, regName=None):
12301233
kernel.my(qubit))
12311234
```
12321235
"""
1236+
self.hasMeasurements = True
12331237
with self.ctx, self.insertPoint, self.loc:
12341238
i1Ty = IntegerType.get_signless(1)
12351239
qubitTy = target.mlirValue.type

python/cudaq/runtime/sample.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def sample(kernel,
7171
or a list of such results in the case of `sample` function broadcasting.
7272
"""
7373

74+
has_measurements = False
7475
has_conditionals_on_measure_result = False
7576

7677
if isinstance(kernel, PyKernelDecorator):
@@ -85,20 +86,40 @@ def sample(kernel,
8586
if not hasattr(operation, 'name'):
8687
continue
8788
if nvqppPrefix + kernel.name == operation.name.value:
89+
has_measurements = 'hasMeasurements' in operation.attributes
8890
has_conditionals_on_measure_result = 'qubitMeasurementFeedback' in operation.attributes
8991
break
90-
elif isinstance(kernel, PyKernel) and kernel.conditionalOnMeasure:
91-
has_conditionals_on_measure_result = True
92+
elif isinstance(kernel, PyKernel):
93+
if kernel.hasMeasurements:
94+
has_measurements = True
95+
if kernel.conditionalOnMeasure:
96+
has_conditionals_on_measure_result = True
9297

9398
if explicit_measurements:
9499
if not cudaq_runtime.supportsExplicitMeasurements():
95100
raise RuntimeError(
96-
"The sampling option `explicit_measurements` is not supported on this target."
97-
)
101+
"The sampling option `explicit_measurements` is not supported "
102+
"on this target.")
98103
if has_conditionals_on_measure_result:
99104
raise RuntimeError(
100-
"The sampling option `explicit_measurements` is not supported on kernel with conditional logic on a measurement result."
101-
)
105+
"The sampling option `explicit_measurements` is not supported "
106+
"on kernel with conditional logic on a measurement result.")
107+
if has_measurements:
108+
if cudaq_runtime.isQuantumDevice():
109+
raise RuntimeError(
110+
"Kernels with explicit measurement operations cannot be used with "
111+
"`cudaq.sample` on hardware targets. Please remove all measurements "
112+
"from the kernel, qubits will be automatically measured at the end "
113+
"when sampling a kernel.\n"
114+
"Alternatively, use `cudaq.run` API, if supported on this target.")
115+
elif not explicit_measurements:
116+
print(
117+
"WARNING: Using `cudaq.sample` with a kernel that contains explicit "
118+
"measurements is deprecated and will be disallowed in a future release. "
119+
"Please remove all measurements from the kernel, qubits will be "
120+
"automatically measured at the end when sampling a kernel.\n"
121+
"Alternatively, use `cudaq.run` API which preserves individual "
122+
"measurement results.")
102123

103124
if noise_model != None:
104125
cudaq_runtime.set_noise(noise_model)

runtime/common/DeviceCodeRegistry.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,4 +203,11 @@ bool kernelHasConditionalFeedback(const std::string &kernelName) {
203203
return !quakeCode.empty() &&
204204
quakeCode.find("qubitMeasurementFeedback = true") != std::string::npos;
205205
}
206+
207+
bool kernelHasMeasurements(const std::string &kernelName) {
208+
auto quakeCode = get_quake_by_name(kernelName, false);
209+
return !quakeCode.empty() &&
210+
quakeCode.find("hasMeasurements = true") != std::string::npos;
211+
}
212+
206213
} // namespace cudaq

runtime/cudaq.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ KernelArgsCreator getArgsCreator(const std::string &kernelName);
216216

217217
bool kernelHasConditionalFeedback(const std::string &kernelName);
218218

219+
bool kernelHasMeasurements(const std::string &kernelName);
220+
219221
/// @brief Provide a hook to set the target backend.
220222
void set_target_backend(const char *backend);
221223

runtime/cudaq/algorithms/sample.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
namespace cudaq {
1818
bool kernelHasConditionalFeedback(const std::string &);
19+
bool kernelHasMeasurements(const std::string &);
1920
namespace detail {
2021
bool isKernelGenerated(const std::string &);
2122
}
@@ -99,6 +100,31 @@ runSampling(KernelFunctor &&wrappedKernel, quantum_platform &platform,
99100
auto isQuantumDevice =
100101
!isRemoteSimulator && (platform.is_remote() || platform.is_emulated());
101102

103+
auto hasMeasurements = cudaq::kernelHasMeasurements(kernelName);
104+
if (hasMeasurements) {
105+
if (isQuantumDevice) {
106+
// Hardware: Error immediately
107+
throw std::runtime_error(
108+
"Kernels with explicit measurement operations cannot be used with "
109+
"`cudaq::sample` on hardware targets. Please remove all "
110+
"measurements from the kernel, qubits will be automatically measured "
111+
"at the end when sampling a kernel."
112+
"Alternatively, use `cudaq::run` API, if supported on this target.");
113+
} else {
114+
// Simulators: Warning for now, but indicate future deprecation
115+
if (!explicitMeasurements) {
116+
printf(
117+
"WARNING: Using `cudaq::sample` with a kernel that contains "
118+
"explicit measurements is deprecated and will be disallowed in a "
119+
"future release. Please remove all measurements from the kernel, "
120+
"qubits will be automatically measured at the end when sampling a "
121+
"kernel.\n"
122+
"Alternatively, use `cudaq::run` which preserves individual "
123+
"measurement results.");
124+
}
125+
}
126+
}
127+
102128
// Loop until all shots are returned.
103129
cudaq::sample_result counts;
104130
while (counts.get_total_shots() < static_cast<std::size_t>(shots)) {

0 commit comments

Comments
 (0)