Skip to content

Commit f648e06

Browse files
committed
Simplifying cudaq::state support.
This PR greatly simplifies the use of cudaq::get_state by limiting its applicability strictly to local simulation execution contexts. Prior attempts were made to make cudaq::get_state generally available regardless of the target execution context. This level of support is being discontinued. - Move the qvector ctor overloads to the cudaq::state class. This simplifies qvector by making it only possible to initialize a qvector with a cudaq::state object and dropping all the overloads. - Fix tests. Adapt the regression tests. - Drop state with non-simulation targets. - Add code to distinguish between floats and complex floats and get rid of a bunch of redundant code that wasn't needed. Also, update tests. Signed-off-by: Eric Schweitz <[email protected]>
1 parent bb4678c commit f648e06

32 files changed

+1043
-1224
lines changed

include/cudaq/Frontend/nvqpp/ASTBridge.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,10 @@ class QuakeBridgeVisitor
288288
bool TraverseCXXConstructExpr(clang::CXXConstructExpr *x,
289289
DataRecursionQueue *q = nullptr);
290290
bool VisitCXXConstructExpr(clang::CXXConstructExpr *x);
291+
bool TraverseCXXTemporaryObjectExpr(clang::CXXTemporaryObjectExpr *x,
292+
DataRecursionQueue *q = nullptr) {
293+
return TraverseCXXConstructExpr(x, q);
294+
}
291295
bool VisitCXXOperatorCallExpr(clang::CXXOperatorCallExpr *x);
292296
bool VisitCXXParenListInitExpr(clang::CXXParenListInitExpr *x);
293297
bool WalkUpFromCXXOperatorCallExpr(clang::CXXOperatorCallExpr *x);

include/cudaq/Optimizer/Builder/Intrinsics.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,14 @@ static constexpr const char getNumQubitsFromCudaqState[] =
5353
"__nvqpp_cudaq_state_numberOfQubits";
5454

5555
// Create a new state from data.
56-
static constexpr const char createCudaqStateFromDataFP64[] =
57-
"__nvqpp_cudaq_state_createFromData_fp64";
58-
static constexpr const char createCudaqStateFromDataFP32[] =
59-
"__nvqpp_cudaq_state_createFromData_fp32";
56+
static constexpr const char createCudaqStateFromDataComplexF64[] =
57+
"__nvqpp_cudaq_state_createFromData_complex_f64";
58+
static constexpr const char createCudaqStateFromDataComplexF32[] =
59+
"__nvqpp_cudaq_state_createFromData_complex_f32";
60+
static constexpr const char createCudaqStateFromDataF64[] =
61+
"__nvqpp_cudaq_state_createFromData_f64";
62+
static constexpr const char createCudaqStateFromDataF32[] =
63+
"__nvqpp_cudaq_state_createFromData_f32";
6064

6165
// Delete a state created by the runtime functions above.
6266
static constexpr const char deleteCudaqState[] = "__nvqpp_cudaq_state_delete";

include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def quake_InitializeStateOp : QuakeOp<"init_state",
121121
}];
122122

123123
let arguments = (ins
124-
VeqType:$targets,
124+
NonStruqRefType:$targets,
125125
AnyStateInitType:$state
126126
);
127127
let results = (outs VeqType);

