Skip to content

Commit dcb0044

Browse files
authored
Don't implicit handle density matrix as a flattened vector (#6)
Signed-off-by: Thien Nguyen <[email protected]>
1 parent 5d8fcf5 commit dcb0044

File tree

12 files changed

+106
-26
lines changed

12 files changed

+106
-26
lines changed

runtime/common/ExecutionContext.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include "Future.h"
1212
#include "NoiseModel.h"
1313
#include "SampleResult.h"
14-
#include "SimulationState.h"
1514
#include "Trace.h"
1615
#include "cudaq/algorithms/optimizer.h"
1716
#include "cudaq/operators.h"
@@ -20,6 +19,8 @@
2019

2120
namespace cudaq {
2221

22+
class SimulationState;
23+
2324
/// The ExecutionContext is an abstraction to indicate how a CUDA-Q kernel
2425
/// should be executed.
2526
class ExecutionContext {

runtime/common/SimulationState.h

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
#pragma once
1010

11+
#include "cudaq/utils/cudaq_utils.h"
12+
#include "cudaq/utils/matrix.h"
1113
#include <algorithm>
14+
#include <bitset>
1215
#include <complex>
1316
#include <memory>
1417
#include <optional>
@@ -28,10 +31,11 @@ using TensorStateData =
2831
/// @brief state_data is a variant type
2932
/// encoding different forms of user state vector data
3033
/// we support.
31-
using state_data = std::variant<
32-
std::vector<std::complex<double>>, std::vector<std::complex<float>>,
33-
std::pair<std::complex<double> *, std::size_t>,
34-
std::pair<std::complex<float> *, std::size_t>, TensorStateData>;
34+
using state_data = std::variant<std::vector<std::complex<double>>,
35+
std::vector<std::complex<float>>,
36+
std::pair<std::complex<double> *, std::size_t>,
37+
std::pair<std::complex<float> *, std::size_t>,
38+
complex_matrix, TensorStateData>;
3539

3640
/// @brief The `SimulationState` interface provides and extension point
3741
/// for concrete circuit simulation sub-types to describe their
@@ -71,16 +75,41 @@ class SimulationState {
7175
auto getSizeAndPtr(const state_data &data) {
7276
auto type = data.index();
7377
std::tuple<std::size_t, void *> sizeAndPtr;
74-
if (type == 0)
78+
if (type ==
79+
cudaq::detail::variant_index<cudaq::state_data,
80+
std::vector<std::complex<double>>>())
7581
sizeAndPtr = getSizeAndPtrFromVec<double, ScalarType>(data);
76-
else if (type == 1)
82+
else if (type ==
83+
cudaq::detail::variant_index<cudaq::state_data,
84+
std::vector<std::complex<float>>>())
7785
sizeAndPtr = getSizeAndPtrFromVec<float, ScalarType>(data);
78-
else if (type == 2)
86+
else if (type == cudaq::detail::variant_index<
87+
cudaq::state_data,
88+
std::pair<std::complex<double> *, std::size_t>>())
7989
sizeAndPtr = getSizeAndPtrFromPair<double, ScalarType>(data);
80-
else if (type == 3)
90+
else if (type == cudaq::detail::variant_index<
91+
cudaq::state_data,
92+
std::pair<std::complex<float> *, std::size_t>>())
8193
sizeAndPtr = getSizeAndPtrFromPair<float, ScalarType>(data);
82-
else
83-
throw std::runtime_error("unsupported data type for state.");
94+
else if (type == cudaq::detail::variant_index<cudaq::state_data,
95+
complex_matrix>()) {
96+
// Complex matrix is double precision only
97+
if constexpr (!std::is_same_v<double, ScalarType>)
98+
throw std::runtime_error("[sim-state] invalid data precision.");
99+
auto &cMat = std::get<complex_matrix>(data);
100+
if (cMat.rows() != cMat.cols())
101+
throw std::runtime_error(
102+
"[sim-state] complex matrix must be square for density matrix.");
103+
// Check that it must be a power of 2
104+
if (std::bitset<64>(cMat.rows()).count() != 1)
105+
throw std::runtime_error("[sim-state] complex matrix size must be a "
106+
"power of 2 for density matrix.");
107+
return std::make_tuple(
108+
cMat.size(),
109+
reinterpret_cast<void *>(const_cast<complex_matrix &>(cMat).get_data(
110+
complex_matrix::order::row_major)));
111+
} else
112+
throw std::runtime_error("unsupported data type for state vector.");
84113

85114
return sizeAndPtr;
86115
}

runtime/cudaq/qis/managers/photonics/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ target_include_directories(${LIBRARY_NAME}
1717
$<INSTALL_INTERFACE:include>)
1818

1919
set (PHOTONICS_DEPENDENCIES "")
20-
list(APPEND PHOTONICS_DEPENDENCIES cudaq-common libqpp fmt::fmt-header-only)
20+
list(APPEND PHOTONICS_DEPENDENCIES cudaq cudaq-common libqpp fmt::fmt-header-only)
2121
add_openmp_configurations(${LIBRARY_NAME} PHOTONICS_DEPENDENCIES)
2222

2323
target_link_libraries(${LIBRARY_NAME}

runtime/nvqir/cutensornet/simulator_mps.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,28 @@ class SimulatorMPS : public SimulatorTensorNetBase<ScalarType> {
5757
throw std::invalid_argument(
5858
"[SimulatorMPS simulator] Incompatible state input");
5959
if (!m_state) {
60+
std::vector<MPSTensor> copiedTensors;
61+
copiedTensors.reserve(casted->getMpsTensors().size());
62+
for (const auto &mpsTensor : casted->getMpsTensors()) {
63+
std::vector<int64_t> extents = mpsTensor.extents;
64+
const auto numElements =
65+
std::reduce(extents.begin(), extents.end(), 1, std::multiplies());
66+
const auto tensorSizeBytes =
67+
sizeof(std::complex<ScalarType>) * numElements;
68+
void *mpsTensorCopy{nullptr};
69+
HANDLE_CUDA_ERROR(cudaMalloc(&mpsTensorCopy, tensorSizeBytes));
70+
HANDLE_CUDA_ERROR(cudaMemcpy(mpsTensorCopy, mpsTensor.deviceData,
71+
tensorSizeBytes, cudaMemcpyDefault));
72+
copiedTensors.emplace_back(MPSTensor(mpsTensorCopy, extents));
73+
}
74+
6075
m_state = TensorNetState<ScalarType>::createFromMpsTensors(
61-
casted->getMpsTensors(), scratchPad, m_cutnHandle, m_randomEngine);
76+
copiedTensors, scratchPad, m_cutnHandle, m_randomEngine);
77+
for (const auto &mpsTensor : copiedTensors) {
78+
m_state->m_tempDevicePtrs.emplace_back(
79+
mpsTensor.deviceData,
80+
typename TensorNetState<ScalarType>::TempDevicePtrDeleter{});
81+
}
6282
} else {
6383
// Expand an existing state: Append MPS tensors
6484
// Factor the existing state

runtime/nvqir/cutensornet/simulator_tensornet.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ class SimulatorTensorNet : public SimulatorTensorNetBase<ScalarType> {
123123
m_state = TensorNetState<ScalarType>::createFromOpTensors(
124124
in_state.getNumQubits(), casted->getAppliedTensors(), scratchPad,
125125
m_cutnHandle, m_randomEngine);
126+
// Need to extend lifetime of all the device pointers stored in the input
127+
// state.
128+
m_state->m_tempDevicePtrs = casted->m_state->m_tempDevicePtrs;
126129
} else {
127130
// Expand an existing state:
128131
// (1) Create a blank tensor network with combined number of qubits
@@ -149,6 +152,11 @@ class SimulatorTensorNet : public SimulatorTensorNetBase<ScalarType> {
149152
m_state->applyQubitProjector(op.deviceData,
150153
mapQubitIdxs(op.targetQubitIds));
151154
}
155+
// Append the temp. pointer
156+
m_state->m_tempDevicePtrs.insert(
157+
m_state->m_tempDevicePtrs.end(),
158+
casted->m_state->m_tempDevicePtrs.begin(),
159+
casted->m_state->m_tempDevicePtrs.end());
152160
}
153161
}
154162
bool requireCacheWorkspace() const override { return true; }

runtime/nvqir/cutensornet/tensornet_state.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,15 @@ class TensorNetState {
7777
cutensornetState_t m_quantumState;
7878
/// Track id of gate tensors that are applied to the state tensors.
7979
std::int64_t m_tensorId = InvalidTensorIndexValue;
80+
struct TempDevicePtrDeleter {
81+
void operator()(void *ptr) const {
82+
if (ptr)
83+
cudaFree(ptr);
84+
}
85+
};
86+
8087
// Device memory pointers to be cleaned up.
81-
std::vector<void *> m_tempDevicePtrs;
88+
std::vector<std::shared_ptr<void>> m_tempDevicePtrs;
8289
// Tensor ops that have been applied to the state.
8390
std::vector<AppliedTensorOp> m_tensorOps;
8491
ScratchDeviceMem &scratchPad;
@@ -233,6 +240,8 @@ class TensorNetState {
233240
template <typename ScalarTy>
234241
friend class SimulatorMPS;
235242
template <typename ScalarTy>
243+
friend class SimulatorTensorNet;
244+
template <typename ScalarTy>
236245
friend class TensorNetSimulationState;
237246
/// Internal method to contract the tensor network.
238247
/// Returns device memory pointer and size (number of elements).

runtime/nvqir/cutensornet/tensornet_state.inc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ TensorNetState<ScalarType>::TensorNetState(const std::vector<int> &basisState,
7373
HANDLE_CUDA_ERROR(cudaMalloc(&d_gate, sizeBytes));
7474
HANDLE_CUDA_ERROR(
7575
cudaMemcpy(d_gate, h_xGate, sizeBytes, cudaMemcpyHostToDevice));
76-
m_tempDevicePtrs.emplace_back(d_gate);
76+
m_tempDevicePtrs.emplace_back(d_gate, TempDevicePtrDeleter{});
7777
for (int32_t qId = 0; const auto &bit : basisState) {
7878
if (bit == 1) {
7979
applyGate({}, {qId}, d_gate);
@@ -251,7 +251,7 @@ void TensorNetState<ScalarType>::addQubits(
251251

252252
// Project the state of those new qubits to the input state.
253253
applyQubitProjector(d_proj, qubitIdx);
254-
m_tempDevicePtrs.emplace_back(d_proj);
254+
m_tempDevicePtrs.emplace_back(d_proj, TempDevicePtrDeleter{});
255255
}
256256

257257
template <typename ScalarType>
@@ -1214,16 +1214,15 @@ TensorNetState<ScalarType>::createFromStateVector(
12141214
std::iota(qubitIdx.begin(), qubitIdx.end(), 0);
12151215
// Project the state to the input state.
12161216
state->applyQubitProjector(d_proj, qubitIdx);
1217-
state->m_tempDevicePtrs.emplace_back(d_proj);
1217+
state->m_tempDevicePtrs.emplace_back(d_proj, TempDevicePtrDeleter{});
12181218
return state;
12191219
}
12201220

12211221
template <typename ScalarType>
12221222
TensorNetState<ScalarType>::~TensorNetState() {
12231223
// Destroy the quantum circuit state
12241224
HANDLE_CUTN_ERROR(cutensornetDestroyState(m_quantumState));
1225-
for (auto *ptr : m_tempDevicePtrs)
1226-
HANDLE_CUDA_ERROR(cudaFree(ptr));
1225+
m_tempDevicePtrs.clear();
12271226
}
12281227

12291228
} // namespace nvqir

runtime/nvqir/cutensornet/tn_simulation_state.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ class TensorNetSimulationState : public cudaq::SimulationState {
8181
return m_state->m_tensorOps;
8282
}
8383

84+
template <typename ScalarTy>
85+
friend class SimulatorTensorNet;
86+
8487
protected:
8588
std::unique_ptr<TensorNetState<ScalarType>> m_state;
8689
ScratchDeviceMem &scratchPad;

runtime/nvqir/cutensornet/tn_simulation_state.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ TensorNetSimulationState<ScalarType>::createFromSizeAndPtr(
305305

306306
template <typename ScalarType>
307307
void TensorNetSimulationState<ScalarType>::destroyState() {
308-
CUDAQ_INFO("mps-state destroying state vector handle.");
308+
CUDAQ_INFO("tn-state destroying state vector handle.");
309309
m_state.reset();
310310
}
311311

runtime/nvqir/qpp/QppDMCircuitSimulator.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,21 @@ struct QppDmState : public cudaq::SimulationState {
109109
}
110110

111111
std::unique_ptr<SimulationState>
112-
createFromSizeAndPtr(std::size_t size, void *ptr, std::size_t) override {
113-
return std::make_unique<QppDmState>(
114-
Eigen::Map<qpp::cmat>(reinterpret_cast<std::complex<double> *>(ptr),
115-
std::sqrt(size), std::sqrt(size)));
112+
createFromSizeAndPtr(std::size_t size, void *ptr, std::size_t type) override {
113+
const bool isMatrixData =
114+
type == cudaq::detail::variant_index<cudaq::state_data,
115+
cudaq::complex_matrix>();
116+
117+
if (isMatrixData)
118+
return std::make_unique<QppDmState>(
119+
Eigen::Map<qpp::cmat>(reinterpret_cast<std::complex<double> *>(ptr),
120+
std::sqrt(size), std::sqrt(size)));
121+
// This is state vector data, convert it to density matrix: rho = |psi><psi|
122+
auto *stateData =
123+
reinterpret_cast<std::complex<double> *>(const_cast<void *>(ptr));
124+
qpp::ket psi = qpp::ket::Map(stateData, size);
125+
qpp::cmat dm = psi * psi.adjoint();
126+
return std::make_unique<QppDmState>(std::move(dm));
116127
}
117128

118129
void dump(std::ostream &os) const override { os << state << "\n"; }

0 commit comments

Comments
 (0)