Skip to content

Commit 77530f7

Browse files
committed
wip, I need to make a flattened stim_data type
Signed-off-by: Kevin Mato <[email protected]>
1 parent 0c43b70 commit 77530f7

File tree

3 files changed

+220
-2
lines changed

3 files changed

+220
-2
lines changed

runtime/common/ExecutionContext.h

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,33 @@
2020

2121
namespace cudaq {
2222

23+
struct RecordStorage{
24+
25+
size_t memory_limit;
26+
size_t current_memory;
27+
RecordStorage(size_t limit = 1e9) : memory_limit(limit), current_memory(0) {}
28+
29+
std::vector<std::unique_ptr<SimulationState>> recordedStates;
30+
void save_state(const SimulationState *state) {
31+
recordedStates.push_back(std::make_unique<SimulationState>(*state););
32+
}
33+
std::vector<std::unique_ptr<SimulationState>> get_recorded_states() const { return recordedStates; }
34+
35+
void clear() { recordedStates.clear(); }
36+
void dump_recorded_states() const {
37+
for (std::size_t i = 0; i < recordedStates.size(); i++) {
38+
recordedStates[i]->dump(std::cout);
39+
}
40+
}
41+
}
42+
2343
/// The ExecutionContext is an abstraction to indicate how a CUDA-Q kernel
2444
/// should be executed.
2545
class ExecutionContext {
46+
47+
///@brief record storage for the states saved during execution
48+
RecordStorage recordStorage;
49+
2650
public:
2751
/// @brief The Constructor, takes the name of the context
2852
/// @param n The name of the context
@@ -155,6 +179,23 @@ class ExecutionContext {
155179
std::vector<std::vector<std::size_t>> errors_per_shot; // errors_per_shot[shot][error_id]
156180

157181
/// @brief Save the current simulation state in the recorded states storage.
158-
void save_state(const SimulationState *state);
182+
void save_state(const SimulationState *state){
183+
recordStorage.save_state(state);
184+
}
185+
186+
/// @brief Get the recorded states saved during execution.
187+
std::vector<std::unique_ptr<SimulationState>> get_recorded_states() const {
188+
return recordStorage.get_recorded_states();
189+
}
190+
191+
/// @brief Clear the recorded states saved during execution.
192+
void clear_recorded_states(){
193+
recordStorage.clear();
194+
}
195+
196+
/// @brief Dump the recorded states saved during execution.
197+
void dump_recorded_states() const{
198+
recordStorage.dump_recorded_states();
199+
}
159200
};
160201
} // namespace cudaq

runtime/common/SimulationState.h

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,51 @@ class SimulationState;
2121
/// Enum to specify the initial quantum state.
2222
enum class InitialState { ZERO, UNIFORM };
2323

24+
struct StimData {
25+
struct TableauClone {
26+
std::vector<std::vector<bool>> x_output;
27+
std::vector<std::vector<bool>> z_output;
28+
std::vector<bool> r_output;
29+
30+
TableauClone copy() const {
31+
TableauClone t;
32+
t.x_output = x_output;
33+
t.z_output = z_output;
34+
t.r_output = r_output;
35+
return t;
36+
}
37+
} tableau;
38+
39+
struct PauliFrameClone {
40+
std::vector<bool> x;
41+
std::vector<bool> z;
42+
43+
PauliFrameClone copy() const {
44+
PauliFrameClone f;
45+
f.x = x;
46+
f.z = z;
47+
return f;
48+
}
49+
} frame;
50+
51+
std::size_t current_size = 0;
52+
std::size_t msm_err_count = 0;
53+
uint64_t num_qubits = 0;
54+
void* data() { return nullptr; }
55+
56+
// Copy helper for the entire StimData
57+
StimData copy() const {
58+
StimData s;
59+
s.tableau = tableau.copy();
60+
s.frame = frame.copy();
61+
s.current_size = current_size;
62+
s.msm_err_count = msm_err_count;
63+
s.num_qubits = num_qubits;
64+
return s;
65+
}
66+
};
67+
68+
2469
/// @brief Encapsulates a list of tensors (data pointer and dimensions).
2570
// Note: tensor data is expected in column-major.
2671
using TensorStateData =
@@ -31,7 +76,8 @@ using TensorStateData =
3176
using state_data = std::variant<
3277
std::vector<std::complex<double>>, std::vector<std::complex<float>>,
3378
std::pair<std::complex<double> *, std::size_t>,
34-
std::pair<std::complex<float> *, std::size_t>, TensorStateData>;
79+
std::pair<std::complex<float> *, std::size_t>, TensorStateData,
80+
StimData>;
3581

3682
/// @brief The `SimulationState` interface provides and extension point
3783
/// for concrete circuit simulation sub-types to describe their
@@ -131,6 +177,14 @@ class SimulationState {
131177
const_cast<TensorStateData::value_type *>(dataCasted.data()),
132178
data.index());
133179
}
180+
if (std::holds_alternative<StimData>(data)) {
181+
if (isArrayLike())
182+
throw std::runtime_error(
183+
"Cannot initialize state vector/density matrix state by stim "
184+
"data. Please use stabilizer simulator backends.");
185+
auto &dataCasted = std::get<StimData>(data);
186+
return createFromSizeAndPtr(1, const_cast<StimData*>(&dataCasted), data.index());
187+
}
134188
// Flat array state data
135189
// Check the precision first. Get the size and
136190
// data pointer from the input data.

