Skip to content

Commit e7e649e

Browse files
committed
Tweaks.
Signed-off-by: Eric Schweitz <[email protected]>
1 parent 0710e67 commit e7e649e

File tree

1 file changed

+13
-24
lines changed

1 file changed

+13
-24
lines changed

runtime/common/SimulationState.h

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,26 +73,17 @@ class SimulationState {
7373
/// and extract the data pointer and size.
7474
template <typename ScalarType = double>
7575
auto getSizeAndPtr(const state_data &data) {
76-
auto type = data.index();
77-
std::tuple<std::size_t, void *> sizeAndPtr;
78-
if (type ==
79-
cudaq::detail::variant_index<cudaq::state_data,
80-
std::vector<std::complex<double>>>())
81-
sizeAndPtr = getSizeAndPtrFromVec<double, ScalarType>(data);
82-
else if (type ==
83-
cudaq::detail::variant_index<cudaq::state_data,
84-
std::vector<std::complex<float>>>())
85-
sizeAndPtr = getSizeAndPtrFromVec<float, ScalarType>(data);
86-
else if (type == cudaq::detail::variant_index<
87-
cudaq::state_data,
88-
std::pair<std::complex<double> *, std::size_t>>())
89-
sizeAndPtr = getSizeAndPtrFromPair<double, ScalarType>(data);
90-
else if (type == cudaq::detail::variant_index<
91-
cudaq::state_data,
92-
std::pair<std::complex<float> *, std::size_t>>())
93-
sizeAndPtr = getSizeAndPtrFromPair<float, ScalarType>(data);
94-
else if (type == cudaq::detail::variant_index<cudaq::state_data,
95-
complex_matrix>()) {
76+
if (std::holds_alternative<std::vector<std::complex<double>>>(data))
77+
return getSizeAndPtrFromVec<double, ScalarType>(data);
78+
if (std::holds_alternative<std::vector<std::complex<float>>>(data))
79+
return getSizeAndPtrFromVec<float, ScalarType>(data);
80+
if (std::holds_alternative<std::pair<std::complex<double> *, std::size_t>>(
81+
data))
82+
return getSizeAndPtrFromPair<double, ScalarType>(data);
83+
if (std::holds_alternative<std::pair<std::complex<float> *, std::size_t>>(
84+
data))
85+
return getSizeAndPtrFromPair<float, ScalarType>(data);
86+
if (std::holds_alternative<complex_matrix>(data)) {
9687
// Complex matrix is double precision only
9788
if constexpr (!std::is_same_v<double, ScalarType>)
9889
throw std::runtime_error("[sim-state] invalid data precision.");
@@ -108,10 +99,8 @@ class SimulationState {
10899
cMat.size(),
109100
reinterpret_cast<void *>(const_cast<complex_matrix &>(cMat).get_data(
110101
complex_matrix::order::row_major)));
111-
} else
112-
throw std::runtime_error("unsupported data type for state vector.");
113-
114-
return sizeAndPtr;
102+
}
103+
throw std::runtime_error("unsupported data type for state vector.");
115104
}
116105

117106
/// @brief Subclass-specific creator method for

0 commit comments

Comments
 (0)