@@ -234,6 +234,74 @@ static py::buffer_info getCupyBufferInfo(py::buffer cupy_buffer) {
234234 );
235235}
236236
237+ static cudaq::state createStateFromPyBuffer (py::buffer data,
238+ LinkedLibraryHolder &holder) {
239+ const bool isHostData = !py::hasattr (data, " __cuda_array_interface__" );
240+ // Check that the target is GPU-based, i.e., can handle device
241+ // pointer.
242+ if (!holder.getTarget ().config .GpuRequired && !isHostData)
243+ throw std::runtime_error (
244+ fmt::format (" Current target '{}' does not support CuPy arrays." ,
245+ holder.getTarget ().name ));
246+
247+ auto info = isHostData ? data.request () : getCupyBufferInfo (data);
248+ if (info.shape .size () > 2 )
249+ throw std::runtime_error (
250+ " state.from_data only supports 1D or 2D array data." );
251+ if (info.format != py::format_descriptor<std::complex <float >>::format () &&
252+ info.format != py::format_descriptor<std::complex <double >>::format ())
253+ throw std::runtime_error (
254+ " A numpy array with only floating point elements passed "
255+ " to "
256+ " state.from_data. input must be of complex float type, "
257+ " please "
258+ " add to your array creation `dtype=numpy.complex64` if "
259+ " simulation is FP32 and `dtype=numpy.complex128` if "
260+ " simulation if FP64, or dtype=cudaq.complex() for "
261+ " precision-agnostic code" );
262+
263+ if (!isHostData || info.shape .size () == 1 ) {
264+ if (info.format == py::format_descriptor<std::complex <float >>::format ()) {
265+ return state::from_data (std::make_pair (
266+ reinterpret_cast <std::complex <float > *>(info.ptr ), info.size ));
267+ } else {
268+ return state::from_data (std::make_pair (
269+ reinterpret_cast <std::complex <double > *>(info.ptr ), info.size ));
270+ }
271+ } else { // 2D array
272+ const std::size_t rows = info.shape [0 ];
273+ const std::size_t cols = info.shape [1 ];
274+ if (rows != cols)
275+ throw std::runtime_error (
276+ " state.from_data 2D array (density matrix) input must be "
277+ " square matrix data." );
278+ const bool isDoublePrecision =
279+ info.format == py::format_descriptor<std::complex <double >>::format ();
280+ const int64_t dataSize = isDoublePrecision ? sizeof (std::complex <double >)
281+ : sizeof (std::complex <float >);
282+ const bool rowMajor =
283+ info.strides [1 ] ==
284+ dataSize; // check row-major: second stride == element size
285+ const cudaq::complex_matrix::order matOrder =
286+ rowMajor ? cudaq::complex_matrix::order::row_major
287+ : cudaq::complex_matrix::order::column_major;
288+ const cudaq::complex_matrix::Dimensions dim = {rows, cols};
289+ if (isDoublePrecision) {
290+ return state::from_data (cudaq::complex_matrix (
291+ std::vector<cudaq::complex_matrix::value_type>(
292+ reinterpret_cast <std::complex <double > *>(info.ptr ),
293+ reinterpret_cast <std::complex <double > *>(info.ptr ) + info.size ),
294+ dim, matOrder));
295+ } else {
296+ return state::from_data (cudaq::complex_matrix (
297+ std::vector<cudaq::complex_matrix::value_type>(
298+ reinterpret_cast <std::complex <float > *>(info.ptr ),
299+ reinterpret_cast <std::complex <float > *>(info.ptr ) + info.size ),
300+ dim, matOrder));
301+ }
302+ }
303+ }
304+
237305// / @brief Bind the get_state cudaq function
238306void bindPyState (py::module &mod, LinkedLibraryHolder &holder) {
239307 py::enum_<cudaq::InitialState>(mod, " InitialStateType" ,
@@ -344,36 +412,7 @@ void bindPyState(py::module &mod, LinkedLibraryHolder &holder) {
344412 .def_static (
345413 " from_data" ,
346414 [&](py::buffer data) {
347- const bool isHostData =
348- !py::hasattr (data, " __cuda_array_interface__" );
349- // Check that the target is GPU-based, i.e., can handle device
350- // pointer.
351- if (!holder.getTarget ().config .GpuRequired && !isHostData)
352- throw std::runtime_error (fmt::format (
353- " Current target '{}' does not support CuPy arrays." ,
354- holder.getTarget ().name ));
355-
356- auto info = isHostData ? data.request () : getCupyBufferInfo (data);
357- if (info.format ==
358- py::format_descriptor<std::complex <float >>::format ()) {
359- return state::from_data (std::make_pair (
360- reinterpret_cast <std::complex <float > *>(info.ptr ),
361- info.size ));
362- }
363- if (info.format ==
364- py::format_descriptor<std::complex <double >>::format ()) {
365- return state::from_data (std::make_pair (
366- reinterpret_cast <std::complex <double > *>(info.ptr ),
367- info.size ));
368- }
369- throw std::runtime_error (
370- " A numpy array with only floating point elements passed to "
371- " state.from_data. input must be of complex float type, "
372- " please "
373- " add to your array creation `dtype=numpy.complex64` if "
374- " simulation is FP32 and `dtype=numpy.complex128` if "
375- " simulation if FP64, or dtype=cudaq.complex() for "
376- " precision-agnostic code" );
415+ return createStateFromPyBuffer (data, holder);
377416 },
378417 " Return a state from data." )
379418 .def_static (
@@ -658,63 +697,8 @@ index pair.
658697 if (self.get_num_tensors () != 1 )
659698 throw std::runtime_error (" overlap NumPy interop only supported "
660699 " for vector and matrix state data." );
661-
662- const bool isHostData =
663- !py::hasattr (other, " __cuda_array_interface__" );
664- // Check that the target is GPU-based, i.e., can handle device
665- // pointer.
666- if (!holder.getTarget ().config .GpuRequired && !isHostData)
667- throw std::runtime_error (fmt::format (
668- " Current target '{}' does not support CuPy arrays." ,
669- holder.getTarget ().name ));
670- py::buffer_info info =
671- isHostData ? other.request () : getCupyBufferInfo (other);
672-
673- if (info.shape .size () > 2 )
674- throw std::runtime_error (
675- " overlap NumPy/CuPy interop only supported "
676- " for vector and matrix state data." );
677-
678- // Check that the shapes are compatible
679- std::size_t otherNumElements = 1 ;
680- for (std::size_t i = 0 ; std::size_t shapeElement : info.shape ) {
681- otherNumElements *= shapeElement;
682- if (shapeElement != self.get_tensor ().extents [i++])
683- throw std::runtime_error (
684- " overlap error - invalid shape of input buffer." );
685- }
686-
687- // Compute the overlap in the case that the
688- // input buffer is FP64
689- if (info.itemsize == 16 ) {
690- // if this state is FP32, then we have to throw an error
691- if (self.get_precision () == SimulationState::precision::fp32)
692- throw std::runtime_error (
693- " simulation state is FP32 but provided state buffer for "
694- " overlap is FP64." );
695-
696- auto otherState = state::from_data (std::make_pair (
697- reinterpret_cast <complex *>(info.ptr ), otherNumElements));
698- return self.overlap (otherState);
699- }
700-
701- // Compute the overlap in the case that the
702- // input buffer is FP32
703- if (info.itemsize == 8 ) {
704- // if this state is FP64, then we have to throw an error
705- if (self.get_precision () == SimulationState::precision::fp64)
706- throw std::runtime_error (
707- " simulation state is FP64 but provided state buffer for "
708- " overlap is FP32." );
709- auto otherState = state::from_data (std::make_pair (
710- reinterpret_cast <std::complex <float > *>(info.ptr ),
711- otherNumElements));
712- return self.overlap (otherState);
713- }
714-
715- // We only support complex f32 and f64 types
716- throw std::runtime_error (
717- " invalid buffer element type size for overlap computation." );
700+ auto otherState = createStateFromPyBuffer (other, holder);
701+ return self.overlap (otherState);
718702 },
719703 " Compute the overlap between the provided :class:`State`'s." )
720704 .def (
0 commit comments