Skip to content

Commit 0c43b70

Browse files
committed
introduced new operation cudaq::save_state and support
Signed-off-by: Kevin Mato <[email protected]>
1 parent 9269307 commit 0c43b70

File tree

13 files changed

+267
-0
lines changed

13 files changed

+267
-0
lines changed

include/cudaq/Optimizer/CodeGen/QIRFunctionNames.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ static constexpr const char QISApplyKrausChannel[] =
122122

123123
static constexpr const char QISTrap[] = "__quantum__qis__trap";
124124

125+
static constexpr const char QISSaveState[] = "__quantum__qis__save_state";
126+
125127
/// Since apply noise is actually a call back to `C++` code, the `QIR` data type
126128
/// `Array` of `Qubit*` must be converted into a `cudaq::qvector`, which is
127129
/// presently a `std::vector<cudaq::qubit>` but with an extremely restricted

include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,17 @@ def quake_ApplyNoiseOp : QuakeOp<"apply_noise", [AttrSizedOperandSegments]> {
530530
}];
531531
}
532532

533+
def quake_SaveStateOp : QuakeOp<"save_state"> {
534+
let summary = "Save quantum state representation compatible with the simulator.";
535+
let description = [{
536+
This operation provides support for the `cudaq::save_state`
537+
function. This function is only valid in simulation contexts where the
538+
simulator is part of the same process as the C++ host executable itself.
539+
}];
540+
541+
// No arguments are needed.
542+
}
543+
533544
//===----------------------------------------------------------------------===//
534545
// Memory and register conversion instructions: These operations are useful for
535546
// intermediate conversions between memory-SSA and value-SSA semantics and vice

lib/Frontend/nvqpp/ConvertExpr.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,13 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) {
16341634
return false;
16351635
}
16361636

1637+
1638+
1639+
if (funcName == "save_state") {
1640+
builder.create<quake::SaveStateOp>(loc, TypeRange{}, ValueRange{});
1641+
return true;
1642+
}
1643+
16371644
if (funcName == "mx" || funcName == "my" || funcName == "mz") {
16381645
// Measurements always return a bool or a std::vector<bool>.
16391646
bool useStdvec =

lib/Optimizer/Builder/Intrinsics.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,7 @@ static constexpr IntrinsicCode intrinsicTable[] = {
536536
func.func private @__quantum__qis__convert_array_to_stdvector(!qir_array) -> !qir_array
537537
func.func private @__quantum__qis__free_converted_stdvector(!qir_array)
538538
539+
func.func private @__quantum__qis__save_state()
539540
llvm.func @generalizedInvokeWithRotationsControlsTargets(i64, i64, i64, i64, !qir_llvmptr, ...) attributes {sym_visibility = "private"}
540541
llvm.func @__quantum__qis__apply_kraus_channel_generalized(i64, i64, i64, i64, i64, ...) attributes {sym_visibility = "private"}
541542
)#"},

lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,6 +1515,20 @@ struct AllocaOpPattern : public OpConversionPattern<cudaq::cc::AllocaOp> {
15151515
}
15161516
};
15171517

