|
| 1 | +/*************************************************************** -*- C++ -*- *** |
| 2 | + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. * |
| 3 | + * All rights reserved. * |
| 4 | + * * |
| 5 | + * This source code and the accompanying materials are made available under * |
| 6 | + * the terms of the Apache License 2.0 which accompanies this distribution. * |
| 7 | + ******************************************************************************/ |
| 8 | +#pragma once |
| 9 | + |
| 10 | +#include "common/SimulationState.h" |
| 11 | +#include <iostream> |
| 12 | +#include <memory> |
| 13 | +#include <variant> |
| 14 | +#include <stdexcept> |
| 15 | + |
| 16 | +namespace cudaq { |
| 17 | + |
| 18 | +/// @brief Provides stabilizer simulation state representation using StimData. |
| 19 | +class StimState : public SimulationState { |
| 20 | +public: |
| 21 | + /// @brief Construct from StimData (may copy). |
| 22 | + explicit StimState(const StimData& d) : data_(d.copy()) {} |
| 23 | + |
| 24 | + /// @brief Construct from an rvalue StimData |
| 25 | + explicit StimState(StimData&& d) : data_(std::move(d)) {} |
| 26 | + |
| 27 | + /// @brief Factory for this type from state_data. |
| 28 | + std::unique_ptr<SimulationState> |
| 29 | + createFromData(const state_data& d) override { |
| 30 | + if (!std::holds_alternative<StimData>(d)) |
| 31 | + throw std::runtime_error("[StimState] only supports StimData for initialization."); |
| 32 | + return std::make_unique<StimState>(std::get<StimData>(d)); |
| 33 | + } |
| 34 | + |
| 35 | +protected: |
| 36 | + /// @brief Create from data pointer. |
| 37 | + std::unique_ptr<SimulationState> |
| 38 | + createFromSizeAndPtr(std::size_t, void* ptr, std::size_t dataType) override { |
| 39 | + if (dataType != state_data::variant_type_index<StimData>()) |
| 40 | + throw std::runtime_error("[StimState] only supports StimData for initialization."); |
| 41 | + auto stim_data = static_cast<StimData*>(ptr); |
| 42 | + return std::make_unique<StimState>(*stim_data); |
| 43 | + } |
| 44 | + |
| 45 | +public: |
| 46 | + /// @brief This simulator is not array-like (must use Pauli frame/tableau APIs). |
| 47 | + bool isArrayLike() const override { return false; } |
| 48 | + |
| 49 | + /// @brief Return the number of qubits. |
| 50 | + std::size_t getNumQubits() const override { return data_.num_qubits; } |
| 51 | + |
| 52 | + /// @brief Tensor interface not supported for StimState. |
| 53 | + Tensor getTensor(std::size_t idx = 0) const override { |
| 54 | + throw std::runtime_error("[StimState] Tensor interface not supported."); |
| 55 | + } |
| 56 | + |
| 57 | + std::vector<Tensor> getTensors() const override { return {}; } |
| 58 | + std::size_t getNumTensors() const override { return 0; } |
| 59 | + |
| 60 | + /// @brief Overlap is not implemented for stabilizer states. |
| 61 | + std::complex<double> overlap(const SimulationState& other) override { |
| 62 | + throw std::runtime_error("[StimState] overlap not implemented for stabilizer data."); |
| 63 | + } |
| 64 | + |
| 65 | + /// @brief Amplitude access not supported for StimState. |
| 66 | + std::complex<double> getAmplitude(const std::vector<int>&) override { |
| 67 | + throw std::runtime_error("[StimState] amplitudes not supported for stabilizer states."); |
| 68 | + } |
| 69 | + |
| 70 | + /// @brief Dump stabilizer state summary. |
| 71 | + void dump(std::ostream &os) const override { |
| 72 | + os << "StimState { qubits=" << data_.num_qubits |
| 73 | + << ", msm_err_count=" << data_.msm_err_count |
| 74 | + << ", current_size=" << data_.current_size << " }"; |
| 75 | + // Optionally list the tableau or Pauli frame if desired |
| 76 | + os << "\nTableau X_output:\n"; |
| 77 | + for (const auto& row : data_.tableau.x_output) { |
| 78 | + for (bool b : row) os << (b ? '1' : '0'); |
| 79 | + os << "\n"; |
| 80 | + } |
| 81 | + os << "Tableau Z_output:\n"; |
| 82 | + for (const auto& row : data_.tableau.z_output) { |
| 83 | + for (bool b : row) os << (b ? '1' : '0'); |
| 84 | + os << "\n"; |
| 85 | + } |
| 86 | + os << "PauliFrame X:\n"; |
| 87 | + for (bool b : data_.frame.x) os << (b ? '1' : '0'); |
| 88 | + os << "\nPauliFrame Z:\n"; |
| 89 | + for (bool b : data_.frame.z) os << (b ? '1' : '0'); |
| 90 | + os << "\n"; |
| 91 | + } |
| 92 | + |
| 93 | + /// @brief Precision is always double for stabilizer/Stim data. |
| 94 | + precision getPrecision() const override { return precision::fp64; } |
| 95 | + |
| 96 | + /// @brief Destroy any resources (none needed here). |
| 97 | + void destroyState() override { |
| 98 | + // No-op: All managed by RAII. |
| 99 | + } |
| 100 | + |
| 101 | + /// @brief Returns a const reference to the tableau (stabilizer generator). |
| 102 | + const StimData::TableauClone& getTableau() const { return data_.tableau; } |
| 103 | + |
| 104 | + /// @brief Returns a const reference to the Pauli frame. |
| 105 | + const StimData::PauliFrameClone& getPauliFrame() const { return data_.frame; } |
| 106 | + |
| 107 | + /// @brief Access StimData internals, if needed. |
| 108 | + const StimData& stim_data() const { return data_; } |
| 109 | + |
| 110 | + void set_tableau(const StimData::TableauClone& t) { data_.set_tableau(t); } |
| 111 | + void set_pauli_frame(const StimData::PauliFrameClone& f) { data_.set_pauli_frame(f); } |
| 112 | + void set_current_size(std::size_t s) { data_.set_current_size(s); } |
| 113 | + void set_msm_err_count(std::size_t c) { data_.set_msm_err_count(c); } |
| 114 | + void set_num_qubits(uint64_t n) { data_.set_num_qubits(n); } |
| 115 | + |
| 116 | + |
| 117 | + |
| 118 | + |
| 119 | +private: |
| 120 | + StimData data_; |
| 121 | +}; |
| 122 | + |
| 123 | +} // namespace cudaq |
0 commit comments