Skip to content

Commit 13f22f0

Browse files
committed
Tidy up state/pystate
Signed-off-by: Thien Nguyen <[email protected]>
1 parent 1345427 commit 13f22f0

File tree

12 files changed

+135
-98
lines changed

12 files changed

+135
-98
lines changed

python/runtime/cudaq/algorithms/py_state.cpp

Lines changed: 71 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,74 @@ static py::buffer_info getCupyBufferInfo(py::buffer cupy_buffer) {
234234
);
235235
}
236236

237+
static cudaq::state createStateFromPyBuffer(py::buffer data,
238+
LinkedLibraryHolder &holder) {
239+
const bool isHostData = !py::hasattr(data, "__cuda_array_interface__");
240+
// Check that the target is GPU-based, i.e., can handle device
241+
// pointer.
242+
if (!holder.getTarget().config.GpuRequired && !isHostData)
243+
throw std::runtime_error(
244+
fmt::format("Current target '{}' does not support CuPy arrays.",
245+
holder.getTarget().name));
246+
247+
auto info = isHostData ? data.request() : getCupyBufferInfo(data);
248+
if (info.shape.size() > 2)
249+
throw std::runtime_error(
250+
"state.from_data only supports 1D or 2D array data.");
251+
if (info.format != py::format_descriptor<std::complex<float>>::format() &&
252+
info.format != py::format_descriptor<std::complex<double>>::format())
253+
throw std::runtime_error(
254+
"A numpy array with only floating point elements passed "
255+
"to "
256+
"state.from_data. input must be of complex float type, "
257+
"please "
258+
"add to your array creation `dtype=numpy.complex64` if "
259+
"simulation is FP32 and `dtype=numpy.complex128` if "
260+
"simulation if FP64, or dtype=cudaq.complex() for "
261+
"precision-agnostic code");
262+
263+
if (!isHostData || info.shape.size() == 1) {
264+
if (info.format == py::format_descriptor<std::complex<float>>::format()) {
265+
return state::from_data(std::make_pair(
266+
reinterpret_cast<std::complex<float> *>(info.ptr), info.size));
267+
} else {
268+
return state::from_data(std::make_pair(
269+
reinterpret_cast<std::complex<double> *>(info.ptr), info.size));
270+
}
271+
} else { // 2D array
272+
const std::size_t rows = info.shape[0];
273+
const std::size_t cols = info.shape[1];
274+
if (rows != cols)
275+
throw std::runtime_error(
276+
"state.from_data 2D array (density matrix) input must be "
277+
"square matrix data.");
278+
const bool isDoublePrecision =
279+
info.format == py::format_descriptor<std::complex<double>>::format();
280+
const int64_t dataSize = isDoublePrecision ? sizeof(std::complex<double>)
281+
: sizeof(std::complex<float>);
282+
const bool rowMajor =
283+
info.strides[1] ==
284+
dataSize; // check row-major: second stride == element size
285+
const cudaq::complex_matrix::order matOrder =
286+
rowMajor ? cudaq::complex_matrix::order::row_major
287+
: cudaq::complex_matrix::order::column_major;
288+
const cudaq::complex_matrix::Dimensions dim = {rows, cols};
289+
if (isDoublePrecision) {
290+
return state::from_data(cudaq::complex_matrix(
291+
std::vector<cudaq::complex_matrix::value_type>(
292+
reinterpret_cast<std::complex<double> *>(info.ptr),
293+
reinterpret_cast<std::complex<double> *>(info.ptr) + info.size),
294+
dim, matOrder));
295+
} else {
296+
return state::from_data(cudaq::complex_matrix(
297+
std::vector<cudaq::complex_matrix::value_type>(
298+
reinterpret_cast<std::complex<float> *>(info.ptr),
299+
reinterpret_cast<std::complex<float> *>(info.ptr) + info.size),
300+
dim, matOrder));
301+
}
302+
}
303+
}
304+
237305
/// @brief Bind the get_state cudaq function
238306
void bindPyState(py::module &mod, LinkedLibraryHolder &holder) {
239307
py::enum_<cudaq::InitialState>(mod, "InitialStateType",
@@ -344,36 +412,7 @@ void bindPyState(py::module &mod, LinkedLibraryHolder &holder) {
344412
.def_static(
345413
"from_data",
346414
[&](py::buffer data) {
347-
const bool isHostData =
348-
!py::hasattr(data, "__cuda_array_interface__");
349-
// Check that the target is GPU-based, i.e., can handle device
350-
// pointer.
351-
if (!holder.getTarget().config.GpuRequired && !isHostData)
352-
throw std::runtime_error(fmt::format(
353-
"Current target '{}' does not support CuPy arrays.",
354-
holder.getTarget().name));
355-
356-
auto info = isHostData ? data.request() : getCupyBufferInfo(data);
357-
if (info.format ==
358-
py::format_descriptor<std::complex<float>>::format()) {
359-
return state::from_data(std::make_pair(
360-
reinterpret_cast<std::complex<float> *>(info.ptr),
361-
info.size));
362-
}
363-
if (info.format ==
364-
py::format_descriptor<std::complex<double>>::format()) {
365-
return state::from_data(std::make_pair(
366-
reinterpret_cast<std::complex<double> *>(info.ptr),
367-
info.size));
368-
}
369-
throw std::runtime_error(
370-
"A numpy array with only floating point elements passed to "
371-
"state.from_data. input must be of complex float type, "
372-
"please "
373-
"add to your array creation `dtype=numpy.complex64` if "
374-
"simulation is FP32 and `dtype=numpy.complex128` if "
375-
"simulation if FP64, or dtype=cudaq.complex() for "
376-
"precision-agnostic code");
415+
return createStateFromPyBuffer(data, holder);
377416
},
378417
"Return a state from data.")
379418
.def_static(
@@ -658,63 +697,8 @@ index pair.
658697
if (self.get_num_tensors() != 1)
659698
throw std::runtime_error("overlap NumPy interop only supported "
660699
"for vector and matrix state data.");
661-
662-
const bool isHostData =
663-
!py::hasattr(other, "__cuda_array_interface__");
664-
// Check that the target is GPU-based, i.e., can handle device
665-
// pointer.
666-
if (!holder.getTarget().config.GpuRequired && !isHostData)
667-
throw std::runtime_error(fmt::format(
668-
"Current target '{}' does not support CuPy arrays.",
669-
holder.getTarget().name));
670-
py::buffer_info info =
671-
isHostData ? other.request() : getCupyBufferInfo(other);
672-
673-
if (info.shape.size() > 2)
674-
throw std::runtime_error(
675-
"overlap NumPy/CuPy interop only supported "
676-
"for vector and matrix state data.");
677-
678-
// Check that the shapes are compatible
679-
std::size_t otherNumElements = 1;
680-
for (std::size_t i = 0; std::size_t shapeElement : info.shape) {
681-
otherNumElements *= shapeElement;
682-
if (shapeElement != self.get_tensor().extents[i++])
683-
throw std::runtime_error(
684-
"overlap error - invalid shape of input buffer.");
685-
}
686-
687-
// Compute the overlap in the case that the
688-
// input buffer is FP64
689-
if (info.itemsize == 16) {
690-
// if this state is FP32, then we have to throw an error
691-
if (self.get_precision() == SimulationState::precision::fp32)
692-
throw std::runtime_error(
693-
"simulation state is FP32 but provided state buffer for "
694-
"overlap is FP64.");
695-
696-
auto otherState = state::from_data(std::make_pair(
697-
reinterpret_cast<complex *>(info.ptr), otherNumElements));
698-
return self.overlap(otherState);
699-
}
700-
701-
// Compute the overlap in the case that the
702-
// input buffer is FP32
703-
if (info.itemsize == 8) {
704-
// if this state is FP64, then we have to throw an error
705-
if (self.get_precision() == SimulationState::precision::fp64)
706-
throw std::runtime_error(
707-
"simulation state is FP64 but provided state buffer for "
708-
"overlap is FP32.");
709-
auto otherState = state::from_data(std::make_pair(
710-
reinterpret_cast<std::complex<float> *>(info.ptr),
711-
otherNumElements));
712-
return self.overlap(otherState);
713-
}
714-
715-
// We only support complex f32 and f64 types
716-
throw std::runtime_error(
717-
"invalid buffer element type size for overlap computation.");
700+
auto otherState = createStateFromPyBuffer(other, holder);
701+
return self.overlap(otherState);
718702
},
719703
"Compute the overlap between the provided :class:`State`'s.")
720704
.def(

runtime/common/SimulationState.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ class SimulationState {
100100
reinterpret_cast<void *>(const_cast<complex_matrix &>(cMat).get_data(
101101
complex_matrix::order::row_major)));
102102
}
103-
throw std::runtime_error("unsupported data type for state vector.");
103+
throw std::runtime_error(
104+
"unsupported data type for state vector/density matrix.");
104105
}
105106

106107
/// @brief Subclass-specific creator method for

runtime/cudaq/CMakeLists.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ add_library(${LIBRARY_NAME}
2323
qis/execution_manager.cpp
2424
qis/remote_state.cpp
2525
qis/state.cpp
26-
utils/cudaq_utils.cpp
27-
utils/matrix.cpp
2826
distributed/mpi_plugin.cpp)
2927

3028
set_property(GLOBAL APPEND PROPERTY CUDAQ_RUNTIME_LIBS ${LIBRARY_NAME})
@@ -41,7 +39,7 @@ if (CUDA_FOUND)
4139
PRIVATE .)
4240

