|
15 | 15 | #include "Trace.h" |
16 | 16 | #include "cudaq/algorithms/optimizer.h" |
17 | 17 | #include "cudaq/operators.h" |
| 18 | +#include <iostream> |
18 | 19 | #include <optional> |
19 | 20 | #include <string_view> |
20 | | -#include <iostream> |
21 | 21 |
|
22 | 22 | #include "nvqir/stim/StimState.h" |
23 | 23 |
|
24 | 24 | namespace cudaq { |
25 | 25 |
|
26 | | -struct RecordStorage{ |
| 26 | +struct RecordStorage { |
27 | 27 |
|
28 | 28 | size_t memory_limit; |
29 | 29 | size_t current_memory; |
30 | 30 | RecordStorage(size_t limit = 1e9) : memory_limit(limit), current_memory(0) {} |
31 | 31 |
|
32 | 32 | std::vector<std::unique_ptr<SimulationState>> recordedStates; |
33 | 33 |
|
34 | | - void save_state(SimulationState *state) { |
35 | | - recordedStates.push_back(clone_state(state)); |
36 | | - } |
37 | | - const std::vector<std::unique_ptr<SimulationState>>& get_recorded_states() const { return recordedStates; } |
38 | | - |
39 | | - void clear() { recordedStates.clear(); } |
40 | | - void dump_recorded_states() const { |
41 | | - for (std::size_t i = 0; i < recordedStates.size(); i++) { |
42 | | - recordedStates[i]->dump(std::cout); |
43 | | - } |
| 34 | + void save_state(SimulationState *state) { |
| 35 | + recordedStates.push_back(clone_state(state)); |
| 36 | + } |
| 37 | + const std::vector<std::unique_ptr<SimulationState>> & |
| 38 | + get_recorded_states() const { |
| 39 | + return recordedStates; |
| 40 | + } |
| 41 | + |
| 42 | + void clear() { recordedStates.clear(); } |
| 43 | + void dump_recorded_states() const { |
| 44 | + for (std::size_t i = 0; i < recordedStates.size(); i++) { |
| 45 | + recordedStates[i]->dump(std::cout); |
44 | 46 | } |
45 | | - |
| 47 | + } |
| 48 | + |
46 | 49 | private: |
47 | | - std::unique_ptr<SimulationState> clone_state(SimulationState* state) { |
48 | | - if (state->isArrayLike()) { |
49 | | - // Handle array-like states (CusvState, etc.) |
50 | | - return clone_array_like_state(state); |
51 | | - } else { |
52 | | - // Handle specialized states (CuDensityMatState, StimState, etc.) |
53 | | - return clone_specialized_state(state); |
54 | | - } |
| 50 | + std::unique_ptr<SimulationState> clone_state(SimulationState *state) { |
| 51 | + if (state->isArrayLike()) { |
| 52 | + // Handle array-like states (CusvState, etc.) |
| 53 | + return clone_array_like_state(state); |
| 54 | + } else { |
| 55 | + // Handle specialized states (CuDensityMatState, StimState, etc.) |
| 56 | + return clone_specialized_state(state); |
| 57 | + } |
| 58 | + } |
| 59 | + |
| 60 | + std::unique_ptr<SimulationState> |
| 61 | + clone_array_like_state(SimulationState *state) { |
| 62 | + auto numQubits = state->getNumQubits(); |
| 63 | + if (numQubits > 20) { // Prevent exponential explosion |
| 64 | + throw std::runtime_error("State too large to clone via amplitudes"); |
| 65 | + } |
| 66 | + |
| 67 | + // Generate all basis states |
| 68 | + auto totalStates = 1ULL << numQubits; |
| 69 | + std::vector<std::vector<int>> basisStates; |
| 70 | + for (size_t i = 0; i < totalStates; ++i) { |
| 71 | + std::vector<int> basis(numQubits); |
| 72 | + for (size_t j = 0; j < numQubits; ++j) { |
| 73 | + basis[j] = (i >> j) & 1; |
| 74 | + } |
| 75 | + basisStates.push_back(basis); |
55 | 76 | } |
56 | | - |
57 | | - std::unique_ptr<SimulationState> clone_array_like_state(SimulationState* state) { |
58 | | - auto numQubits = state->getNumQubits(); |
59 | | - if (numQubits > 20) { // Prevent exponential explosion |
60 | | - throw std::runtime_error("State too large to clone via amplitudes"); |
61 | | - } |
62 | | - |
63 | | - // Generate all basis states |
64 | | - auto totalStates = 1ULL << numQubits; |
65 | | - std::vector<std::vector<int>> basisStates; |
66 | | - for (size_t i = 0; i < totalStates; ++i) { |
67 | | - std::vector<int> basis(numQubits); |
68 | | - for (size_t j = 0; j < numQubits; ++j) { |
69 | | - basis[j] = (i >> j) & 1; |
70 | | - } |
71 | | - basisStates.push_back(basis); |
72 | | - } |
73 | | - |
74 | | - auto amplitudes = state->getAmplitudes(basisStates); |
75 | | - |
76 | | - // Create new state with appropriate precision |
77 | | - if (state->getPrecision() == SimulationState::precision::fp32) { |
78 | | - std::vector<std::complex<float>> floatAmps; |
79 | | - for (const auto& amp : amplitudes) { |
80 | | - floatAmps.emplace_back(static_cast<float>(amp.real()), |
81 | | - static_cast<float>(amp.imag())); |
82 | | - } |
83 | | - return state->createFromData(floatAmps); |
84 | | - } else { |
85 | | - return state->createFromData(amplitudes); |
86 | | - } |
| 77 | + |
| 78 | + auto amplitudes = state->getAmplitudes(basisStates); |
| 79 | + |
| 80 | + // Create new state with appropriate precision |
| 81 | + if (state->getPrecision() == SimulationState::precision::fp32) { |
| 82 | + std::vector<std::complex<float>> floatAmps; |
| 83 | + for (const auto & : amplitudes) { |
| 84 | + floatAmps.emplace_back(static_cast<float>(amp.real()), |
| 85 | + static_cast<float>(amp.imag())); |
| 86 | + } |
| 87 | + return state->createFromData(floatAmps); |
| 88 | + } else { |
| 89 | + return state->createFromData(amplitudes); |
87 | 90 | } |
88 | | - |
89 | | - std::unique_ptr<SimulationState> clone_specialized_state(SimulationState* state) { |
90 | | - // Try dynamic_cast to known types that have clone methods |
91 | | - // this triggerd fatal error: library_types.h: No such file or directory |
92 | | - //if (auto* densityState = dynamic_cast<const CuDensityMatState*>(state)) { |
93 | | - // return CuDensityMatState::clone(*densityState); |
94 | | - //} |
95 | | - |
96 | | - if (auto* stimState = dynamic_cast<const StimState*>(state)) { |
97 | | - return stimState->clone(); |
98 | | - } |
99 | | - |
100 | | - // For unknown specialized types, try createFromSizeAndPtr as fallback |
101 | | - // This might work for some specialized states |
102 | | - // auto tensor = state->getTensor(0); |
103 | | - // return state->createFromSizeAndPtr(tensor.get_num_elements(), tensor.data, 1); |
| 91 | + } |
| 92 | + |
| 93 | + std::unique_ptr<SimulationState> |
| 94 | + clone_specialized_state(SimulationState *state) { |
| 95 | + // Try dynamic_cast to known types that have clone methods |
| 96 | + // this triggerd fatal error: library_types.h: No such file or directory |
| 97 | + // if (auto* densityState = dynamic_cast<const CuDensityMatState*>(state)) { |
| 98 | + // return CuDensityMatState::clone(*densityState); |
| 99 | + //} |
| 100 | + |
| 101 | + if (auto *cloneable = dynamic_cast<ClonableState *>(state)) { |
| 102 | + return cloneable->clone(); |
104 | 103 | } |
| 104 | + |
| 105 | + // Fallback for non-cloneable specialized states |
| 106 | + throw std::runtime_error("Specialized state type does not support cloning"); |
| 107 | + // For unknown specialized types, try createFromSizeAndPtr as fallback |
| 108 | + // This might work for some specialized states |
| 109 | + // auto tensor = state->getTensor(0); |
| 110 | + // return state->createFromSizeAndPtr(tensor.get_num_elements(), |
| 111 | + // tensor.data, 1); |
| 112 | + } |
105 | 113 | }; |
106 | 114 |
|
107 | 115 | /// The ExecutionContext is an abstraction to indicate how a CUDA-Q kernel |
@@ -239,27 +247,24 @@ class ExecutionContext { |
239 | 247 | std::vector<std::vector<bool>> msm_z_flips; // msm_z_flips[error_id][qubit_id] |
240 | 248 |
|
241 | 249 | /// @brief For each shot, this is a vector of error IDs. |
242 | | - /// This is populated when using the "sample" mode (i.e. this->name == "sample") |
243 | | - std::vector<std::vector<std::size_t>> errors_per_shot; // errors_per_shot[shot][error_id] |
| 250 | + /// This is populated when using the "sample" mode (i.e. this->name == |
| 251 | + /// "sample") |
| 252 | + std::vector<std::vector<std::size_t>> |
| 253 | + errors_per_shot; // errors_per_shot[shot][error_id] |
244 | 254 |
|
245 | 255 | /// @brief Save the current simulation state in the recorded states storage. |
246 | | - void save_state(SimulationState *state){ |
247 | | - recordStorage.save_state(state); |
248 | | - } |
| 256 | + void save_state(SimulationState *state) { recordStorage.save_state(state); } |
249 | 257 |
|
250 | 258 | /// @brief Get the recorded states saved during execution. |
251 | | - const std::vector<std::unique_ptr<SimulationState>>& get_recorded_states() const { |
| 259 | + const std::vector<std::unique_ptr<SimulationState>> & |
| 260 | + get_recorded_states() const { |
252 | 261 | return recordStorage.get_recorded_states(); |
253 | 262 | } |
254 | 263 |
|
255 | 264 | /// @brief Clear the recorded states saved during execution. |
256 | | - void clear_recorded_states(){ |
257 | | - recordStorage.clear(); |
258 | | - } |
| 265 | + void clear_recorded_states() { recordStorage.clear(); } |
259 | 266 |
|
260 | 267 | /// @brief Dump the recorded states saved during execution. |
261 | | - void dump_recorded_states() const{ |
262 | | - recordStorage.dump_recorded_states(); |
263 | | - } |
| 268 | + void dump_recorded_states() const { recordStorage.dump_recorded_states(); } |
264 | 269 | }; |
265 | 270 | } // namespace cudaq |
0 commit comments