Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/cudaq/Frontend/nvqpp/ASTBridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ class QuakeBridgeVisitor
bool TraverseCXXConstructExpr(clang::CXXConstructExpr *x,
DataRecursionQueue *q = nullptr);
bool VisitCXXConstructExpr(clang::CXXConstructExpr *x);
bool TraverseCXXTemporaryObjectExpr(clang::CXXTemporaryObjectExpr *x,
DataRecursionQueue *q = nullptr) {
return TraverseCXXConstructExpr(x, q);
}
bool VisitCXXOperatorCallExpr(clang::CXXOperatorCallExpr *x);
bool VisitCXXParenListInitExpr(clang::CXXParenListInitExpr *x);
bool WalkUpFromCXXOperatorCallExpr(clang::CXXOperatorCallExpr *x);
Expand Down
12 changes: 8 additions & 4 deletions include/cudaq/Optimizer/Builder/Intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ static constexpr const char getNumQubitsFromCudaqState[] =
"__nvqpp_cudaq_state_numberOfQubits";

// Create a new state from data.
static constexpr const char createCudaqStateFromDataFP64[] =
"__nvqpp_cudaq_state_createFromData_fp64";
static constexpr const char createCudaqStateFromDataFP32[] =
"__nvqpp_cudaq_state_createFromData_fp32";
static constexpr const char createCudaqStateFromDataComplexF64[] =
"__nvqpp_cudaq_state_createFromData_complex_f64";
static constexpr const char createCudaqStateFromDataComplexF32[] =
"__nvqpp_cudaq_state_createFromData_complex_f32";
static constexpr const char createCudaqStateFromDataF64[] =
"__nvqpp_cudaq_state_createFromData_f64";
static constexpr const char createCudaqStateFromDataF32[] =
"__nvqpp_cudaq_state_createFromData_f32";

// Delete a state created by the runtime functions above.
static constexpr const char deleteCudaqState[] = "__nvqpp_cudaq_state_delete";
Expand Down
2 changes: 1 addition & 1 deletion include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def quake_InitializeStateOp : QuakeOp<"init_state",
}];