4341
target_link_libraries(${LIBRARY_NAME}
44-
PUBLIC dl cudaq-operator cudaq-common cudaq-nlopt cudaq-ensmallen
42+
PUBLIC dl cudaq-operator cudaq-common cudaq-nlopt cudaq-ensmallen cudaq-utils
4543
PRIVATE nvqir fmt::fmt-header-only CUDA::cudart_static CUDAQTargetConfigUtil)
4644

4745
target_compile_definitions(${LIBRARY_NAME} PRIVATE CUDAQ_HAS_CUDA)
@@ -53,7 +51,7 @@ else()
5351
PRIVATE .)
5452

5553
target_link_libraries(${LIBRARY_NAME}
56-
PUBLIC dl cudaq-operator cudaq-common cudaq-nlopt cudaq-ensmallen
54+
PUBLIC dl cudaq-operator cudaq-common cudaq-nlopt cudaq-ensmallen cudaq-utils
5755
PRIVATE nvqir fmt::fmt-header-only CUDAQTargetConfigUtil)
5856
endif()
5957

@@ -63,6 +61,7 @@ add_subdirectory(platform)
6361
add_subdirectory(builder)
6462
add_subdirectory(domains)
6563
add_subdirectory(operators)
64+
add_subdirectory(utils)
6665

