Skip to content

Commit 06a0e37

Browse files
committed
wip, flattened type introduced
Signed-off-by: Kevin Mato <[email protected]>
1 parent 77530f7 commit 06a0e37

File tree

3 files changed

+264
-90
lines changed

3 files changed

+264
-90
lines changed

runtime/common/ExecutionContext.h

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
#include "cudaq/operators.h"
1818
#include <optional>
1919
#include <string_view>
20+
#include <iostream>
21+
22+
#include "nvqir/stim/StimState.h"
2023

2124
namespace cudaq {
2225

@@ -27,18 +30,79 @@ struct RecordStorage{
2730
RecordStorage(size_t limit = 1e9) : memory_limit(limit), current_memory(0) {}
2831

2932
std::vector<std::unique_ptr<SimulationState>> recordedStates;
30-
void save_state(const SimulationState *state) {
31-
recordedStates.push_back(std::make_unique<SimulationState>(*state););
33+
34+
void save_state(SimulationState *state) {
35+
recordedStates.push_back(clone_state(state));
3236
}
33-
std::vector<std::unique_ptr<SimulationState>> get_recorded_states() const { return recordedStates; }
37+
const std::vector<std::unique_ptr<SimulationState>>& get_recorded_states() const { return recordedStates; }
3438

3539
void clear() { recordedStates.clear(); }
3640
void dump_recorded_states() const {
3741
for (std::size_t i = 0; i < recordedStates.size(); i++) {
3842
recordedStates[i]->dump(std::cout);
3943
}
4044
}
41-
}
45+
46+
private:
47+
std::unique_ptr<SimulationState> clone_state(SimulationState* state) {
48+
if (state->isArrayLike()) {
49+
// Handle array-like states (CusvState, etc.)
50+
return clone_array_like_state(state);
51+
} else {
52+
// Handle specialized states (CuDensityMatState, StimState, etc.)
53+
return clone_specialized_state(state);
54+
}
55+
}
56+
57+
std::unique_ptr<SimulationState> clone_array_like_state(SimulationState* state) {
58+
auto numQubits = state->getNumQubits();
59+
if (numQubits > 20) { // Prevent exponential explosion
60+
throw std::runtime_error("State too large to clone via amplitudes");
61+
}
62+
63+
// Generate all basis states
64+
auto totalStates = 1ULL << numQubits;
65+
std::vector<std::vector<int>> basisStates;
66+
for (size_t i = 0; i < totalStates; ++i) {
67+
std::vector<int> basis(numQubits);
68+
for (size_t j = 0; j < numQubits; ++j) {
69+
basis[j] = (i >> j) & 1;
70+
}
71+
basisStates.push_back(basis);
72+
}
73+
74+
auto amplitudes = state->getAmplitudes(basisStates);
75+
76+
// Create new state with appropriate precision
77+
if (state->getPrecision() == SimulationState::precision::fp32) {
78+
std::vector<std::complex<float>> floatAmps;
79+
for (const auto& amp : amplitudes) {
80+
floatAmps.emplace_back(static_cast<float>(amp.real()),
81+
static_cast<float>(amp.imag()));
82+
}
83+
return state->createFromData(floatAmps);
84+
} else {
85+
return state->createFromData(amplitudes);
86+
}
87+
}
88+
89+
std::unique_ptr<SimulationState> clone_specialized_state(SimulationState* state) {
90+
// Try dynamic_cast to known types that have clone methods
91+
// this triggerd fatal error: library_types.h: No such file or directory
92+
//if (auto* densityState = dynamic_cast<const CuDensityMatState*>(state)) {
93+
// return CuDensityMatState::clone(*densityState);
94+
//}
95+
96+
if (auto* stimState = dynamic_cast<const StimState*>(state)) {
97+
return stimState->clone();
98+
}
99+
100+
// For unknown specialized types, try createFromSizeAndPtr as fallback
101+
// This might work for some specialized states
102+
// auto tensor = state->getTensor(0);
103+
// return state->createFromSizeAndPtr(tensor.get_num_elements(), tensor.data, 1);
104+
}
105+
};
42106

43107
/// The ExecutionContext is an abstraction to indicate how a CUDA-Q kernel
44108
/// should be executed.
@@ -179,12 +243,12 @@ class ExecutionContext {
179243
std::vector<std::vector<std::size_t>> errors_per_shot; // errors_per_shot[shot][error_id]
180244

181245
/// @brief Save the current simulation state in the recorded states storage.
182-
void save_state(const SimulationState *state){
246+
void save_state(SimulationState *state){
183247
recordStorage.save_state(state);
184248
}
185249

186250
/// @brief Get the recorded states saved during execution.
187-
std::vector<std::unique_ptr<SimulationState>> get_recorded_states() const {
251+
const std::vector<std::unique_ptr<SimulationState>>& get_recorded_states() const {
188252
return recordStorage.get_recorded_states();
189253
}
190254

runtime/common/SimulationState.h

Lines changed: 10 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -21,50 +21,16 @@ class SimulationState;
2121
/// Enum to specify the initial quantum state.
2222
enum class InitialState { ZERO, UNIFORM };
2323

24-
struct StimData {
25-
struct TableauClone {
26-
std::vector<std::vector<bool>> x_output;
27-
std::vector<std::vector<bool>> z_output;
28-
std::vector<bool> r_output;
29-
30-
TableauClone copy() const {
31-
TableauClone t;
32-
t.x_output = x_output;
33-
t.z_output = z_output;
34-
t.r_output = r_output;
35-
return t;
36-
}
37-
} tableau;
38-
39-
struct PauliFrameClone {
40-
std::vector<bool> x;
41-
std::vector<bool> z;
42-
43-
PauliFrameClone copy() const {
44-
PauliFrameClone f;
45-
f.x = x;
46-
f.z = z;
47-
return f;
48-
}
49-
} frame;
50-
51-
std::size_t current_size = 0;
52-
std::size_t msm_err_count = 0;
53-
uint64_t num_qubits = 0;
54-
void* data() { return nullptr; }
55-
56-
// Copy helper for the entire StimData
57-
StimData copy() const {
58-
StimData s;
59-
s.tableau = tableau.copy();
60-
s.frame = frame.copy();
61-
s.current_size = current_size;
62-
s.msm_err_count = msm_err_count;
63-
s.num_qubits = num_qubits;
64-
return s;
65-
}
66-
};
6724

25+
/// @brief StimData now stores a list of (pointer, size) pairs
26+
/// according to convention:
27+
/// 0: pointer to num_qubits, size = 1
28+
/// 1: pointer to msm_err_count, size = 1
29+
/// 2: pointer to num_stabilizers, size = 1
30+
/// 3: x_output array, size = x_output_size (X stabilizers + destabilisers + phase bits)
31+
/// 4: z_output array, size = z_output_size (Z stabilizers + destabilisers + phase bits)
32+
/// 5: frame array (Pauli frame), size = 2*num_qubits
33+
using StimData = std::vector<std::pair<void*, std::size_t>>;
6834

6935
/// @brief Encapsulates a list of tensors (data pointer and dimensions).
7036
// Note: tensor data is expected in column-major.
@@ -183,7 +149,7 @@ class SimulationState {
183149
"Cannot initialize state vector/density matrix state by stim "
184150
"data. Please use stabilizer simulator backends.");
185151
auto &dataCasted = std::get<StimData>(data);
186-
return createFromSizeAndPtr(1, const_cast<StimData*>(&dataCasted), data.index());
152+
return createFromSizeAndPtr(dataCasted.size(), const_cast<StimData*>(&dataCasted), data.index());
187153
}
188154
// Flat array state data
189155
// Check the precision first. Get the size and

0 commit comments

Comments
 (0)