1717#include " cudaq/operators.h"
1818#include < optional>
1919#include < string_view>
20+ #include < iostream>
21+
22+ #include " nvqir/stim/StimState.h"
2023
2124namespace cudaq {
2225
@@ -27,18 +30,79 @@ struct RecordStorage{
2730 RecordStorage (size_t limit = 1e9 ) : memory_limit(limit), current_memory(0 ) {}
2831
2932 std::vector<std::unique_ptr<SimulationState>> recordedStates;
30- void save_state (const SimulationState *state) {
31- recordedStates.push_back (std::make_unique<SimulationState>(*state););
33+
34+ void save_state (SimulationState *state) {
35+ recordedStates.push_back (clone_state (state));
3236 }
33- std::vector<std::unique_ptr<SimulationState>> get_recorded_states () const { return recordedStates; }
37+ const std::vector<std::unique_ptr<SimulationState>>& get_recorded_states () const { return recordedStates; }
3438
3539 void clear () { recordedStates.clear (); }
3640 void dump_recorded_states () const {
3741 for (std::size_t i = 0 ; i < recordedStates.size (); i++) {
3842 recordedStates[i]->dump (std::cout);
3943 }
4044 }
41- }
45+
46+ private:
47+ std::unique_ptr<SimulationState> clone_state (SimulationState* state) {
48+ if (state->isArrayLike ()) {
49+ // Handle array-like states (CusvState, etc.)
50+ return clone_array_like_state (state);
51+ } else {
52+ // Handle specialized states (CuDensityMatState, StimState, etc.)
53+ return clone_specialized_state (state);
54+ }
55+ }
56+
57+ std::unique_ptr<SimulationState> clone_array_like_state (SimulationState* state) {
58+ auto numQubits = state->getNumQubits ();
59+ if (numQubits > 20 ) { // Prevent exponential explosion
60+ throw std::runtime_error (" State too large to clone via amplitudes" );
61+ }
62+
63+ // Generate all basis states
64+ auto totalStates = 1ULL << numQubits;
65+ std::vector<std::vector<int >> basisStates;
66+ for (size_t i = 0 ; i < totalStates; ++i) {
67+ std::vector<int > basis (numQubits);
68+ for (size_t j = 0 ; j < numQubits; ++j) {
69+ basis[j] = (i >> j) & 1 ;
70+ }
71+ basisStates.push_back (basis);
72+ }
73+
74+ auto amplitudes = state->getAmplitudes (basisStates);
75+
76+ // Create new state with appropriate precision
77+ if (state->getPrecision () == SimulationState::precision::fp32) {
78+ std::vector<std::complex <float >> floatAmps;
79+ for (const auto & amp : amplitudes) {
80+ floatAmps.emplace_back (static_cast <float >(amp.real ()),
81+ static_cast <float >(amp.imag ()));
82+ }
83+ return state->createFromData (floatAmps);
84+ } else {
85+ return state->createFromData (amplitudes);
86+ }
87+ }
88+
89+ std::unique_ptr<SimulationState> clone_specialized_state (SimulationState* state) {
90+ // Try dynamic_cast to known types that have clone methods
91+ // this triggerd fatal error: library_types.h: No such file or directory
92+ // if (auto* densityState = dynamic_cast<const CuDensityMatState*>(state)) {
93+ // return CuDensityMatState::clone(*densityState);
94+ // }
95+
96+ if (auto * stimState = dynamic_cast <const StimState*>(state)) {
97+ return stimState->clone ();
98+ }
99+
100+ // For unknown specialized types, try createFromSizeAndPtr as fallback
101+ // This might work for some specialized states
102+ // auto tensor = state->getTensor(0);
103+ // return state->createFromSizeAndPtr(tensor.get_num_elements(), tensor.data, 1);
104+ }
105+ };
42106
43107// / The ExecutionContext is an abstraction to indicate how a CUDA-Q kernel
44108// / should be executed.
@@ -179,12 +243,12 @@ class ExecutionContext {
179243 std::vector<std::vector<std::size_t >> errors_per_shot; // errors_per_shot[shot][error_id]
180244
181245 // / @brief Save the current simulation state in the recorded states storage.
182- void save_state (const SimulationState *state){
246+ void save_state (SimulationState *state){
183247 recordStorage.save_state (state);
184248 }
185249
186250 // / @brief Get the recorded states saved during execution.
187- std::vector<std::unique_ptr<SimulationState>> get_recorded_states () const {
251+ const std::vector<std::unique_ptr<SimulationState>>& get_recorded_states () const {
188252 return recordStorage.get_recorded_states ();
189253 }
190254
0 commit comments