runtime/nvqir/stim/StimState.h

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*************************************************************** -*- C++ -*- ***
2+
* Copyright (c) 2025 NVIDIA Corporation & Affiliates. *
3+
* All rights reserved. *
4+
* *
5+
* This source code and the accompanying materials are made available under *
6+
* the terms of the Apache License 2.0 which accompanies this distribution. *
7+
******************************************************************************/
8+
#pragma once
9+
10+
#include "common/SimulationState.h"
11+
#include <iostream>
12+
#include <memory>
13+
#include <variant>
14+
#include <stdexcept>
15+
16+
namespace cudaq {
17+
18+
/// @brief Provides stabilizer simulation state representation using StimData.
19+
class StimState : public SimulationState {
20+
public:
21+
/// @brief Construct from StimData (may copy).
22+
explicit StimState(const StimData& d) : data_(d.copy()) {}
23+
24+
/// @brief Construct from an rvalue StimData
25+
explicit StimState(StimData&& d) : data_(std::move(d)) {}
26+
27+
/// @brief Factory for this type from state_data.
28+
std::unique_ptr<SimulationState>
29+
createFromData(const state_data& d) override {
30+
if (!std::holds_alternative<StimData>(d))
31+
throw std::runtime_error("[StimState] only supports StimData for initialization.");
32+
return std::make_unique<StimState>(std::get<StimData>(d));
33+
}
34+
35+
protected:
36+
/// @brief Create from data pointer.
37+
std::unique_ptr<SimulationState>
38+
createFromSizeAndPtr(std::size_t, void* ptr, std::size_t dataType) override {
39+
if (dataType != state_data::variant_type_index<StimData>())
40+
throw std::runtime_error("[StimState] only supports StimData for initialization.");
41+
auto stim_data = static_cast<StimData*>(ptr);
42+
return std::make_unique<StimState>(*stim_data);
43+
}
44+
45+
public:
46+
/// @brief This simulator is not array-like (must use Pauli frame/tableau APIs).
47+
bool isArrayLike() const override { return false; }
48+
49+
/// @brief Return the number of qubits.
50+
std::size_t getNumQubits() const override { return data_.num_qubits; }
51+
52+
/// @brief Tensor interface not supported for StimState.
53+
Tensor getTensor(std::size_t idx = 0) const override {
54+
throw std::runtime_error("[StimState] Tensor interface not supported.");
55+
}
56+
57+
std::vector<Tensor> getTensors() const override { return {}; }
58+
std::size_t getNumTensors() const override { return 0; }
59+
60+
/// @brief Overlap is not implemented for stabilizer states.
61+
std::complex<double> overlap(const SimulationState& other) override {
62+
throw std::runtime_error("[StimState] overlap not implemented for stabilizer data.");
63+
}
64+
65+
/// @brief Amplitude access not supported for StimState.
66+
std::complex<double> getAmplitude(const std::vector<int>&) override {
67+
throw std::runtime_error("[StimState] amplitudes not supported for stabilizer states.");
68+
}
69+
70+
/// @brief Dump stabilizer state summary.
71+
void dump(std::ostream &os) const override {
72+
os << "StimState { qubits=" << data_.num_qubits
73+
<< ", msm_err_count=" << data_.msm_err_count
74+
<< ", current_size=" << data_.current_size << " }";
75+
// Optionally list the tableau or Pauli frame if desired
76+
os << "\nTableau X_output:\n";
77+
for (const auto& row : data_.tableau.x_output) {
78+
for (bool b : row) os << (b ? '1' : '0');
79+
os << "\n";
80+
}
81+
os << "Tableau Z_output:\n";
82+
for (const auto& row : data_.tableau.z_output) {
83+
for (bool b : row) os << (b ? '1' : '0');
84+
os << "\n";
85+
}
86+
os << "PauliFrame X:\n";
87+
for (bool b : data_.frame.x) os << (b ? '1' : '0');
88+
os << "\nPauliFrame Z:\n";
89+
for (bool b : data_.frame.z) os << (b ? '1' : '0');
90+
os << "\n";
91+
}
92+
93+
/// @brief Precision is always double for stabilizer/Stim data.
94+
precision getPrecision() const override { return precision::fp64; }
95+
96+
/// @brief Destroy any resources (none needed here).
97+
void destroyState() override {
98+
// No-op: All managed by RAII.
99+
}
100+
101+
/// @brief Returns a const reference to the tableau (stabilizer generator).
102+
const StimData::TableauClone& getTableau() const { return data_.tableau; }
103+
104+
/// @brief Returns a const reference to the Pauli frame.
105+
const StimData::PauliFrameClone& getPauliFrame() const { return data_.frame; }
106+
107+
/// @brief Access StimData internals, if needed.
108+
const StimData& stim_data() const { return data_; }
109+
110+
void set_tableau(const StimData::TableauClone& t) { data_.set_tableau(t); }
111+
void set_pauli_frame(const StimData::PauliFrameClone& f) { data_.set_pauli_frame(f); }
112+
void set_current_size(std::size_t s) { data_.set_current_size(s); }
113+
void set_msm_err_count(std::size_t c) { data_.set_msm_err_count(c); }
114+
void set_num_qubits(uint64_t n) { data_.set_num_qubits(n); }
115+
116+
117+
118+
119+
private:
120+
StimData data_;
121+
};
122+
123+
} // namespace cudaq

0 commit comments

Comments
 (0)