1518+
struct SaveStateOpRewrite
1519+
: public OpConversionPattern<quake::SaveStateOp> {
1520+
using OpConversionPattern::OpConversionPattern;
1521+
1522+
LogicalResult
1523+
matchAndRewrite(quake::SaveStateOp saveState, OpAdaptor adaptor,
1524+
ConversionPatternRewriter &rewriter) const override {
1525+
rewriter.replaceOpWithNewOp<func::CallOp>(
1526+
saveState, TypeRange{}, cudaq::opt::QISSaveState, ValueRange{});
1527+
return success();
1528+
}
1529+
};
1530+
1531+
15181532
/// Convert the quake types in `func::FuncOp` signatures.
15191533
struct FuncSignaturePattern : public OpConversionPattern<func::FuncOp> {
15201534
using OpConversionPattern::OpConversionPattern;

lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ add_cudaq_library(OptTransforms
2929
DistributedDeviceCall.cpp
3030
EraseNoise.cpp
3131
EraseNopCalls.cpp
32+
EraseSaveState.cpp
3233
EraseVectorCopyCtor.cpp
3334
ExpandControlVeqs.cpp
3435
ExpandMeasurements.cpp
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2025 NVIDIA Corporation & Affiliates. *
3+
* All rights reserved. *
4+
* *
5+
* This source code and the accompanying materials are made available under *
6+
* the terms of the Apache License 2.0 which accompanies this distribution. *
7+
******************************************************************************/
8+
9+
#include "PassDetails.h"
10+
#include "cudaq/Optimizer/Builder/Intrinsics.h"
11+
#include "cudaq/Optimizer/Transforms/Passes.h"
12+
#include "mlir/IR/PatternMatch.h"
13+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14+
#include "mlir/Transforms/Passes.h"
15+
16+
namespace cudaq::opt {
17+
#define GEN_PASS_DEF_ERASESAVESTATE
18+
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
19+
} // namespace cudaq::opt
20+
21+
#define DEBUG_TYPE "erase-save-state"
22+
23+
using namespace mlir;
24+
25+
/// \file
26+
/// This pass exists simply to remove all the quake.save_state (and related)
27+
/// Ops from the IR.
28+
29+
30+
namespace {
31+
template <typename Op>
32+
class EraseSaveStatePattern : public OpRewritePattern<Op> {
33+
public:
34+
using OpRewritePattern<Op>::OpRewritePattern;
35+
36+
LogicalResult matchAndRewrite(Op saveState,
37+
PatternRewriter &rewriter) const override {
38+
rewriter.eraseOp(saveState);
39+
return success();
40+
}
41+
};
42+
43+
} // namespace

python/cudaq/kernel/ast_bridge.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2626,6 +2626,11 @@ def bodyBuilder(iterVal):
26262626
quake.ApplyNoiseOp(params, [asVeq], key=key)
26272627
return
26282628

2629+
2630+
if node.func.attr == 'save_state':
2631+
quake.SaveStateOp()
2632+
return
2633+
26292634
if node.func.attr == 'compute_action':
26302635
# There can only be 2 arguments here.
26312636
action = None
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# ============================================================================ #
2+
# Copyright (c) 2025 NVIDIA Corporation & Affiliates. #
3+
# All rights reserved. #
4+
# #
5+
# This source code and the accompanying materials are made available under #
6+
# the terms of the Apache License 2.0 which accompanies this distribution. #
7+
# ============================================================================ #
8+
9+
import os
10+
11+
import pytest
12+
import numpy as np
13+
14+
import cudaq
15+
16+
@pytest.mark.parametrize('target', ['density-matrix-cpu', 'stim'])
17+
def test_save_state_builtin(target: str):
18+
cudaq.set_target(target)
19+
20+
noise = cudaq.NoiseModel()
21+
22+
@cudaq.kernel
23+
def bell_depol2(d: float, flag: bool):
24+
q, r = cudaq.qubit(), cudaq.qubit()
25+
h(q)
26+
cudaq.save_state()
27+
28+
x.ctrl(q, r)
29+
cudaq.save_state()
30+
31+
if flag:
32+
cudaq.apply_noise(cudaq.Depolarization2, d, q, r)
33+
else:
34+
cudaq.apply_noise(cudaq.Depolarization2, [d], q, r)
35+
36+
counts = cudaq.sample(bell_depol2, 0.2, True, noise_model=noise)
37+
assert len(counts) == 4
38+
print(counts)
39+

runtime/common/ExecutionContext.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,5 +142,19 @@ class ExecutionContext {
142142
/// Note: Measurement Syndrome Matrix is defined in
143143
/// https://arxiv.org/pdf/2407.13826.
144144
std::optional<std::pair<std::size_t, std::size_t>> msm_dimensions;
145+
146+
/// @brief For each possible error, this is a "flips" vector of length "num
147+
/// qubits", where "num qubits" is the number of qubits known to the simulator
148+
/// at the time of the error mechanism. This is populated when using the "msm"
149+
/// mode (i.e. this->name == "msm")
150+
std::vector<std::vector<bool>> msm_x_flips; // msm_x_flips[error_id][qubit_id]
151+
std::vector<std::vector<bool>> msm_z_flips; // msm_z_flips[error_id][qubit_id]
152+
153+
/// @brief For each shot, this is a vector of error IDs.
154+
/// This is populated when using the "sample" mode (i.e. this->name == "sample")
155+
std::vector<std::vector<std::size_t>> errors_per_shot; // errors_per_shot[shot][error_id]
156+
157+
/// @brief Save the current simulation state in the recorded states storage.
158+
void save_state(const SimulationState *state);
145159
};
146160
} // namespace cudaq

0 commit comments

Comments
 (0)