Skip to content

Commit b873525

Browse files
committed
wip, save_state lowered. need for library mode and llvm translate
Signed-off-by: Kevin Mato <[email protected]>
1 parent d7f1551 commit b873525

File tree

15 files changed

+256
-218
lines changed

15 files changed

+256
-218
lines changed

lib/Frontend/nvqpp/ConvertExpr.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,8 +1634,6 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) {
16341634
return false;
16351635
}
16361636

1637-
1638-
16391637
if (funcName == "save_state") {
16401638
builder.create<quake::SaveStateOp>(loc, TypeRange{}, ValueRange{});
16411639
return true;

lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,19 +1516,18 @@ struct AllocaOpPattern : public OpConversionPattern<cudaq::cc::AllocaOp> {
15161516
}
15171517
};
15181518

1519-
struct SaveStateOpRewrite
1520-
: public OpConversionPattern<quake::SaveStateOp> {
1519+
struct SaveStateOpRewrite : public OpConversionPattern<quake::SaveStateOp> {
15211520
using OpConversionPattern::OpConversionPattern;
15221521

15231522
LogicalResult
15241523
matchAndRewrite(quake::SaveStateOp saveState, OpAdaptor adaptor,
15251524
ConversionPatternRewriter &rewriter) const override {
1526-
rewriter.replaceOpWithNewOp<func::CallOp>(saveState, TypeRange{}, cudaq::opt::QISSaveState, ValueRange{});
1525+
rewriter.replaceOpWithNewOp<func::CallOp>(
1526+
saveState, TypeRange{}, cudaq::opt::QISSaveState, ValueRange{});
15271527
return success();
15281528
}
15291529
};
15301530

1531-
15321531
/// Convert the quake types in `func::FuncOp` signatures.
15331532
struct FuncSignaturePattern : public OpConversionPattern<func::FuncOp> {
15341533
using OpConversionPattern::OpConversionPattern;

lib/Optimizer/Transforms/EraseSaveState.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include "mlir/IR/PatternMatch.h"
1313
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1414
#include "mlir/Transforms/Passes.h"
15-
15+
1616
namespace cudaq::opt {
1717
#define GEN_PASS_DEF_ERASESAVESTATE
1818
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
@@ -26,7 +26,6 @@ using namespace mlir;
2626
/// This pass exists simply to remove all the quake.save_state (and related)
2727
/// Ops from the IR.
2828

29-
3029
namespace {
3130
template <typename Op>
3231
class EraseSaveStatePattern : public OpRewritePattern<Op> {
@@ -40,4 +39,4 @@ class EraseSaveStatePattern : public OpRewritePattern<Op> {
4039
}
4140
};
4241

43-
} // namespace
42+
} // namespace

runtime/common/ExecutionContext.h

Lines changed: 85 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -15,93 +15,101 @@
1515
#include "Trace.h"
1616
#include "cudaq/algorithms/optimizer.h"
1717
#include "cudaq/operators.h"
18+
#include <iostream>
1819
#include <optional>
1920
#include <string_view>
20-
#include <iostream>
2121

2222
#include "nvqir/stim/StimState.h"
2323

2424
namespace cudaq {
2525

26-
struct RecordStorage{
26+
struct RecordStorage {
2727

2828
size_t memory_limit;
2929
size_t current_memory;
3030
RecordStorage(size_t limit = 1e9) : memory_limit(limit), current_memory(0) {}
3131

3232
std::vector<std::unique_ptr<SimulationState>> recordedStates;
3333

34-
void save_state(SimulationState *state) {
35-
recordedStates.push_back(clone_state(state));
36-
}
37-
const std::vector<std::unique_ptr<SimulationState>>& get_recorded_states() const { return recordedStates; }
38-
39-
void clear() { recordedStates.clear(); }
40-
void dump_recorded_states() const {
41-
for (std::size_t i = 0; i < recordedStates.size(); i++) {
42-
recordedStates[i]->dump(std::cout);
43-
}
34+
void save_state(SimulationState *state) {
35+
recordedStates.push_back(clone_state(state));
36+
}
37+
const std::vector<std::unique_ptr<SimulationState>> &
38+
get_recorded_states() const {
39+
return recordedStates;
40+
}
41+
42+
void clear() { recordedStates.clear(); }
43+
void dump_recorded_states() const {
44+
for (std::size_t i = 0; i < recordedStates.size(); i++) {
45+
recordedStates[i]->dump(std::cout);
4446
}
45-
47+
}
48+
4649
private:
47-
std::unique_ptr<SimulationState> clone_state(SimulationState* state) {
48-
if (state->isArrayLike()) {
49-
// Handle array-like states (CusvState, etc.)
50-
return clone_array_like_state(state);
51-
} else {
52-
// Handle specialized states (CuDensityMatState, StimState, etc.)
53-
return clone_specialized_state(state);
54-
}
50+
std::unique_ptr<SimulationState> clone_state(SimulationState *state) {
51+
if (state->isArrayLike()) {
52+
// Handle array-like states (CusvState, etc.)
53+
return clone_array_like_state(state);
54+
} else {
55+
// Handle specialized states (CuDensityMatState, StimState, etc.)
56+
return clone_specialized_state(state);
57+
}
58+
}
59+
60+
std::unique_ptr<SimulationState>
61+
clone_array_like_state(SimulationState *state) {
62+
auto numQubits = state->getNumQubits();
63+
if (numQubits > 20) { // Prevent exponential explosion
64+
throw std::runtime_error("State too large to clone via amplitudes");
65+
}
66+
67+
// Generate all basis states
68+
auto totalStates = 1ULL << numQubits;
69+
std::vector<std::vector<int>> basisStates;
70+
for (size_t i = 0; i < totalStates; ++i) {
71+
std::vector<int> basis(numQubits);
72+
for (size_t j = 0; j < numQubits; ++j) {
73+
basis[j] = (i >> j) & 1;
74+
}
75+
basisStates.push_back(basis);
5576
}
56-
57-
std::unique_ptr<SimulationState> clone_array_like_state(SimulationState* state) {
58-
auto numQubits = state->getNumQubits();
59-
if (numQubits > 20) { // Prevent exponential explosion
60-
throw std::runtime_error("State too large to clone via amplitudes");
61-
}
62-
63-
// Generate all basis states
64-
auto totalStates = 1ULL << numQubits;
65-
std::vector<std::vector<int>> basisStates;
66-
for (size_t i = 0; i < totalStates; ++i) {
67-
std::vector<int> basis(numQubits);
68-
for (size_t j = 0; j < numQubits; ++j) {
69-
basis[j] = (i >> j) & 1;
70-
}
71-
basisStates.push_back(basis);
72-
}
73-
74-
auto amplitudes = state->getAmplitudes(basisStates);
75-
76-
// Create new state with appropriate precision
77-
if (state->getPrecision() == SimulationState::precision::fp32) {
78-
std::vector<std::complex<float>> floatAmps;
79-
for (const auto& amp : amplitudes) {
80-
floatAmps.emplace_back(static_cast<float>(amp.real()),
81-
static_cast<float>(amp.imag()));
82-
}
83-
return state->createFromData(floatAmps);
84-
} else {
85-
return state->createFromData(amplitudes);
86-
}
77+
78+
auto amplitudes = state->getAmplitudes(basisStates);
79+
80+
// Create new state with appropriate precision
81+
if (state->getPrecision() == SimulationState::precision::fp32) {
82+
std::vector<std::complex<float>> floatAmps;
83+
for (const auto &amp : amplitudes) {
84+
floatAmps.emplace_back(static_cast<float>(amp.real()),
85+
static_cast<float>(amp.imag()));
86+
}
87+
return state->createFromData(floatAmps);
88+
} else {
89+
return state->createFromData(amplitudes);
8790
}
88-
89-
std::unique_ptr<SimulationState> clone_specialized_state(SimulationState* state) {
90-
// Try dynamic_cast to known types that have clone methods
91-
// this triggerd fatal error: library_types.h: No such file or directory
92-
//if (auto* densityState = dynamic_cast<const CuDensityMatState*>(state)) {
93-
// return CuDensityMatState::clone(*densityState);
94-
//}
95-
96-
if (auto* stimState = dynamic_cast<const StimState*>(state)) {
97-
return stimState->clone();
98-
}
99-
100-
// For unknown specialized types, try createFromSizeAndPtr as fallback
101-
// This might work for some specialized states
102-
// auto tensor = state->getTensor(0);
103-
// return state->createFromSizeAndPtr(tensor.get_num_elements(), tensor.data, 1);
91+
}
92+
93+
std::unique_ptr<SimulationState>
94+
clone_specialized_state(SimulationState *state) {
95+
// Try dynamic_cast to known types that have clone methods
96+
// this triggerd fatal error: library_types.h: No such file or directory
97+
// if (auto* densityState = dynamic_cast<const CuDensityMatState*>(state)) {
98+
// return CuDensityMatState::clone(*densityState);
99+
//}
100+
101+
if (auto *cloneable = dynamic_cast<ClonableState *>(state)) {
102+
return cloneable->clone();
104103
}
104+
105+
// Fallback for non-cloneable specialized states
106+
throw std::runtime_error("Specialized state type does not support cloning");
107+
// For unknown specialized types, try createFromSizeAndPtr as fallback
108+
// This might work for some specialized states
109+
// auto tensor = state->getTensor(0);
110+
// return state->createFromSizeAndPtr(tensor.get_num_elements(),
111+
// tensor.data, 1);
112+
}
105113
};
106114

107115
/// The ExecutionContext is an abstraction to indicate how a CUDA-Q kernel
@@ -239,27 +247,24 @@ class ExecutionContext {
239247
std::vector<std::vector<bool>> msm_z_flips; // msm_z_flips[error_id][qubit_id]
240248

241249
/// @brief For each shot, this is a vector of error IDs.
242-
/// This is populated when using the "sample" mode (i.e. this->name == "sample")
243-
std::vector<std::vector<std::size_t>> errors_per_shot; // errors_per_shot[shot][error_id]
250+
/// This is populated when using the "sample" mode (i.e. this->name ==
251+
/// "sample")
252+
std::vector<std::vector<std::size_t>>
253+
errors_per_shot; // errors_per_shot[shot][error_id]
244254

245255
/// @brief Save the current simulation state in the recorded states storage.
246-
void save_state(SimulationState *state){
247-
recordStorage.save_state(state);
248-
}
256+
void save_state(SimulationState *state) { recordStorage.save_state(state); }
249257

250258
/// @brief Get the recorded states saved during execution.
251-
const std::vector<std::unique_ptr<SimulationState>>& get_recorded_states() const {
259+
const std::vector<std::unique_ptr<SimulationState>> &
260+
get_recorded_states() const {
252261
return recordStorage.get_recorded_states();
253262
}
254263

255264
/// @brief Clear the recorded states saved during execution.
256-
void clear_recorded_states(){
257-
recordStorage.clear();
258-
}
265+
void clear_recorded_states() { recordStorage.clear(); }
259266

260267
/// @brief Dump the recorded states saved during execution.
261-
void dump_recorded_states() const{
262-
recordStorage.dump_recorded_states();
263-
}
268+
void dump_recorded_states() const { recordStorage.dump_recorded_states(); }
264269
};
265270
} // namespace cudaq

runtime/common/SimulationState.h

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,21 @@
1717

1818
namespace cudaq {
1919
class SimulationState;
20+
class ClonableState;
2021

2122
/// Enum to specify the initial quantum state.
2223
enum class InitialState { ZERO, UNIFORM };
2324

24-
2525
/// @brief StimData now stores a list of (pointer, size) pairs
2626
/// according to convention:
2727
/// 0: pointer to num_qubits, size = 1
2828
/// 1: pointer to msm_err_count, size = 1
2929
/// 2: pointer to num_stabilizers, size = 1
30-
/// 3: x_output array, size = x_output_size (X stabilizers + destabilisers + phase bits)
31-
/// 4: z_output array, size = z_output_size (Z stabilizers + destabilisers + phase bits)
32-
/// 5: frame array (Pauli frame), size = 2*num_qubits
33-
using StimData = std::vector<std::pair<void*, std::size_t>>;
30+
/// 3: x_output array, size = x_output_size (X stabilizers + destabilisers +
31+
/// phase bits) 4: z_output array, size = z_output_size (Z stabilizers +
32+
/// destabilisers + phase bits) 5: frame array (Pauli frame), size =
33+
/// 2*num_qubits
34+
using StimData = std::vector<std::pair<void *, std::size_t>>;
3435

3536
/// @brief Encapsulates a list of tensors (data pointer and dimensions).
3637
// Note: tensor data is expected in column-major.
@@ -42,8 +43,7 @@ using TensorStateData =
4243
using state_data = std::variant<
4344
std::vector<std::complex<double>>, std::vector<std::complex<float>>,
4445
std::pair<std::complex<double> *, std::size_t>,
45-
std::pair<std::complex<float> *, std::size_t>, TensorStateData,
46-
StimData>;
46+
std::pair<std::complex<float> *, std::size_t>, TensorStateData, StimData>;
4747

4848
/// @brief The `SimulationState` interface provides and extension point
4949
/// for concrete circuit simulation sub-types to describe their
@@ -149,7 +149,8 @@ class SimulationState {
149149
"Cannot initialize state vector/density matrix state by stim "
150150
"data. Please use stabilizer simulator backends.");
151151
auto &dataCasted = std::get<StimData>(data);
152-
return createFromSizeAndPtr(dataCasted.size(), const_cast<StimData*>(&dataCasted), data.index());
152+
return createFromSizeAndPtr(
153+
dataCasted.size(), const_cast<StimData *>(&dataCasted), data.index());
153154
}
154155
// Flat array state data
155156
// Check the precision first. Get the size and
@@ -269,4 +270,12 @@ class SimulationState {
269270
/// @brief Destructor
270271
virtual ~SimulationState() {}
271272
};
273+
274+
/// @brief Interface for SimulationState subtypes that support cloning.
275+
class ClonableState {
276+
public:
277+
virtual ~ClonableState() = default;
278+
virtual std::unique_ptr<SimulationState> clone() const = 0;
279+
};
280+
272281
} // namespace cudaq

runtime/cudaq/builder/kernels.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ std::vector<double> getAlphaY(const std::span<double> data,
5959
/// to its internal representation. This implementation follows the algorithm
6060
/// defined in `https://arxiv.org/pdf/quant-ph/0407010.pdf`.
6161
template <typename Kernel>
62-
void from_state(Kernel &&kernel, QuakeValue &qubits,
63-
const std::span<std::complex<double>> data,
64-
std::size_t inNumQubits = 0) {
62+
inline void from_state(Kernel &&kernel, QuakeValue &qubits,
63+
const std::span<std::complex<double>> data,
64+
std::size_t inNumQubits = 0) {
6565
std::make_signed_t<std::size_t> numQubits =
6666
qubits.constantSize().value_or(inNumQubits);
6767
if (numQubits <= 0)
@@ -113,7 +113,7 @@ void from_state(Kernel &&kernel, QuakeValue &qubits,
113113
/// @brief Construct a CUDA-Q kernel that produces the
114114
/// given state. This overload will return the `kernel_builder` as a
115115
/// `unique_ptr`.
116-
auto from_state(const std::span<std::complex<double>> data) {
116+
inline auto from_state(const std::span<std::complex<double>> data) {
117117
auto numQubits = std::log2(data.size());
118118
std::vector<details::KernelBuilderType> empty;
119119
auto kernel = std::make_unique<kernel_builder<>>(empty);

runtime/cudaq/qis/qubit_qis.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,8 @@ void apply_noise(Args &&...args) {
13861386
details::tuple_slice_last<qubit_arity>(std::forward_as_tuple(args...)));
13871387
}
13881388

1389+
inline void save_state() { return; }
1390+
13891391
} // namespace cudaq
13901392

13911393
#define __qop__ __attribute__((annotate("user_custom_quantum_operation")))

runtime/nvqir/CircuitSimulator.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,13 @@ class CircuitSimulator {
151151

152152
/// @brief Return the internal state representation. This
153153
/// is meant for subtypes to override
154-
virtual std::unique_ptr<cudaq::SimulationState> getSimulationState() = 0;
154+
virtual std::unique_ptr<cudaq::SimulationState> getSimulationState() = 0;
155155

156156
/// @brief Get the current simulation state.
157157
/// The method returns the current state of teh simulation without flushing
158158
/// the gate queue.
159-
virtual std::unique_ptr<cudaq::SimulationState> getCurrentSimulationState() = 0;
159+
virtual std::unique_ptr<cudaq::SimulationState>
160+
getCurrentSimulationState() = 0;
160161

161162
/// @brief Apply exp(-i theta PauliTensorProd) to the underlying state.
162163
/// This must be provided by subclasses.
@@ -553,7 +554,8 @@ class CircuitSimulatorBase : public CircuitSimulator {
553554
/// The method returns the current state of teh simulation without flushing
554555
/// the gate queue.
555556
virtual std::unique_ptr<cudaq::SimulationState> getCurrentSimulationState() {
556-
throw std::runtime_error("Simulation data not available for this simulator backend.");
557+
throw std::runtime_error(
558+
"Simulation data not available for this simulator backend.");
557559
}
558560

559561
/// @brief Handle basic sampling tasks by storing the qubit index for

runtime/nvqir/NVQIR.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -839,15 +839,14 @@ void __quantum__qis__apply_kraus_channel_generalized(
839839
va_end(args);
840840
}
841841

842-
843-
static void
844-
__quantum__qis__save_state() {
842+
static void __quantum__qis__save_state() {
845843

846844
auto *ctx = nvqir::getCircuitSimulatorInternal()->getExecutionContext();
847845
if (!ctx)
848846
return;
849847

850-
std::unique_ptr<cudaq::SimulationState> state = nvqir::getCircuitSimulatorInternal()->getCurrentSimulationState();
848+
std::unique_ptr<cudaq::SimulationState> state =
849+
nvqir::getCircuitSimulatorInternal()->getCurrentSimulationState();
851850
ctx->save_state(state.get());
852851
}
853852

0 commit comments

Comments
 (0)