6766
install(TARGETS ${LIBRARY_NAME} EXPORT cudaq-targets DESTINATION lib)
6867

runtime/cudaq/qis/managers/photonics/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ target_include_directories(${LIBRARY_NAME}
1717
$<INSTALL_INTERFACE:include>)
1818

1919
set (PHOTONICS_DEPENDENCIES "")
20-
list(APPEND PHOTONICS_DEPENDENCIES cudaq cudaq-common libqpp fmt::fmt-header-only)
20+
list(APPEND PHOTONICS_DEPENDENCIES cudaq-common cudaq-utils libqpp fmt::fmt-header-only)
2121
add_openmp_configurations(${LIBRARY_NAME} PHOTONICS_DEPENDENCIES)
2222

2323
target_link_libraries(${LIBRARY_NAME}

runtime/cudaq/utils/CMakeLists.txt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# ============================================================================ #
2+
# Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. #
3+
# All rights reserved. #
4+
# #
5+
# This source code and the accompanying materials are made available under #
6+
# the terms of the Apache License 2.0 which accompanies this distribution. #
7+
# ============================================================================ #
8+
9+
set(LIBRARY_NAME cudaq-utils)
10+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-ctad-maybe-unsupported")
11+
set(INTERFACE_POSITION_INDEPENDENT_CODE ON)
12+
13+
# Create the CUDA-Q utils library
14+
add_library(${LIBRARY_NAME} SHARED cudaq_utils.cpp matrix.cpp)
15+
16+
target_include_directories(${LIBRARY_NAME}
17+
PUBLIC $<INSTALL_INTERFACE:include>
18+
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/tpls/eigen>
19+
PRIVATE .)
20+
21+
set_property(GLOBAL APPEND PROPERTY CUDAQ_RUNTIME_LIBS ${LIBRARY_NAME})
22+
23+
install(TARGETS ${LIBRARY_NAME} EXPORT cudaq-utils-targets DESTINATION lib)
24+
25+
install(EXPORT cudaq-utils-targets
26+
FILE CUDAQUtilsTargets.cmake
27+
NAMESPACE cudaq::
28+
DESTINATION lib/cmake/cudaq)

runtime/cudaq/utils/cudaq_utils.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
******************************************************************************/
88

99
#include "cudaq_utils.h"
10-
#include "cudaq/platform.h"
11-
1210
#include <random>
1311

1412
namespace cudaq {

runtime/cudaq/utils/matrix.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* the terms of the Apache License 2.0 which accompanies this distribution. *
77
******************************************************************************/
88

9-
#include "cudaq/utils/matrix.h"
9+
#include "matrix.h"
1010
#include <cmath>
1111
#include <iostream>
1212
#include <sstream>

runtime/nvqir/custatevec/CuStateVecState.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,12 @@ class CusvState : public cudaq::SimulationState {
207207

208208
std::unique_ptr<SimulationState>
209209
createFromSizeAndPtr(std::size_t size, void *ptr, std::size_t type) override {
210+
// custatevec sim doesn't support density matrix states
211+
if (type ==
212+
cudaq::detail::variant_index<cudaq::state_data, complex_matrix>())
213+
throw std::runtime_error(
214+
"[custatevec-state] density matrix state data not supported.");
215+
210216
// If the data is provided as a pointer / size, then
211217
// we assume we do not own it.
212218
bool weOwnTheData = type < 2 ? true : false;

runtime/nvqir/cutensornet/mps_simulation_state.inc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,13 @@ MPSSimulationState<ScalarType>::createFromSizeAndPtr(std::size_t size,
566566
return std::make_unique<MPSSimulationState>(
567567
std::move(state), mpsTensors, scratchPad, m_cutnHandle, m_randomEngine);
568568
}
569+
570+
// Don't allow density matrix
571+
if (dataType ==
572+
cudaq::detail::variant_index<cudaq::state_data, cudaq::complex_matrix>())
573+
throw std::runtime_error(
574+
"[MPSSimulationState] density matrix state data not supported.");
575+
569576
auto [state, mpsTensors] =
570577
createFromStateVec(m_cutnHandle, scratchPad, size,
571578
reinterpret_cast<std::complex<ScalarType> *>(ptr),

runtime/nvqir/cutensornet/tn_simulation_state.inc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,13 @@ TensorNetSimulationState<ScalarType>::createFromSizeAndPtr(
293293
throw std::runtime_error(
294294
"Cannot create tensornet backend's simulation state with MPS tensors.");
295295
}
296+
297+
// Don't allow density matrix
298+
if (dataType ==
299+
cudaq::detail::variant_index<cudaq::state_data, cudaq::complex_matrix>())
300+
throw std::runtime_error(
301+
"[TensorNetSimulationState] density matrix state data not supported.");
302+
296303
std::vector<std::complex<ScalarType>> vec(
297304
reinterpret_cast<std::complex<ScalarType> *>(ptr),
298305
reinterpret_cast<std::complex<ScalarType> *>(ptr) + size);

0 commit comments

Comments
 (0)