lib/Frontend/nvqpp/ConvertExpr.cpp

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,12 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
682682
if (cxxExpr->getNumArgs() == 1)
683683
return true;
684684
}
685-
if (isa<ComplexType>(castToTy) && isa<ComplexType>(peekValue().getType())) {
685+
if (isa<ComplexType>(castToTy) && isa<ComplexType>(peekValue().getType()))
686686
return true;
687-
}
687+
if (isa<quake::StateType>(castToTy))
688+
if (auto ptrTy = dyn_cast<cudaq::cc::PointerType>(peekValue().getType()))
689+
if (isa<quake::StateType>(ptrTy.getElementType()))
690+
return pushValue(builder.create<cudaq::cc::LoadOp>(loc, popValue()));
688691
if (auto funcTy = peelPointerFromFunction(castToTy))
689692
if (auto fromTy = dyn_cast<cc::CallableType>(peekValue().getType())) {
690693
auto inputs = funcTy.getInputs();
@@ -1003,8 +1006,8 @@ bool QuakeBridgeVisitor::VisitMaterializeTemporaryExpr(
10031006
// The following cases are λ expressions, quantum data, or a std::vector view.
10041007
// In those cases, there is nothing to materialize, so we can just pass the
10051008
// Value on the top of the stack.
1006-
if (isa<cc::CallableType, quake::VeqType, quake::RefType, cc::SpanLikeType>(
1007-
ty))
1009+
if (isa<cc::CallableType, quake::VeqType, quake::RefType, cc::SpanLikeType,
1010+
quake::StateType>(ty))
10081011
return true;
10091012

10101013
// If not one of the above special cases, then materialize the value to a
@@ -2689,6 +2692,11 @@ bool QuakeBridgeVisitor::VisitInitListExpr(clang::InitListExpr *x) {
26892692
}
26902693

26912694
// List has 1 or more members.
2695+
if (size == 1 && isa<clang::MaterializeTemporaryExpr>(x->getInit(0)))
2696+
if (auto alloc = peekValue().getDefiningOp<cudaq::cc::AllocaOp>())
2697+
if (auto arrTy = dyn_cast<cudaq::cc::ArrayType>(initListTy))
2698+
if (alloc.getElementType() == arrTy.getElementType())
2699+
return true;
26922700
auto last = lastValues(size);
26932701
bool allRef = std::all_of(last.begin(), last.end(), [](auto v) {
26942702
return isa<quake::RefType, quake::VeqType>(v.getType());
@@ -2916,6 +2924,32 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
29162924
loc, quake::VeqType::getUnsized(builder.getContext()), sizeVal));
29172925
}
29182926

2927+
if (ctorName == "state") {
2928+
// cudaq::state ctor can be materialized when using local simulators and
2929+
// converting raw data to state vectors. Use a runtime helper function
2930+
// to perform the conversion.
2931+
Value stdvec = popValue();
2932+
auto stateTy = cudaq::cc::PointerType::get(
2933+
quake::StateType::get(builder.getContext()));
2934+
if (auto stdvecTy = dyn_cast<cudaq::cc::StdvecType>(stdvec.getType())) {
2935+
auto dataTy = cudaq::cc::PointerType::get(stdvecTy.getElementType());
2936+
Value data =
2937+
builder.create<cudaq::cc::StdvecDataOp>(loc, dataTy, stdvec);
2938+
auto i64Ty = builder.getI64Type();
2939+
Value size =
2940+
builder.create<cudaq::cc::StdvecSizeOp>(loc, i64Ty, stdvec);
2941+
return pushValue(builder.create<quake::CreateStateOp>(
2942+
loc, stateTy, ValueRange{data, size}));
2943+
}
2944+
if (auto alloc = stdvec.getDefiningOp<cudaq::cc::AllocaOp>()) {
2945+
Value size = alloc.getSeqSize();
2946+
return pushValue(builder.create<quake::CreateStateOp>(
2947+
loc, stateTy, ValueRange{alloc, size}));
2948+
}
2949+
TODO_loc(loc, "unhandled state constructor");
2950+
return false;
2951+
}
2952+
29192953
// lambda determines: is `t` a cudaq::state* ?
29202954
auto isStateType = [&](Type t) {
29212955
if (auto ptrTy = dyn_cast<cc::PointerType>(t))
@@ -2925,9 +2959,17 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
29252959

29262960
if (ctorName == "qudit") {
29272961
auto initials = popValue();
2962+
if (isa<quake::StateType>(initials.getType()))
2963+
if (auto load = initials.getDefiningOp<cudaq::cc::LoadOp>())
2964+
initials = load.getPtrvalue();
29282965
if (isStateType(initials.getType())) {
2929-
TODO_x(loc, x, mangler, "qudit(state) ctor");
2930-
return false;
2966+
Value alloca = builder.create<quake::AllocaOp>(loc);
2967+
auto veq1Ty = quake::VeqType::get(builder.getContext(), 1);
2968+
Value initSt = builder.create<quake::InitializeStateOp>(
2969+
loc, veq1Ty, ValueRange{alloca, initials});
2970+
if (auto initOp = initials.getDefiningOp<quake::CreateStateOp>())
2971+
builder.create<quake::DeleteStateOp>(loc, initOp);
2972+
return pushValue(builder.create<quake::ExtractRefOp>(loc, initSt, 0));
29312973
}
29322974
bool ok = false;
29332975
if (auto ptrTy = dyn_cast<cc::PointerType>(initials.getType()))
@@ -2953,57 +2995,26 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
29532995
return pushValue(builder.create<quake::AllocaOp>(
29542996
loc, quake::VeqType::getUnsized(ctx), initials));
29552997
}
2956-
if (isa<quake::StateType>(initials.getType())) {
2998+
if (isa<quake::StateType>(initials.getType()))
29572999
if (auto load = initials.getDefiningOp<cudaq::cc::LoadOp>())
29583000
initials = load.getPtrvalue();
2959-
}
29603001
if (isStateType(initials.getType())) {
29613002
Value state = initials;
29623003
auto i64Ty = builder.getI64Type();
29633004
auto numQubits =
29643005
builder.create<quake::GetNumberOfQubitsOp>(loc, i64Ty, state);
29653006
auto veqTy = quake::VeqType::getUnsized(ctx);
29663007
Value alloc = builder.create<quake::AllocaOp>(loc, veqTy, numQubits);
2967-
return pushValue(builder.create<quake::InitializeStateOp>(
2968-
loc, veqTy, alloc, state));
3008+
Value initSt = builder.create<quake::InitializeStateOp>(loc, veqTy,
3009+
alloc, state);
3010+
if (auto initOp = initials.getDefiningOp<quake::CreateStateOp>())
3011+
builder.create<quake::DeleteStateOp>(loc, initOp);
3012+
return pushValue(initSt);
29693013
}
2970-
// Otherwise, it is the cudaq::qvector(std::vector<complex>) ctor.
2971-
Value numQubits;
2972-
Type initialsTy = initials.getType();
2973-
if (auto ptrTy = dyn_cast<cc::PointerType>(initialsTy)) {
2974-
if (auto arrTy = dyn_cast<cc::ArrayType>(ptrTy.getElementType())) {
2975-
if (arrTy.isUnknownSize()) {
2976-
if (auto allocOp = initials.getDefiningOp<cc::AllocaOp>())
2977-
if (auto size = allocOp.getSeqSize())
2978-
numQubits =
2979-
builder.create<math::CountTrailingZerosOp>(loc, size);
2980-
} else {
2981-
std::size_t arraySize = arrTy.getSize();
2982-
if (!std::has_single_bit(arraySize)) {
2983-
reportClangError(x, mangler,
2984-
"state vector must be a power of 2 in length");
2985-
}
2986-
numQubits = builder.create<arith::ConstantIntOp>(
2987-
loc, std::countr_zero(arraySize), 64);
2988-
}
2989-
}
2990-
} else if (auto stdvecTy = dyn_cast<cc::StdvecType>(initialsTy)) {
2991-
Value vecLen = builder.create<cc::StdvecSizeOp>(
2992-
loc, builder.getI64Type(), initials);
2993-
numQubits = builder.create<math::CountTrailingZerosOp>(loc, vecLen);
2994-
auto ptrTy = cc::PointerType::get(stdvecTy.getElementType());
2995-
initials = builder.create<cc::StdvecDataOp>(loc, ptrTy, initials);
2996-
}
2997-
if (!numQubits) {
2998-
reportClangError(
2999-
x, mangler,
3000-
"internal error: could not determine the number of qubits");
3001-
return false;
3002-
}
3003-
auto veqTy = quake::VeqType::getUnsized(ctx);
3004-
auto alloc = builder.create<quake::AllocaOp>(loc, veqTy, numQubits);
3005-
return pushValue(builder.create<quake::InitializeStateOp>(
3006-
loc, veqTy, alloc, initials));
3014+
reportClangError(
3015+
x, mangler,
3016+
"internal error: could not determine the number of qubits");
3017+
return false;
30073018
}
30083019
if ((ctorName == "qspan" || ctorName == "qview") &&
30093020
isa<quake::VeqType>(peekValue().getType())) {

lib/Optimizer/Builder/Intrinsics.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,17 @@ static constexpr IntrinsicCode intrinsicTable[] = {
304304
}
305305
)#"},
306306

307-
{cudaq::createCudaqStateFromDataFP32, {}, R"#(
308-
func.func private @__nvqpp_cudaq_state_createFromData_fp32(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
307+
{cudaq::createCudaqStateFromDataComplexF32, {}, R"#(
308+
func.func private @__nvqpp_cudaq_state_createFromData_complex_f32(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
309309
)#"},
310-
{cudaq::createCudaqStateFromDataFP64, {}, R"#(
311-
func.func private @__nvqpp_cudaq_state_createFromData_fp64(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
310+
{cudaq::createCudaqStateFromDataComplexF64, {}, R"#(
311+
func.func private @__nvqpp_cudaq_state_createFromData_complex_f64(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
312+
)#"},
313+
{cudaq::createCudaqStateFromDataF32, {}, R"#(
314+
func.func private @__nvqpp_cudaq_state_createFromData_f32(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
315+
)#"},
316+
{cudaq::createCudaqStateFromDataF64, {}, R"#(
317+
func.func private @__nvqpp_cudaq_state_createFromData_f64(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
312318
)#"},
313319

314320
{cudaq::deleteCudaqState, {}, R"#(

lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -785,8 +785,8 @@ struct QmemRAIIOpRewrite : public OpConversionPattern<cudaq::codegen::RAIIOp> {
785785

786786
Value sizeOperand;
787787
if (!adaptor.getAllocSize()) {
788-
auto type = cast<quake::VeqType>(allocTy);
789-
auto constantSize = type.getSize();
788+
auto type = dyn_cast<quake::VeqType>(allocTy);
789+
auto constantSize = type ? type.getSize() : 1;
790790
sizeOperand =
791791
rewriter.create<arith::ConstantIntOp>(loc, constantSize, 64);
792792
} else {

lib/Optimizer/CodeGen/QuakeToCodegen.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,27 @@ class CreateStateOpPattern : public OpRewritePattern<quake::CreateStateOp> {
8080

8181
auto bufferTy = buffer.getType();
8282
auto ptrTy = cast<cudaq::cc::PointerType>(bufferTy);
83-
auto arrTy = cast<cudaq::cc::ArrayType>(ptrTy.getElementType());
84-
auto eleTy = arrTy.getElementType();
83+
auto arrTy = dyn_cast<cudaq::cc::ArrayType>(ptrTy.getElementType());
84+
auto eleTy = arrTy ? arrTy.getElementType() : ptrTy.getElementType();
8585
auto is64Bit = isa<Float64Type>(eleTy);
86+
bool isComplex = false;
8687

87-
if (auto cTy = dyn_cast<ComplexType>(eleTy))
88+
if (auto cTy = dyn_cast<ComplexType>(eleTy)) {
8889
is64Bit = isa<Float64Type>(cTy.getElementType());
90+
isComplex = true;
91+
}
92+
93+
auto createStateFunc = [&]() {
94+
if (isComplex) {
95+
if (is64Bit)
96+
return cudaq::createCudaqStateFromDataComplexF64;
97+
return cudaq::createCudaqStateFromDataComplexF32;
98+
}
99+
if (is64Bit)
100+
return cudaq::createCudaqStateFromDataF64;
101+
return cudaq::createCudaqStateFromDataF32;
102+
}();
89103

90-
auto createStateFunc = is64Bit ? cudaq::createCudaqStateFromDataFP64
91-
: cudaq::createCudaqStateFromDataFP32;
92104
cudaq::IRBuilder irBuilder(ctx);
93105
auto result = irBuilder.loadIntrinsic(module, createStateFunc);
94106
assert(succeeded(result) && "loading intrinsic should never fail");
@@ -97,11 +109,9 @@ class CreateStateOpPattern : public OpRewritePattern<quake::CreateStateOp> {
97109
auto statePtrTy = cudaq::cc::PointerType::get(stateTy);
98110
auto i8PtrTy = cudaq::cc::PointerType::get(rewriter.getI8Type());
99111
auto cast = rewriter.create<cudaq::cc::CastOp>(loc, i8PtrTy, buffer);
100-
auto one = rewriter.create<arith::ConstantIntOp>(loc, 1, size.getType());
101-
auto powsz = rewriter.create<arith::ShLIOp>(loc, size.getType(), one, size);
102112

103113
rewriter.replaceOpWithNewOp<func::CallOp>(
104-
createStateOp, statePtrTy, createStateFunc, ValueRange{cast, powsz});
114+
createStateOp, statePtrTy, createStateFunc, ValueRange{cast, size});
105115
return success();
106116
}
107117
};

lib/Optimizer/CodeGen/VerifyNVQIRCalls.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ struct VerifyNVQIRCallOpsPass
4646
cudaq::opt::QIRArrayQubitAllocateArrayWithStateComplex32,
4747
cudaq::opt::QIRArrayQubitAllocateArrayWithStateComplex64,
4848
cudaq::getNumQubitsFromCudaqState,
49-
cudaq::createCudaqStateFromDataFP32,
50-
cudaq::createCudaqStateFromDataFP64,
49+
cudaq::createCudaqStateFromDataComplexF32,
50+
cudaq::createCudaqStateFromDataComplexF64,
51+
cudaq::createCudaqStateFromDataF32,
52+
cudaq::createCudaqStateFromDataF64,
5153
cudaq::deleteCudaqState};
5254
// It must be either NVQIR extension functions or in the allowed list.
5355
return std::find(NVQIR_FUNCS.begin(), NVQIR_FUNCS.end(), functionName) !=

runtime/cudaq/algorithms/get_state.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,12 @@ async_state_result get_state_async(QuantumKernel &&kernel, Args &&...args) {
202202

203203
extern "C" {
204204
std::int64_t __nvqpp_cudaq_state_numberOfQubits(state *);
205-
state *__nvqpp_cudaq_state_createFromData_fp64(void *, std::size_t);
206-
state *__nvqpp_cudaq_state_createFromData_fp32(void *, std::size_t);
205+
state *__nvqpp_cudaq_state_createFromData_f64(double *, std::size_t);
206+
state *__nvqpp_cudaq_state_createFromData_f32(float *, std::size_t);
207+
state *__nvqpp_cudaq_state_createFromData_complex_f64(std::complex<double> *,
208+
std::size_t);
209+
state *__nvqpp_cudaq_state_createFromData_complex_f32(std::complex<float> *,
210+
std::size_t);
207211
void __nvqpp_cudaq_state_delete(state *);
208212
}
209213

runtime/cudaq/qis/qudit.h

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
namespace cudaq {
1414

15+
class state;
16+
1517
/// The qudit models a general d-level quantum system.
1618
/// This type is templated on the number of levels d.
1719
template <std::size_t Levels>
@@ -28,29 +30,14 @@ class qudit {
2830
public:
2931
/// Construct a qudit, will allocated a new unique index
3032
qudit() : idx(getExecutionManager()->allocateQudit(n_levels())) {}
31-
qudit(const std::vector<complex> &state) : qudit() {
32-
if (state.size() != Levels)
33-
throw std::runtime_error(
34-
"Invalid number of state vector elements for qudit allocation (" +
35-
std::to_string(state.size()) + ").");
36-
37-
auto norm = std::inner_product(
38-
state.begin(), state.end(), state.begin(), complex{0., 0.},
39-
[](auto a, auto b) { return a + b; },
40-
[](auto a, auto b) { return std::conj(a) * b; })
41-
.real();
42-
if (std::fabs(1.0 - norm) > 1e-4)
43-
throw std::runtime_error("Invalid vector norm for qudit allocation.");
44-
45-
// Perform the initialization
46-
auto precision = std::is_same_v<complex::value_type, float>
47-
? simulation_precision::fp32
48-
: simulation_precision::fp64;
49-
getExecutionManager()->initializeState({QuditInfo(n_levels(), idx)},
50-
state.data(), precision);
33+
qudit(const state &state) : qudit() {
34+
// Note: the internal state data will be cloned by the simulator backend.
35+
std::vector<QuditInfo> v{QuditInfo{Levels, id()}};
36+
getExecutionManager()->initializeState(v, state.internal.get());
5137
}
52-
qudit(const std::initializer_list<complex> &list)
53-
: qudit({list.begin(), list.end()}) {}
38+
qudit(const state *s) : qudit(*s) {}
39+
qudit(state *s) : qudit(const_cast<const state *>(s)) {}
40+
qudit(state &s) : qudit(const_cast<const state &>(s)) {}
5441

5542
// Qudits cannot be copied
5643
qudit(const qudit &q) = delete;

0 commit comments

Comments
 (0)