Skip to content

Commit 1345427

Browse files
committed
Simplifying cudaq::state support.
This PR greatly simplifies the use of cudaq::get_state by limiting its applicability strictly to local simulation execution contexts. Prior attempts were made to make cudaq::get_state generally available regardless of the target execution context. This level of support is being discontinued. - Move the qvector ctor overloads to the cudaq::state class. This simplifies qvector by making it only possible to initialize a qvector with a cudaq::state object and dropping all the overloads. - Fix tests. Adapt the regression tests. - Drop state with non-simulation targets. - Add code to distinguish between floats and complex floats and get rid of a bunch of redundant code that wasn't needed. Also, update tests. - Make all command line arguments be in the same order. - Don't implicit handle density matrix as a flattened vector (#6) Signed-off-by: Eric Schweitz <[email protected]>
1 parent 9db4859 commit 1345427

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1419
-1243
lines changed

include/cudaq/Frontend/nvqpp/ASTBridge.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,10 @@ class QuakeBridgeVisitor
288288
bool TraverseCXXConstructExpr(clang::CXXConstructExpr *x,
289289
DataRecursionQueue *q = nullptr);
290290
bool VisitCXXConstructExpr(clang::CXXConstructExpr *x);
291+
bool TraverseCXXTemporaryObjectExpr(clang::CXXTemporaryObjectExpr *x,
292+
DataRecursionQueue *q = nullptr) {
293+
return TraverseCXXConstructExpr(x, q);
294+
}
291295
bool VisitCXXOperatorCallExpr(clang::CXXOperatorCallExpr *x);
292296
bool VisitCXXParenListInitExpr(clang::CXXParenListInitExpr *x);
293297
bool WalkUpFromCXXOperatorCallExpr(clang::CXXOperatorCallExpr *x);

include/cudaq/Optimizer/Builder/Intrinsics.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,14 @@ static constexpr const char getNumQubitsFromCudaqState[] =
5353
"__nvqpp_cudaq_state_numberOfQubits";
5454

5555
// Create a new state from data.
56-
static constexpr const char createCudaqStateFromDataFP64[] =
57-
"__nvqpp_cudaq_state_createFromData_fp64";
58-
static constexpr const char createCudaqStateFromDataFP32[] =
59-
"__nvqpp_cudaq_state_createFromData_fp32";
56+
static constexpr const char createCudaqStateFromDataComplexF64[] =
57+
"__nvqpp_cudaq_state_createFromData_complex_f64";
58+
static constexpr const char createCudaqStateFromDataComplexF32[] =
59+
"__nvqpp_cudaq_state_createFromData_complex_f32";
60+
static constexpr const char createCudaqStateFromDataF64[] =
61+
"__nvqpp_cudaq_state_createFromData_f64";
62+
static constexpr const char createCudaqStateFromDataF32[] =
63+
"__nvqpp_cudaq_state_createFromData_f32";
6064

6165
// Delete a state created by the runtime functions above.
6266
static constexpr const char deleteCudaqState[] = "__nvqpp_cudaq_state_delete";

include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def quake_InitializeStateOp : QuakeOp<"init_state",
121121
}];
122122

123123
let arguments = (ins
124-
VeqType:$targets,
124+
NonStruqRefType:$targets,
125125
AnyStateInitType:$state
126126
);
127127
let results = (outs VeqType);

