@@ -73,26 +73,17 @@ class SimulationState {
7373 // / and extract the data pointer and size.
7474 template <typename ScalarType = double >
7575 auto getSizeAndPtr (const state_data &data) {
76- auto type = data.index ();
77- std::tuple<std::size_t , void *> sizeAndPtr;
78- if (type ==
79- cudaq::detail::variant_index<cudaq::state_data,
80- std::vector<std::complex <double >>>())
81- sizeAndPtr = getSizeAndPtrFromVec<double , ScalarType>(data);
82- else if (type ==
83- cudaq::detail::variant_index<cudaq::state_data,
84- std::vector<std::complex <float >>>())
85- sizeAndPtr = getSizeAndPtrFromVec<float , ScalarType>(data);
86- else if (type == cudaq::detail::variant_index<
87- cudaq::state_data,
88- std::pair<std::complex <double > *, std::size_t >>())
89- sizeAndPtr = getSizeAndPtrFromPair<double , ScalarType>(data);
90- else if (type == cudaq::detail::variant_index<
91- cudaq::state_data,
92- std::pair<std::complex <float > *, std::size_t >>())
93- sizeAndPtr = getSizeAndPtrFromPair<float , ScalarType>(data);
94- else if (type == cudaq::detail::variant_index<cudaq::state_data,
95- complex_matrix>()) {
76+ if (std::holds_alternative<std::vector<std::complex <double >>>(data))
77+ return getSizeAndPtrFromVec<double , ScalarType>(data);
78+ if (std::holds_alternative<std::vector<std::complex <float >>>(data))
79+ return getSizeAndPtrFromVec<float , ScalarType>(data);
80+ if (std::holds_alternative<std::pair<std::complex <double > *, std::size_t >>(
81+ data))
82+ return getSizeAndPtrFromPair<double , ScalarType>(data);
83+ if (std::holds_alternative<std::pair<std::complex <float > *, std::size_t >>(
84+ data))
85+ return getSizeAndPtrFromPair<float , ScalarType>(data);
86+ if (std::holds_alternative<complex_matrix>(data)) {
9687 // Complex matrix is double precision only
9788 if constexpr (!std::is_same_v<double , ScalarType>)
9889 throw std::runtime_error (" [sim-state] invalid data precision." );
@@ -108,10 +99,8 @@ class SimulationState {
10899 cMat.size (),
109100 reinterpret_cast <void *>(const_cast <complex_matrix &>(cMat).get_data (
110101 complex_matrix::order::row_major)));
111- } else
112- throw std::runtime_error (" unsupported data type for state vector." );
113-
114- return sizeAndPtr;
102+ }
103+ throw std::runtime_error (" unsupported data type for state vector." );
115104 }
116105
117106 // / @brief Subclass-specific creator method for
0 commit comments