88
99#pragma once
1010
11+ #include " cudaq/utils/cudaq_utils.h"
12+ #include " cudaq/utils/matrix.h"
1113#include < algorithm>
14+ #include < bitset>
1215#include < complex>
1316#include < memory>
1417#include < optional>
@@ -28,10 +31,11 @@ using TensorStateData =
2831// / @brief state_data is a variant type
2932// / encoding different forms of user state vector data
3033// / we support.
31- using state_data = std::variant<
32- std::vector<std::complex <double >>, std::vector<std::complex <float >>,
33- std::pair<std::complex <double > *, std::size_t >,
34- std::pair<std::complex <float > *, std::size_t >, TensorStateData>;
34+ using state_data = std::variant<std::vector<std::complex <double >>,
35+ std::vector<std::complex <float >>,
36+ std::pair<std::complex <double > *, std::size_t >,
37+ std::pair<std::complex <float > *, std::size_t >,
38+ complex_matrix, TensorStateData>;
3539
3640// / @brief The `SimulationState` interface provides and extension point
3741// / for concrete circuit simulation sub-types to describe their
@@ -71,16 +75,41 @@ class SimulationState {
7175 auto getSizeAndPtr (const state_data &data) {
7276 auto type = data.index ();
7377 std::tuple<std::size_t , void *> sizeAndPtr;
74- if (type == 0 )
78+ if (type ==
79+ cudaq::detail::variant_index<cudaq::state_data,
80+ std::vector<std::complex <double >>>())
7581 sizeAndPtr = getSizeAndPtrFromVec<double , ScalarType>(data);
76- else if (type == 1 )
82+ else if (type ==
83+ cudaq::detail::variant_index<cudaq::state_data,
84+ std::vector<std::complex <float >>>())
7785 sizeAndPtr = getSizeAndPtrFromVec<float , ScalarType>(data);
78- else if (type == 2 )
86+ else if (type == cudaq::detail::variant_index<
87+ cudaq::state_data,
88+ std::pair<std::complex <double > *, std::size_t >>())
7989 sizeAndPtr = getSizeAndPtrFromPair<double , ScalarType>(data);
80- else if (type == 3 )
90+ else if (type == cudaq::detail::variant_index<
91+ cudaq::state_data,
92+ std::pair<std::complex <float > *, std::size_t >>())
8193 sizeAndPtr = getSizeAndPtrFromPair<float , ScalarType>(data);
82- else
83- throw std::runtime_error (" unsupported data type for state." );
94+ else if (type == cudaq::detail::variant_index<cudaq::state_data,
95+ complex_matrix>()) {
96+ // Complex matrix is double precision only
97+ if constexpr (!std::is_same_v<double , ScalarType>)
98+ throw std::runtime_error (" [sim-state] invalid data precision." );
99+ auto &cMat = std::get<complex_matrix>(data);
100+ if (cMat.rows () != cMat.cols ())
101+ throw std::runtime_error (
102+ " [sim-state] complex matrix must be square for density matrix." );
103+ // Check that it must be a power of 2
104+ if (std::bitset<64 >(cMat.rows ()).count () != 1 )
105+ throw std::runtime_error (" [sim-state] complex matrix size must be a "
106+ " power of 2 for density matrix." );
107+ return std::make_tuple (
108+ cMat.size (),
109+ reinterpret_cast <void *>(const_cast <complex_matrix &>(cMat).get_data (
110+ complex_matrix::order::row_major)));
111+ } else
112+ throw std::runtime_error (" unsupported data type for state vector." );
84113
85114 return sizeAndPtr;
86115 }
0 commit comments