lib/Frontend/nvqpp/ConvertExpr.cpp

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,12 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
682682
if (cxxExpr->getNumArgs() == 1)
683683
return true;
684684
}
685-
if (isa<ComplexType>(castToTy) && isa<ComplexType>(peekValue().getType())) {
685+
if (isa<ComplexType>(castToTy) && isa<ComplexType>(peekValue().getType()))
686686
return true;
687-
}
687+
if (isa<quake::StateType>(castToTy))
688+
if (auto ptrTy = dyn_cast<cudaq::cc::PointerType>(peekValue().getType()))
689+
if (isa<quake::StateType>(ptrTy.getElementType()))
690+
return pushValue(builder.create<cudaq::cc::LoadOp>(loc, popValue()));
688691
if (auto funcTy = peelPointerFromFunction(castToTy))
689692
if (auto fromTy = dyn_cast<cc::CallableType>(peekValue().getType())) {
690693
auto inputs = funcTy.getInputs();
@@ -1003,8 +1006,8 @@ bool QuakeBridgeVisitor::VisitMaterializeTemporaryExpr(
10031006
// The following cases are λ expressions, quantum data, or a std::vector view.
10041007
// In those cases, there is nothing to materialize, so we can just pass the
10051008
// Value on the top of the stack.
1006-
if (isa<cc::CallableType, quake::VeqType, quake::RefType, cc::SpanLikeType>(
1007-
ty))
1009+
if (isa<cc::CallableType, quake::VeqType, quake::RefType, cc::SpanLikeType,
1010+
quake::StateType>(ty))
10081011
return true;
10091012

10101013
// If not one of the above special cases, then materialize the value to a
@@ -2689,6 +2692,11 @@ bool QuakeBridgeVisitor::VisitInitListExpr(clang::InitListExpr *x) {
26892692
}
26902693

26912694
// List has 1 or more members.
2695+
if (size == 1 && isa<clang::MaterializeTemporaryExpr>(x->getInit(0)))
2696+
if (auto alloc = peekValue().getDefiningOp<cudaq::cc::AllocaOp>())
2697+
if (auto arrTy = dyn_cast<cudaq::cc::ArrayType>(initListTy))
2698+
if (alloc.getElementType() == arrTy.getElementType())
2699+
return true;
26922700
auto last = lastValues(size);
26932701
bool allRef = std::all_of(last.begin(), last.end(), [](auto v) {
26942702
return isa<quake::RefType, quake::VeqType>(v.getType());
@@ -2916,6 +2924,32 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
29162924
loc, quake::VeqType::getUnsized(builder.getContext()), sizeVal));
29172925
}
29182926

2927+
if (ctorName == "state") {
2928+
// cudaq::state ctor can be materialized when using local simulators and
2929+
// converting raw data to state vectors. Use a runtime helper function
2930+
// to perform the conversion.
2931+
Value stdvec = popValue();
2932+
auto stateTy = cudaq::cc::PointerType::get(
2933+
quake::StateType::get(builder.getContext()));
2934+
if (auto stdvecTy = dyn_cast<cudaq::cc::StdvecType>(stdvec.getType())) {
2935+
auto dataTy = cudaq::cc::PointerType::get(stdvecTy.getElementType());
2936+
Value data =
2937+
builder.create<cudaq::cc::StdvecDataOp>(loc, dataTy, stdvec);
2938+
auto i64Ty = builder.getI64Type();
2939+
Value size =
2940+
builder.create<cudaq::cc::StdvecSizeOp>(loc, i64Ty, stdvec);
2941+
return pushValue(builder.create<quake::CreateStateOp>(
2942+
loc, stateTy, ValueRange{data, size}));
2943+
}
2944+
if (auto alloc = stdvec.getDefiningOp<cudaq::cc::AllocaOp>()) {
2945+
Value size = alloc.getSeqSize();
2946+
return pushValue(builder.create<quake::CreateStateOp>(
2947+
loc, stateTy, ValueRange{alloc, size}));
2948+
}
2949+
TODO_loc(loc, "unhandled state constructor");
2950+
return false;
2951+
}
2952+
29192953
// lambda determines: is `t` a cudaq::state* ?
29202954
auto isStateType = [&](Type t) {
29212955
if (auto ptrTy = dyn_cast<cc::PointerType>(t))
@@ -2925,9 +2959,17 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
29252959

29262960
if (ctorName == "qudit") {
29272961
auto initials = popValue();
2962+
if (isa<quake::StateType>(initials.getType()))
2963+
if (auto load = initials.getDefiningOp<cudaq::cc::LoadOp>())
2964+
initials = load.getPtrvalue();
29282965
if (isStateType(initials.getType())) {
2929-
TODO_x(loc, x, mangler, "qudit(state) ctor");
2930-
return false;
2966+
Value alloca = builder.create<quake::AllocaOp>(loc);
2967+
auto veq1Ty = quake::VeqType::get(builder.getContext(), 1);
2968+
Value initSt = builder.create<quake::InitializeStateOp>(
2969+
loc, veq1Ty, ValueRange{alloca, initials});
2970+
if (auto initOp = initials.getDefiningOp<quake::CreateStateOp>())
2971+
builder.create<quake::DeleteStateOp>(loc, initOp);
2972+
return pushValue(builder.create<quake::ExtractRefOp>(loc, initSt, 0));
29312973
}
29322974
bool ok = false;
29332975
if (auto ptrTy = dyn_cast<cc::PointerType>(initials.getType()))
@@ -2953,57 +2995,26 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
29532995
return pushValue(builder.create<quake::AllocaOp>(
29542996
loc, quake::VeqType::getUnsized(ctx), initials));
29552997
}
2956-
if (isa<quake::StateType>(initials.getType())) {
2998+
if (isa<quake::StateType>(initials.getType()))
29572999
if (auto load = initials.getDefiningOp<cudaq::cc::LoadOp>())
29583000
initials = load.getPtrvalue();
2959-
}
29603001
if (isStateType(initials.getType())) {
29613002
Value state = initials;
29623003
auto i64Ty = builder.getI64Type();
29633004
auto numQubits =
29643005
builder.create<quake::GetNumberOfQubitsOp>(loc, i64Ty, state);
29653006
auto veqTy = quake::VeqType::getUnsized(ctx);
29663007
Value alloc = builder.create<quake::AllocaOp>(loc, veqTy, numQubits);
2967-
return pushValue(builder.create<quake::InitializeStateOp>(
2968-
loc, veqTy, alloc, state));
3008+
Value initSt = builder.create<quake::InitializeStateOp>(loc, veqTy,
3009+
alloc, state);
3010+
if (auto initOp = initials.getDefiningOp<quake::CreateStateOp>())
3011+
builder.create<quake::DeleteStateOp>(loc, initOp);
3012+
return pushValue(initSt);
29693013
}
2970-
// Otherwise, it is the cudaq::qvector(std::vector<complex>) ctor.
2971-
Value numQubits;
2972-
Type initialsTy = initials.getType();
2973-
if (auto ptrTy = dyn_cast<cc::PointerType>(initialsTy)) {
2974-
if (auto arrTy = dyn_cast<cc::ArrayType>(ptrTy.getElementType())) {
2975-
if (arrTy.isUnknownSize()) {
2976-
if (auto allocOp = initials.getDefiningOp<cc::AllocaOp>())
2977-
if (auto size = allocOp.getSeqSize())
2978-
numQubits =
2979-
builder.create<math::CountTrailingZerosOp>(loc, size);
2980-
} else {
2981-
std::size_t arraySize = arrTy.getSize();
2982-
if (!std::has_single_bit(arraySize)) {
2983-
reportClangError(x, mangler,
2984-
"state vector must be a power of 2 in length");
2985-
}
2986-
numQubits = builder.create<arith::ConstantIntOp>(
2987-
loc, std::countr_zero(arraySize), 64);
2988-
}
2989-
}
2990-
} else if (auto stdvecTy = dyn_cast<cc::StdvecType>(initialsTy)) {
2991-
Value vecLen = builder.create<cc::StdvecSizeOp>(
2992-
loc, builder.getI64Type(), initials);
2993-
numQubits = builder.create<math::CountTrailingZerosOp>(loc, vecLen);
2994-
auto ptrTy = cc::PointerType::get(stdvecTy.getElementType());
2995-
initials = builder.create<cc::StdvecDataOp>(loc, ptrTy, initials);
2996-
}
2997-
if (!numQubits) {
2998-
reportClangError(
2999-
x, mangler,
3000-
"internal error: could not determine the number of qubits");
3001-
return false;
3002-
}
3003-
auto veqTy = quake::VeqType::getUnsized(ctx);
3004-
auto alloc = builder.create<quake::AllocaOp>(loc, veqTy, numQubits);
3005-
return pushValue(builder.create<quake::InitializeStateOp>(
3006-
loc, veqTy, alloc, initials));
3014+
reportClangError(
3015+
x, mangler,
3016+
"internal error: could not determine the number of qubits");
3017+
return false;
30073018
}
30083019
if ((ctorName == "qspan" || ctorName == "qview") &&
30093020
isa<quake::VeqType>(peekValue().getType())) {

lib/Optimizer/Builder/Intrinsics.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,17 @@ static constexpr IntrinsicCode intrinsicTable[] = {
304304
}
305305
)#"},
306306

307-
{cudaq::createCudaqStateFromDataFP32, {}, R"#(
308-
func.func private @__nvqpp_cudaq_state_createFromData_fp32(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
307+
{cudaq::createCudaqStateFromDataComplexF32, {}, R"#(
308+
func.func private @__nvqpp_cudaq_state_createFromData_complex_f32(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
309309
)#"},
310-
{cudaq::createCudaqStateFromDataFP64, {}, R"#(
311-
func.func private @__nvqpp_cudaq_state_createFromData_fp64(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
310+
{cudaq::createCudaqStateFromDataComplexF64, {}, R"#(
311+
func.func private @__nvqpp_cudaq_state_createFromData_complex_f64(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
312+
)#"},
313+
{cudaq::createCudaqStateFromDataF32, {}, R"#(
314+
func.func private @__nvqpp_cudaq_state_createFromData_f32(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
315+
)#"},
316+
{cudaq::createCudaqStateFromDataF64, {}, R"#(
317+
func.func private @__nvqpp_cudaq_state_createFromData_f64(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
312318
)#"},
313319

314320
{cudaq::deleteCudaqState, {}, R"#(

lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -785,8 +785,8 @@ struct QmemRAIIOpRewrite : public OpConversionPattern<cudaq::codegen::RAIIOp> {
785785

786786
Value sizeOperand;
787787
if (!adaptor.getAllocSize()) {
788-
auto type = cast<quake::VeqType>(allocTy);
789-
auto constantSize = type.getSize();
788+
auto type = dyn_cast<quake::VeqType>(allocTy);
789+
auto constantSize = type ? type.getSize() : 1;
790790
sizeOperand =
791791
rewriter.create<arith::ConstantIntOp>(loc, constantSize, 64);
792792
} else {

lib/Optimizer/CodeGen/QuakeToCodegen.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,27 @@ class CreateStateOpPattern : public OpRewritePattern<quake::CreateStateOp> {
8080

8181
auto bufferTy = buffer.getType();
8282
auto ptrTy = cast<cudaq::cc::PointerType>(bufferTy);
83-
auto arrTy = cast<cudaq::cc::ArrayType>(ptrTy.getElementType());
84-
auto eleTy = arrTy.getElementType();
83+
auto arrTy = dyn_cast<cudaq::cc::ArrayType>(ptrTy.getElementType());
84+
auto eleTy = arrTy ? arrTy.getElementType() : ptrTy.getElementType();
8585
auto is64Bit = isa<Float64Type>(eleTy);
86+
bool isComplex = false;
8687

87-
if (auto cTy = dyn_cast<ComplexType>(eleTy))
88+
if (auto cTy = dyn_cast<ComplexType>(eleTy)) {
8889
is64Bit = isa<Float64Type>(cTy.getElementType());
90+
isComplex = true;
91+
}
92+
93+
auto createStateFunc = [&]() {
94+
if (isComplex) {
95+
if (is64Bit)
96+
return cudaq::createCudaqStateFromDataComplexF64;
97+
return cudaq::createCudaqStateFromDataComplexF32;
98+
}
99+
if (is64Bit)
100+
return cudaq::createCudaqStateFromDataF64;
101+
return cudaq::createCudaqStateFromDataF32;
102+
}();
89103

90-
auto createStateFunc = is64Bit ? cudaq::createCudaqStateFromDataFP64
91-
: cudaq::createCudaqStateFromDataFP32;
92104
cudaq::IRBuilder irBuilder(ctx);
93105
auto result = irBuilder.loadIntrinsic(module, createStateFunc);
94106
assert(succeeded(result) && "loading intrinsic should never fail");
@@ -97,11 +109,9 @@ class CreateStateOpPattern : public OpRewritePattern<quake::CreateStateOp> {
97109
auto statePtrTy = cudaq::cc::PointerType::get(stateTy);
98110
auto i8PtrTy = cudaq::cc::PointerType::get(rewriter.getI8Type());
99111
auto cast = rewriter.create<cudaq::cc::CastOp>(loc, i8PtrTy, buffer);
100-
auto one = rewriter.create<arith::ConstantIntOp>(loc, 1, size.getType());
101-
auto powsz = rewriter.create<arith::ShLIOp>(loc, size.getType(), one, size);
102112

103113
rewriter.replaceOpWithNewOp<func::CallOp>(
104-
createStateOp, statePtrTy, createStateFunc, ValueRange{cast, powsz});
114+
createStateOp, statePtrTy, createStateFunc, ValueRange{cast, size});
105115
return success();
106116
}
107117
};

lib/Optimizer/CodeGen/VerifyNVQIRCalls.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ struct VerifyNVQIRCallOpsPass
4646
cudaq::opt::QIRArrayQubitAllocateArrayWithStateComplex32,
4747
cudaq::opt::QIRArrayQubitAllocateArrayWithStateComplex64,
4848
cudaq::getNumQubitsFromCudaqState,
49-
cudaq::createCudaqStateFromDataFP32,
50-
cudaq::createCudaqStateFromDataFP64,
49+
cudaq::createCudaqStateFromDataComplexF32,
50+
cudaq::createCudaqStateFromDataComplexF64,
51+
cudaq::createCudaqStateFromDataF32,
52+
cudaq::createCudaqStateFromDataF64,
5153
cudaq::deleteCudaqState};
5254
// It must be either NVQIR extension functions or in the allowed list.
5355
return std::find(NVQIR_FUNCS.begin(), NVQIR_FUNCS.end(), functionName) !=

runtime/common/ExecutionContext.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include "Future.h"
1212
#include "NoiseModel.h"
1313
#include "SampleResult.h"
14-
#include "SimulationState.h"
1514
#include "Trace.h"
1615
#include "cudaq/algorithms/optimizer.h"
1716
#include "cudaq/operators.h"
@@ -20,6 +19,8 @@
2019

2120
namespace cudaq {
2221

22+
class SimulationState;
23+
2324
/// The ExecutionContext is an abstraction to indicate how a CUDA-Q kernel
2425
/// should be executed.
2526
class ExecutionContext {

runtime/common/SimulationState.h

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
#pragma once
1010

11+
#include "cudaq/utils/cudaq_utils.h"
12+
#include "cudaq/utils/matrix.h"
1113
#include <algorithm>
14+
#include <bitset>
1215
#include <complex>
1316
#include <memory>
1417
#include <optional>
@@ -28,10 +31,11 @@ using TensorStateData =
2831
/// @brief state_data is a variant type
2932
/// encoding different forms of user state vector data
3033
/// we support.
31-
using state_data = std::variant<
32-
std::vector<std::complex<double>>, std::vector<std::complex<float>>,
33-
std::pair<std::complex<double> *, std::size_t>,
34-
std::pair<std::complex<float> *, std::size_t>, TensorStateData>;
34+
using state_data = std::variant<std::vector<std::complex<double>>,
35+
std::vector<std::complex<float>>,
36+
std::pair<std::complex<double> *, std::size_t>,
37+
std::pair<std::complex<float> *, std::size_t>,
38+
complex_matrix, TensorStateData>;
3539

3640
/// @brief The `SimulationState` interface provides and extension point
3741
/// for concrete circuit simulation sub-types to describe their
@@ -69,20 +73,34 @@ class SimulationState {
6973
/// and extract the data pointer and size.
7074
template <typename ScalarType = double>
7175
auto getSizeAndPtr(const state_data &data) {
72-
auto type = data.index();
73-
std::tuple<std::size_t, void *> sizeAndPtr;
74-
if (type == 0)
75-
sizeAndPtr = getSizeAndPtrFromVec<double, ScalarType>(data);
76-
else if (type == 1)
77-
sizeAndPtr = getSizeAndPtrFromVec<float, ScalarType>(data);
78-
else if (type == 2)
79-
sizeAndPtr = getSizeAndPtrFromPair<double, ScalarType>(data);
80-
else if (type == 3)
81-
sizeAndPtr = getSizeAndPtrFromPair<float, ScalarType>(data);
82-
else
83-
throw std::runtime_error("unsupported data type for state.");
84-
85-
return sizeAndPtr;
76+
if (std::holds_alternative<std::vector<std::complex<double>>>(data))
77+
return getSizeAndPtrFromVec<double, ScalarType>(data);
78+
if (std::holds_alternative<std::vector<std::complex<float>>>(data))
79+
return getSizeAndPtrFromVec<float, ScalarType>(data);
80+
if (std::holds_alternative<std::pair<std::complex<double> *, std::size_t>>(
81+
data))
82+
return getSizeAndPtrFromPair<double, ScalarType>(data);
83+
if (std::holds_alternative<std::pair<std::complex<float> *, std::size_t>>(
84+
data))
85+
return getSizeAndPtrFromPair<float, ScalarType>(data);
86+
if (std::holds_alternative<complex_matrix>(data)) {
87+
// Complex matrix is double precision only
88+
if constexpr (!std::is_same_v<double, ScalarType>)
89+
throw std::runtime_error("[sim-state] invalid data precision.");
90+
auto &cMat = std::get<complex_matrix>(data);
91+
if (cMat.rows() != cMat.cols())
92+
throw std::runtime_error(
93+
"[sim-state] complex matrix must be square for density matrix.");
94+
// Check that it must be a power of 2
95+
if (std::bitset<64>(cMat.rows()).count() != 1)
96+
throw std::runtime_error("[sim-state] complex matrix size must be a "
97+
"power of 2 for density matrix.");
98+
return std::make_tuple(
99+
cMat.size(),
100+
reinterpret_cast<void *>(const_cast<complex_matrix &>(cMat).get_data(
101+
complex_matrix::order::row_major)));
102+
}
103+
throw std::runtime_error("unsupported data type for state vector.");
86104
}
87105

88106
/// @brief Subclass-specific creator method for

0 commit comments

Comments
 (0)