@@ -682,9 +682,12 @@ bool QuakeBridgeVisitor::VisitCastExpr(clang::CastExpr *x) {
682682 if (cxxExpr->getNumArgs () == 1 )
683683 return true ;
684684 }
685- if (isa<ComplexType>(castToTy) && isa<ComplexType>(peekValue ().getType ())) {
685+ if (isa<ComplexType>(castToTy) && isa<ComplexType>(peekValue ().getType ()))
686686 return true ;
687- }
687+ if (isa<quake::StateType>(castToTy))
688+ if (auto ptrTy = dyn_cast<cudaq::cc::PointerType>(peekValue ().getType ()))
689+ if (isa<quake::StateType>(ptrTy.getElementType ()))
690+ return pushValue (builder.create <cudaq::cc::LoadOp>(loc, popValue ()));
688691 if (auto funcTy = peelPointerFromFunction (castToTy))
689692 if (auto fromTy = dyn_cast<cc::CallableType>(peekValue ().getType ())) {
690693 auto inputs = funcTy.getInputs ();
@@ -1003,8 +1006,8 @@ bool QuakeBridgeVisitor::VisitMaterializeTemporaryExpr(
10031006 // The following cases are λ expressions, quantum data, or a std::vector view.
10041007 // In those cases, there is nothing to materialize, so we can just pass the
10051008 // Value on the top of the stack.
1006- if (isa<cc::CallableType, quake::VeqType, quake::RefType, cc::SpanLikeType>(
1007- ty))
1009+ if (isa<cc::CallableType, quake::VeqType, quake::RefType, cc::SpanLikeType,
1010+ quake::StateType>( ty))
10081011 return true ;
10091012
10101013 // If not one of the above special cases, then materialize the value to a
@@ -2689,6 +2692,11 @@ bool QuakeBridgeVisitor::VisitInitListExpr(clang::InitListExpr *x) {
26892692 }
26902693
26912694 // List has 1 or more members.
2695+ if (size == 1 && isa<clang::MaterializeTemporaryExpr>(x->getInit (0 )))
2696+ if (auto alloc = peekValue ().getDefiningOp <cudaq::cc::AllocaOp>())
2697+ if (auto arrTy = dyn_cast<cudaq::cc::ArrayType>(initListTy))
2698+ if (alloc.getElementType () == arrTy.getElementType ())
2699+ return true ;
26922700 auto last = lastValues (size);
26932701 bool allRef = std::all_of (last.begin (), last.end (), [](auto v) {
26942702 return isa<quake::RefType, quake::VeqType>(v.getType ());
@@ -2916,6 +2924,32 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
29162924 loc, quake::VeqType::getUnsized (builder.getContext ()), sizeVal));
29172925 }
29182926
2927+ if (ctorName == " state" ) {
2928+ // cudaq::state ctor can be materialized when using local simulators and
2929+ // converting raw data to state vectors. Use a runtime helper function
2930+ // to perform the conversion.
2931+ Value stdvec = popValue ();
2932+ auto stateTy = cudaq::cc::PointerType::get (
2933+ quake::StateType::get (builder.getContext ()));
2934+ if (auto stdvecTy = dyn_cast<cudaq::cc::StdvecType>(stdvec.getType ())) {
2935+ auto dataTy = cudaq::cc::PointerType::get (stdvecTy.getElementType ());
2936+ Value data =
2937+ builder.create <cudaq::cc::StdvecDataOp>(loc, dataTy, stdvec);
2938+ auto i64Ty = builder.getI64Type ();
2939+ Value size =
2940+ builder.create <cudaq::cc::StdvecSizeOp>(loc, i64Ty, stdvec);
2941+ return pushValue (builder.create <quake::CreateStateOp>(
2942+ loc, stateTy, ValueRange{data, size}));
2943+ }
2944+ if (auto alloc = stdvec.getDefiningOp <cudaq::cc::AllocaOp>()) {
2945+ Value size = alloc.getSeqSize ();
2946+ return pushValue (builder.create <quake::CreateStateOp>(
2947+ loc, stateTy, ValueRange{alloc, size}));
2948+ }
2949+ TODO_loc (loc, " unhandled state constructor" );
2950+ return false ;
2951+ }
2952+
29192953 // lambda determines: is `t` a cudaq::state* ?
29202954 auto isStateType = [&](Type t) {
29212955 if (auto ptrTy = dyn_cast<cc::PointerType>(t))
@@ -2925,9 +2959,17 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
29252959
29262960 if (ctorName == " qudit" ) {
29272961 auto initials = popValue ();
2962+ if (isa<quake::StateType>(initials.getType ()))
2963+ if (auto load = initials.getDefiningOp <cudaq::cc::LoadOp>())
2964+ initials = load.getPtrvalue ();
29282965 if (isStateType (initials.getType ())) {
2929- TODO_x (loc, x, mangler, " qudit(state) ctor" );
2930- return false ;
2966+ Value alloca = builder.create <quake::AllocaOp>(loc);
2967+ auto veq1Ty = quake::VeqType::get (builder.getContext (), 1 );
2968+ Value initSt = builder.create <quake::InitializeStateOp>(
2969+ loc, veq1Ty, ValueRange{alloca, initials});
2970+ if (auto initOp = initials.getDefiningOp <quake::CreateStateOp>())
2971+ builder.create <quake::DeleteStateOp>(loc, initOp);
2972+ return pushValue (builder.create <quake::ExtractRefOp>(loc, initSt, 0 ));
29312973 }
29322974 bool ok = false ;
29332975 if (auto ptrTy = dyn_cast<cc::PointerType>(initials.getType ()))
@@ -2953,57 +2995,26 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
29532995 return pushValue (builder.create <quake::AllocaOp>(
29542996 loc, quake::VeqType::getUnsized (ctx), initials));
29552997 }
2956- if (isa<quake::StateType>(initials.getType ())) {
2998+ if (isa<quake::StateType>(initials.getType ()))
29572999 if (auto load = initials.getDefiningOp <cudaq::cc::LoadOp>())
29583000 initials = load.getPtrvalue ();
2959- }
29603001 if (isStateType (initials.getType ())) {
29613002 Value state = initials;
29623003 auto i64Ty = builder.getI64Type ();
29633004 auto numQubits =
29643005 builder.create <quake::GetNumberOfQubitsOp>(loc, i64Ty, state);
29653006 auto veqTy = quake::VeqType::getUnsized (ctx);
29663007 Value alloc = builder.create <quake::AllocaOp>(loc, veqTy, numQubits);
2967- return pushValue (builder.create <quake::InitializeStateOp>(
2968- loc, veqTy, alloc, state));
3008+ Value initSt = builder.create <quake::InitializeStateOp>(loc, veqTy,
3009+ alloc, state);
3010+ if (auto initOp = initials.getDefiningOp <quake::CreateStateOp>())
3011+ builder.create <quake::DeleteStateOp>(loc, initOp);
3012+ return pushValue (initSt);
29693013 }
2970- // Otherwise, it is the cudaq::qvector(std::vector<complex>) ctor.
2971- Value numQubits;
2972- Type initialsTy = initials.getType ();
2973- if (auto ptrTy = dyn_cast<cc::PointerType>(initialsTy)) {
2974- if (auto arrTy = dyn_cast<cc::ArrayType>(ptrTy.getElementType ())) {
2975- if (arrTy.isUnknownSize ()) {
2976- if (auto allocOp = initials.getDefiningOp <cc::AllocaOp>())
2977- if (auto size = allocOp.getSeqSize ())
2978- numQubits =
2979- builder.create <math::CountTrailingZerosOp>(loc, size);
2980- } else {
2981- std::size_t arraySize = arrTy.getSize ();
2982- if (!std::has_single_bit (arraySize)) {
2983- reportClangError (x, mangler,
2984- " state vector must be a power of 2 in length" );
2985- }
2986- numQubits = builder.create <arith::ConstantIntOp>(
2987- loc, std::countr_zero (arraySize), 64 );
2988- }
2989- }
2990- } else if (auto stdvecTy = dyn_cast<cc::StdvecType>(initialsTy)) {
2991- Value vecLen = builder.create <cc::StdvecSizeOp>(
2992- loc, builder.getI64Type (), initials);
2993- numQubits = builder.create <math::CountTrailingZerosOp>(loc, vecLen);
2994- auto ptrTy = cc::PointerType::get (stdvecTy.getElementType ());
2995- initials = builder.create <cc::StdvecDataOp>(loc, ptrTy, initials);
2996- }
2997- if (!numQubits) {
2998- reportClangError (
2999- x, mangler,
3000- " internal error: could not determine the number of qubits" );
3001- return false ;
3002- }
3003- auto veqTy = quake::VeqType::getUnsized (ctx);
3004- auto alloc = builder.create <quake::AllocaOp>(loc, veqTy, numQubits);
3005- return pushValue (builder.create <quake::InitializeStateOp>(
3006- loc, veqTy, alloc, initials));
3014+ reportClangError (
3015+ x, mangler,
3016+ " internal error: could not determine the number of qubits" );
3017+ return false ;
30073018 }
30083019 if ((ctorName == " qspan" || ctorName == " qview" ) &&
30093020 isa<quake::VeqType>(peekValue ().getType ())) {
0 commit comments