let arguments = (ins
VeqType:$targets,
NonStruqRefType:$targets,
AnyStateInitType:$state
);
let results = (outs VeqType);
Expand Down
105 changes: 58 additions & 47 deletions lib/Frontend/nvqpp/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,12 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
if (cxxExpr->getNumArgs() == 1)
return true;
}
if (isa<ComplexType>(castToTy) && isa<ComplexType>(peekValue().getType())) {
if (isa<ComplexType>(castToTy) && isa<ComplexType>(peekValue().getType()))
return true;
}
if (isa<quake::StateType>(castToTy))
if (auto ptrTy = dyn_cast<cudaq::cc::PointerType>(peekValue().getType()))
if (isa<quake::StateType>(ptrTy.getElementType()))
return pushValue(builder.create<cudaq::cc::LoadOp>(loc, popValue()));
if (auto funcTy = peelPointerFromFunction(castToTy))
if (auto fromTy = dyn_cast<cc::CallableType>(peekValue().getType())) {
auto inputs = funcTy.getInputs();
Expand Down Expand Up @@ -1003,8 +1006,8 @@ bool QuakeBridgeVisitor::VisitMaterializeTemporaryExpr(
// The following cases are λ expressions, quantum data, or a std::vector view.
// In those cases, there is nothing to materialize, so we can just pass the
// Value on the top of the stack.
if (isa<cc::CallableType, quake::VeqType, quake::RefType, cc::SpanLikeType>(
ty))
if (isa<cc::CallableType, quake::VeqType, quake::RefType, cc::SpanLikeType,
quake::StateType>(ty))
return true;

// If not one of the above special cases, then materialize the value to a
Expand Down Expand Up @@ -2689,6 +2692,11 @@ bool QuakeBridgeVisitor::VisitInitListExpr(clang::InitListExpr *x) {
}

// List has 1 or more members.
if (size == 1 && isa<clang::MaterializeTemporaryExpr>(x->getInit(0)))
if (auto alloc = peekValue().getDefiningOp<cudaq::cc::AllocaOp>())
if (auto arrTy = dyn_cast<cudaq::cc::ArrayType>(initListTy))
if (alloc.getElementType() == arrTy.getElementType())
return true;
auto last = lastValues(size);
bool allRef = std::all_of(last.begin(), last.end(), [](auto v) {
return isa<quake::RefType, quake::VeqType>(v.getType());
Expand Down Expand Up @@ -2916,6 +2924,32 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
loc, quake::VeqType::getUnsized(builder.getContext()), sizeVal));
}

if (ctorName == "state") {
// cudaq::state ctor can be materialized when using local simulators and
// converting raw data to state vectors. Use a runtime helper function
// to perform the conversion.
Value stdvec = popValue();
auto stateTy = cudaq::cc::PointerType::get(
quake::StateType::get(builder.getContext()));
if (auto stdvecTy = dyn_cast<cudaq::cc::StdvecType>(stdvec.getType())) {
auto dataTy = cudaq::cc::PointerType::get(stdvecTy.getElementType());
Value data =
builder.create<cudaq::cc::StdvecDataOp>(loc, dataTy, stdvec);
auto i64Ty = builder.getI64Type();
Value size =
builder.create<cudaq::cc::StdvecSizeOp>(loc, i64Ty, stdvec);
return pushValue(builder.create<quake::CreateStateOp>(
loc, stateTy, ValueRange{data, size}));
}
if (auto alloc = stdvec.getDefiningOp<cudaq::cc::AllocaOp>()) {
Value size = alloc.getSeqSize();
return pushValue(builder.create<quake::CreateStateOp>(
loc, stateTy, ValueRange{alloc, size}));
}
TODO_loc(loc, "unhandled state constructor");
return false;
}

// lambda determines: is `t` a cudaq::state* ?
auto isStateType = [&](Type t) {
if (auto ptrTy = dyn_cast<cc::PointerType>(t))
Expand All @@ -2925,9 +2959,17 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {

if (ctorName == "qudit") {
auto initials = popValue();
if (isa<quake::StateType>(initials.getType()))
if (auto load = initials.getDefiningOp<cudaq::cc::LoadOp>())
initials = load.getPtrvalue();
if (isStateType(initials.getType())) {
TODO_x(loc, x, mangler, "qudit(state) ctor");
return false;
Value alloca = builder.create<quake::AllocaOp>(loc);
auto veq1Ty = quake::VeqType::get(builder.getContext(), 1);
Value initSt = builder.create<quake::InitializeStateOp>(
loc, veq1Ty, ValueRange{alloca, initials});
if (auto initOp = initials.getDefiningOp<quake::CreateStateOp>())
builder.create<quake::DeleteStateOp>(loc, initOp);
return pushValue(builder.create<quake::ExtractRefOp>(loc, initSt, 0));
}
bool ok = false;
if (auto ptrTy = dyn_cast<cc::PointerType>(initials.getType()))
Expand All @@ -2953,57 +2995,26 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
return pushValue(builder.create<quake::AllocaOp>(
loc, quake::VeqType::getUnsized(ctx), initials));
}
if (isa<quake::StateType>(initials.getType())) {
if (isa<quake::StateType>(initials.getType()))
if (auto load = initials.getDefiningOp<cudaq::cc::LoadOp>())
initials = load.getPtrvalue();
}
if (isStateType(initials.getType())) {
Value state = initials;
auto i64Ty = builder.getI64Type();
auto numQubits =
builder.create<quake::GetNumberOfQubitsOp>(loc, i64Ty, state);
auto veqTy = quake::VeqType::getUnsized(ctx);
Value alloc = builder.create<quake::AllocaOp>(loc, veqTy, numQubits);
return pushValue(builder.create<quake::InitializeStateOp>(
loc, veqTy, alloc, state));
Value initSt = builder.create<quake::InitializeStateOp>(loc, veqTy,
alloc, state);
if (auto initOp = initials.getDefiningOp<quake::CreateStateOp>())
builder.create<quake::DeleteStateOp>(loc, initOp);
return pushValue(initSt);
}
// Otherwise, it is the cudaq::qvector(std::vector<complex>) ctor.
Value numQubits;
Type initialsTy = initials.getType();
if (auto ptrTy = dyn_cast<cc::PointerType>(initialsTy)) {
if (auto arrTy = dyn_cast<cc::ArrayType>(ptrTy.getElementType())) {
if (arrTy.isUnknownSize()) {
if (auto allocOp = initials.getDefiningOp<cc::AllocaOp>())
if (auto size = allocOp.getSeqSize())
numQubits =
builder.create<math::CountTrailingZerosOp>(loc, size);
} else {
std::size_t arraySize = arrTy.getSize();
if (!std::has_single_bit(arraySize)) {
reportClangError(x, mangler,
"state vector must be a power of 2 in length");
}
numQubits = builder.create<arith::ConstantIntOp>(
loc, std::countr_zero(arraySize), 64);
}
}
} else if (auto stdvecTy = dyn_cast<cc::StdvecType>(initialsTy)) {
Value vecLen = builder.create<cc::StdvecSizeOp>(
loc, builder.getI64Type(), initials);
numQubits = builder.create<math::CountTrailingZerosOp>(loc, vecLen);
auto ptrTy = cc::PointerType::get(stdvecTy.getElementType());
initials = builder.create<cc::StdvecDataOp>(loc, ptrTy, initials);
}
if (!numQubits) {
reportClangError(
x, mangler,
"internal error: could not determine the number of qubits");
return false;
}
auto veqTy = quake::VeqType::getUnsized(ctx);
auto alloc = builder.create<quake::AllocaOp>(loc, veqTy, numQubits);
return pushValue(builder.create<quake::InitializeStateOp>(
loc, veqTy, alloc, initials));
reportClangError(
x, mangler,
"internal error: could not determine the number of qubits");
return false;
}
if ((ctorName == "qspan" || ctorName == "qview") &&
isa<quake::VeqType>(peekValue().getType())) {
Expand Down
14 changes: 10 additions & 4 deletions lib/Optimizer/Builder/Intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,17 @@ static constexpr IntrinsicCode intrinsicTable[] = {
}
)#"},

{cudaq::createCudaqStateFromDataFP32, {}, R"#(
func.func private @__nvqpp_cudaq_state_createFromData_fp32(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
{cudaq::createCudaqStateFromDataComplexF32, {}, R"#(
func.func private @__nvqpp_cudaq_state_createFromData_complex_f32(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
)#"},
{cudaq::createCudaqStateFromDataFP64, {}, R"#(
func.func private @__nvqpp_cudaq_state_createFromData_fp64(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
{cudaq::createCudaqStateFromDataComplexF64, {}, R"#(
func.func private @__nvqpp_cudaq_state_createFromData_complex_f64(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
)#"},
{cudaq::createCudaqStateFromDataF32, {}, R"#(
func.func private @__nvqpp_cudaq_state_createFromData_f32(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
)#"},
{cudaq::createCudaqStateFromDataF64, {}, R"#(
func.func private @__nvqpp_cudaq_state_createFromData_f64(%p : !cc.ptr<i8>, %s : i64) -> !cc.ptr<!quake.state>
)#"},

{cudaq::deleteCudaqState, {}, R"#(
Expand Down
4 changes: 2 additions & 2 deletions lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,8 @@ struct QmemRAIIOpRewrite : public OpConversionPattern<cudaq::codegen::RAIIOp> {

Value sizeOperand;
if (!adaptor.getAllocSize()) {
auto type = cast<quake::VeqType>(allocTy);
auto constantSize = type.getSize();
auto type = dyn_cast<quake::VeqType>(allocTy);
auto constantSize = type ? type.getSize() : 1;
sizeOperand =
rewriter.create<arith::ConstantIntOp>(loc, constantSize, 64);
} else {
Expand Down
26 changes: 18 additions & 8 deletions lib/Optimizer/CodeGen/QuakeToCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,27 @@ class CreateStateOpPattern : public OpRewritePattern<quake::CreateStateOp> {

auto bufferTy = buffer.getType();
auto ptrTy = cast<cudaq::cc::PointerType>(bufferTy);
auto arrTy = cast<cudaq::cc::ArrayType>(ptrTy.getElementType());
auto eleTy = arrTy.getElementType();
auto arrTy = dyn_cast<cudaq::cc::ArrayType>(ptrTy.getElementType());
auto eleTy = arrTy ? arrTy.getElementType() : ptrTy.getElementType();
auto is64Bit = isa<Float64Type>(eleTy);
bool isComplex = false;

if (auto cTy = dyn_cast<ComplexType>(eleTy))
if (auto cTy = dyn_cast<ComplexType>(eleTy)) {
is64Bit = isa<Float64Type>(cTy.getElementType());
isComplex = true;
}

auto createStateFunc = [&]() {
if (isComplex) {
if (is64Bit)
return cudaq::createCudaqStateFromDataComplexF64;
return cudaq::createCudaqStateFromDataComplexF32;
}
if (is64Bit)
return cudaq::createCudaqStateFromDataF64;
return cudaq::createCudaqStateFromDataF32;
}();

auto createStateFunc = is64Bit ? cudaq::createCudaqStateFromDataFP64
: cudaq::createCudaqStateFromDataFP32;
cudaq::IRBuilder irBuilder(ctx);
auto result = irBuilder.loadIntrinsic(module, createStateFunc);
assert(succeeded(result) && "loading intrinsic should never fail");
Expand All @@ -97,11 +109,9 @@ class CreateStateOpPattern : public OpRewritePattern<quake::CreateStateOp> {
auto statePtrTy = cudaq::cc::PointerType::get(stateTy);
auto i8PtrTy = cudaq::cc::PointerType::get(rewriter.getI8Type());
auto cast = rewriter.create<cudaq::cc::CastOp>(loc, i8PtrTy, buffer);
auto one = rewriter.create<arith::ConstantIntOp>(loc, 1, size.getType());
auto powsz = rewriter.create<arith::ShLIOp>(loc, size.getType(), one, size);

rewriter.replaceOpWithNewOp<func::CallOp>(
createStateOp, statePtrTy, createStateFunc, ValueRange{cast, powsz});
createStateOp, statePtrTy, createStateFunc, ValueRange{cast, size});
return success();
}
};
Expand Down
6 changes: 4 additions & 2 deletions lib/Optimizer/CodeGen/VerifyNVQIRCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ struct VerifyNVQIRCallOpsPass
cudaq::opt::QIRArrayQubitAllocateArrayWithStateComplex32,
cudaq::opt::QIRArrayQubitAllocateArrayWithStateComplex64,
cudaq::getNumQubitsFromCudaqState,
cudaq::createCudaqStateFromDataFP32,
cudaq::createCudaqStateFromDataFP64,
cudaq::createCudaqStateFromDataComplexF32,
cudaq::createCudaqStateFromDataComplexF64,
cudaq::createCudaqStateFromDataF32,
cudaq::createCudaqStateFromDataF64,
cudaq::deleteCudaqState};
// It must be either NVQIR extension functions or in the allowed list.
return std::find(NVQIR_FUNCS.begin(), NVQIR_FUNCS.end(), functionName) !=
Expand Down
Loading
Loading