Skip to content

Commit a0143e6

Browse files
committed
Tidy up state/pystate
Signed-off-by: Thien Nguyen <[email protected]>
1 parent 1345427 commit a0143e6

File tree

6 files changed

+101
-89
lines changed

6 files changed

+101
-89
lines changed

python/runtime/cudaq/algorithms/py_state.cpp

Lines changed: 71 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,74 @@ static py::buffer_info getCupyBufferInfo(py::buffer cupy_buffer) {
234234
);
235235
}
236236

237+
static cudaq::state createStateFromPyBuffer(py::buffer data,
238+
LinkedLibraryHolder &holder) {
239+
const bool isHostData = !py::hasattr(data, "__cuda_array_interface__");
240+
// Check that the target is GPU-based, i.e., can handle device
241+
// pointer.
242+
if (!holder.getTarget().config.GpuRequired && !isHostData)
243+
throw std::runtime_error(
244+
fmt::format("Current target '{}' does not support CuPy arrays.",
245+
holder.getTarget().name));
246+
247+
auto info = isHostData ? data.request() : getCupyBufferInfo(data);
248+
if (info.shape.size() > 2)
249+
throw std::runtime_error(
250+
"state.from_data only supports 1D or 2D array data.");
251+
if (info.format != py::format_descriptor<std::complex<float>>::format() &&
252+
info.format != py::format_descriptor<std::complex<double>>::format())
253+
throw std::runtime_error(
254+
"A numpy array with only floating point elements passed "
255+
"to "
256+
"state.from_data. input must be of complex float type, "
257+
"please "
258+
"add to your array creation `dtype=numpy.complex64` if "
259+
"simulation is FP32 and `dtype=numpy.complex128` if "
260+
"simulation if FP64, or dtype=cudaq.complex() for "
261+
"precision-agnostic code");
262+
263+
if (!isHostData || info.shape.size() == 1) {
264+
if (info.format == py::format_descriptor<std::complex<float>>::format()) {
265+
return state::from_data(std::make_pair(
266+
reinterpret_cast<std::complex<float> *>(info.ptr), info.size));
267+
} else {
268+
return state::from_data(std::make_pair(
269+
reinterpret_cast<std::complex<double> *>(info.ptr), info.size));
270+
}
271+
} else { // 2D array
272+
const std::size_t rows = info.shape[0];
273+
const std::size_t cols = info.shape[1];
274+
if (rows != cols)
275+
throw std::runtime_error(
276+
"state.from_data 2D array (density matrix) input must be "
277+
"square matrix data.");
278+
const bool isDoublePrecision =
279+
info.format == py::format_descriptor<std::complex<double>>::format();
280+
const int64_t dataSize = isDoublePrecision ? sizeof(std::complex<double>)
281+
: sizeof(std::complex<float>);
282+
const bool rowMajor =
283+
info.strides[1] ==
284+
dataSize; // check row-major: second stride == element size
285+
const cudaq::complex_matrix::order matOrder =
286+
rowMajor ? cudaq::complex_matrix::order::row_major
287+
: cudaq::complex_matrix::order::column_major;
288+
const cudaq::complex_matrix::Dimensions dim = {rows, cols};
289+
if (isDoublePrecision) {
290+
return state::from_data(cudaq::complex_matrix(
291+
std::vector<cudaq::complex_matrix::value_type>(
292+
reinterpret_cast<std::complex<double> *>(info.ptr),
293+
reinterpret_cast<std::complex<double> *>(info.ptr) + info.size),
294+
dim, matOrder));
295+
} else {
296+
return state::from_data(cudaq::complex_matrix(
297+
std::vector<cudaq::complex_matrix::value_type>(
298+
reinterpret_cast<std::complex<float> *>(info.ptr),
299+
reinterpret_cast<std::complex<float> *>(info.ptr) + info.size),
300+
dim, matOrder));
301+
}
302+
}
303+
}
304+
237305
/// @brief Bind the get_state cudaq function
238306
void bindPyState(py::module &mod, LinkedLibraryHolder &holder) {
239307
py::enum_<cudaq::InitialState>(mod, "InitialStateType",
@@ -344,36 +412,7 @@ void bindPyState(py::module &mod, LinkedLibraryHolder &holder) {
344412
.def_static(
345413
"from_data",
346414
[&](py::buffer data) {
347-
const bool isHostData =
348-
!py::hasattr(data, "__cuda_array_interface__");
349-
// Check that the target is GPU-based, i.e., can handle device
350-
// pointer.
351-
if (!holder.getTarget().config.GpuRequired && !isHostData)
352-
throw std::runtime_error(fmt::format(
353-
"Current target '{}' does not support CuPy arrays.",
354-
holder.getTarget().name));
355-
356-
auto info = isHostData ? data.request() : getCupyBufferInfo(data);
357-
if (info.format ==
358-
py::format_descriptor<std::complex<float>>::format()) {
359-
return state::from_data(std::make_pair(
360-
reinterpret_cast<std::complex<float> *>(info.ptr),
361-
info.size));
362-
}
363-
if (info.format ==
364-
py::format_descriptor<std::complex<double>>::format()) {
365-
return state::from_data(std::make_pair(
366-
reinterpret_cast<std::complex<double> *>(info.ptr),
367-
info.size));
368-
}
369-
throw std::runtime_error(
370-
"A numpy array with only floating point elements passed to "
371-
"state.from_data. input must be of complex float type, "
372-
"please "
373-
"add to your array creation `dtype=numpy.complex64` if "
374-
"simulation is FP32 and `dtype=numpy.complex128` if "
375-
"simulation if FP64, or dtype=cudaq.complex() for "
376-
"precision-agnostic code");
415+
return createStateFromPyBuffer(data, holder);
377416
},
378417
"Return a state from data.")
379418
.def_static(
@@ -658,63 +697,8 @@ index pair.
658697
if (self.get_num_tensors() != 1)
659698
throw std::runtime_error("overlap NumPy interop only supported "
660699
"for vector and matrix state data.");
661-
662-
const bool isHostData =
663-
!py::hasattr(other, "__cuda_array_interface__");
664-
// Check that the target is GPU-based, i.e., can handle device
665-
// pointer.
666-
if (!holder.getTarget().config.GpuRequired && !isHostData)
667-
throw std::runtime_error(fmt::format(
668-
"Current target '{}' does not support CuPy arrays.",
669-
holder.getTarget().name));
670-
py::buffer_info info =
671-
isHostData ? other.request() : getCupyBufferInfo(other);
672-
673-
if (info.shape.size() > 2)
674-
throw std::runtime_error(
675-
"overlap NumPy/CuPy interop only supported "
676-
"for vector and matrix state data.");
677-
678-
// Check that the shapes are compatible
679-
std::size_t otherNumElements = 1;
680-
for (std::size_t i = 0; std::size_t shapeElement : info.shape) {
681-
otherNumElements *= shapeElement;
682-
if (shapeElement != self.get_tensor().extents[i++])
683-
throw std::runtime_error(
684-
"overlap error - invalid shape of input buffer.");
685-
}
686-
687-
// Compute the overlap in the case that the
688-
// input buffer is FP64
689-
if (info.itemsize == 16) {
690-
// if this state is FP32, then we have to throw an error
691-
if (self.get_precision() == SimulationState::precision::fp32)
692-
throw std::runtime_error(
693-
"simulation state is FP32 but provided state buffer for "
694-
"overlap is FP64.");
695-
696-
auto otherState = state::from_data(std::make_pair(
697-
reinterpret_cast<complex *>(info.ptr), otherNumElements));
698-
return self.overlap(otherState);
699-
}
700-
701-
// Compute the overlap in the case that the
702-
// input buffer is FP32
703-
if (info.itemsize == 8) {
704-
// if this state is FP64, then we have to throw an error
705-
if (self.get_precision() == SimulationState::precision::fp64)
706-
throw std::runtime_error(
707-
"simulation state is FP64 but provided state buffer for "
708-
"overlap is FP32.");
709-
auto otherState = state::from_data(std::make_pair(
710-
reinterpret_cast<std::complex<float> *>(info.ptr),
711-
otherNumElements));
712-
return self.overlap(otherState);
713-
}
714-
715-
// We only support complex f32 and f64 types
716-
throw std::runtime_error(
717-
"invalid buffer element type size for overlap computation.");
700+
auto otherState = createStateFromPyBuffer(other, holder);
701+
return self.overlap(otherState);
718702
},
719703
"Compute the overlap between the provided :class:`State`'s.")
720704
.def(

runtime/common/SimulationState.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ class SimulationState {
100100
reinterpret_cast<void *>(const_cast<complex_matrix &>(cMat).get_data(
101101
complex_matrix::order::row_major)));
102102
}
103-
throw std::runtime_error("unsupported data type for state vector.");
103+
throw std::runtime_error(
104+
"unsupported data type for state vector/density matrix.");
104105
}
105106

106107
/// @brief Subclass-specific creator method for

runtime/nvqir/custatevec/CuStateVecState.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,12 @@ class CusvState : public cudaq::SimulationState {
207207

208208
std::unique_ptr<SimulationState>
209209
createFromSizeAndPtr(std::size_t size, void *ptr, std::size_t type) override {
210+
// custatevec sim doesn't support density matrix states
211+
if (type ==
212+
cudaq::detail::variant_index<cudaq::state_data, complex_matrix>())
213+
throw std::runtime_error(
214+
"[custatevec-state] density matrix state data not supported.");
215+
210216
// If the data is provided as a pointer / size, then
211217
// we assume we do not own it.
212218
bool weOwnTheData = type < 2 ? true : false;

runtime/nvqir/cutensornet/mps_simulation_state.inc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,13 @@ MPSSimulationState<ScalarType>::createFromSizeAndPtr(std::size_t size,
566566
return std::make_unique<MPSSimulationState>(
567567
std::move(state), mpsTensors, scratchPad, m_cutnHandle, m_randomEngine);
568568
}
569+
570+
// Don't allow density matrix
571+
if (dataType ==
572+
cudaq::detail::variant_index<cudaq::state_data, cudaq::complex_matrix>())
573+
throw std::runtime_error(
574+
"[MPSSimulationState] density matrix state data not supported.");
575+
569576
auto [state, mpsTensors] =
570577
createFromStateVec(m_cutnHandle, scratchPad, size,
571578
reinterpret_cast<std::complex<ScalarType> *>(ptr),

runtime/nvqir/cutensornet/tn_simulation_state.inc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,13 @@ TensorNetSimulationState<ScalarType>::createFromSizeAndPtr(
293293
throw std::runtime_error(
294294
"Cannot create tensornet backend's simulation state with MPS tensors.");
295295
}
296+
297+
// Don't allow density matrix
298+
if (dataType ==
299+
cudaq::detail::variant_index<cudaq::state_data, cudaq::complex_matrix>())
300+
throw std::runtime_error(
301+
"[TensorNetSimulationState] density matrix state data not supported.");
302+
296303
std::vector<std::complex<ScalarType>> vec(
297304
reinterpret_cast<std::complex<ScalarType> *>(ptr),
298305
reinterpret_cast<std::complex<ScalarType> *>(ptr) + size);

runtime/nvqir/qpp/QppCircuitSimulator.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,14 @@ struct QppState : public cudaq::SimulationState {
102102
}
103103

104104
std::unique_ptr<SimulationState>
105-
createFromSizeAndPtr(std::size_t size, void *ptr, std::size_t) override {
105+
createFromSizeAndPtr(std::size_t size, void *ptr,
106+
std::size_t dataType) override {
107+
// Don't allow density matrix
108+
if (dataType ==
109+
cudaq::detail::variant_index<cudaq::state_data, complex_matrix>())
110+
throw std::runtime_error(
111+
"[QppState] density matrix state data not supported.");
112+
106113
return std::make_unique<QppState>(Eigen::Map<qpp::ket>(
107114
reinterpret_cast<std::complex<double> *>(ptr), size));
108115
}

0 commit comments

Comments
 (0)