Skip to content

Commit 3588f65

Browse files
authored
Fix bug with separate compilation of quasi-entry point kernels. (#3352)
Signed-off-by: Eric Schweitz <[email protected]>
1 parent bf0d50a commit 3588f65

File tree

4 files changed

+50
-17
lines changed

4 files changed

+50
-17
lines changed

lib/Optimizer/Builder/Marshal.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,8 @@ cudaq::opt::marshal::dropAnyHiddenArguments(MutableArrayRef<BlockArgument> args,
785785
std::pair<bool, func::FuncOp> cudaq::opt::marshal::lookupHostEntryPointFunc(
786786
StringRef mangledEntryPointName, ModuleOp module, func::FuncOp funcOp) {
787787
if (mangledEntryPointName == "BuilderKernel.EntryPoint" ||
788-
mangledEntryPointName.contains("_PyKernelEntryPointRewrite")) {
788+
mangledEntryPointName.contains("_PyKernelEntryPointRewrite") ||
789+
funcOp.empty()) {
789790
// No host entry point needed.
790791
return {false, func::FuncOp{}};
791792
}

lib/Optimizer/Transforms/GenKernelExecution.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -869,14 +869,13 @@ class GenerateKernelExecution
869869
for (auto &op : *module.getBody())
870870
if (auto funcOp = dyn_cast<func::FuncOp>(op))
871871
if (funcOp.getName().startswith(cudaq::runtime::cudaqGenPrefixName) &&
872-
cudaq::opt::marshal::hasLegalType(funcOp.getFunctionType()))
872+
cudaq::opt::marshal::hasLegalType(funcOp.getFunctionType()) &&
873+
!funcOp.empty() && !funcOp->hasAttr(cudaq::generatorAnnotation))
873874
workList.push_back(funcOp);
874875

875876
LLVM_DEBUG(llvm::dbgs()
876877
<< workList.size() << " kernel entry functions to process\n");
877878
for (auto funcOp : workList) {
878-
if (funcOp->hasAttr(cudaq::generatorAnnotation))
879-
continue;
880879
auto loc = funcOp.getLoc();
881880
[[maybe_unused]] auto className =
882881
funcOp.getName().drop_front(cudaq::runtime::cudaqGenPrefixLength);

lib/Optimizer/Transforms/LowerToCFG.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,10 @@ class RewriteScope : public OpRewritePattern<cudaq::cc::ScopeOp> {
5757
auto loc = scopeOp.getLoc();
5858
auto *initBlock = rewriter.getInsertionBlock();
5959
Value stacksave;
60-
auto module = scopeOp.getOperation()->getParentOfType<ModuleOp>();
6160
auto ptrTy = cudaq::cc::PointerType::get(rewriter.getI8Type());
6261
if (scopeOp.hasAllocation(/*quantumAllocs=*/false)) {
63-
auto fun = cudaq::opt::factory::createFunction(
64-
"llvm.stacksave", ArrayRef<Type>{ptrTy}, {}, module);
65-
fun.setPrivate();
6662
auto call = rewriter.create<func::CallOp>(
67-
loc, ptrTy, fun.getSymNameAttr(), ArrayRef<Value>{});
63+
loc, ptrTy, cudaq::llvmStackSave, ArrayRef<Value>{});
6864
stacksave = call.getResult(0);
6965
}
7066
auto initPos = rewriter.getInsertionPoint();
@@ -93,10 +89,8 @@ class RewriteScope : public OpRewritePattern<cudaq::cc::ScopeOp> {
9389
rewriter.inlineRegionBefore(scopeOp.getInitRegion(), endBlock);
9490
if (stacksave) {
9591
rewriter.setInsertionPointToStart(endBlock);
96-
auto fun = cudaq::opt::factory::createFunction(
97-
"llvm.stackrestore", {}, ArrayRef<Type>{ptrTy}, module);
98-
fun.setPrivate();
99-
rewriter.create<func::CallOp>(loc, ArrayRef<Type>{}, fun.getSymNameAttr(),
92+
rewriter.create<func::CallOp>(loc, ArrayRef<Type>{},
93+
cudaq::llvmStackRestore,
10094
ArrayRef<Value>{stacksave});
10195
}
10296
rewriter.replaceOp(scopeOp, scopeResults);
@@ -331,10 +325,6 @@ class ConvertToCFGPrep
331325
mod.emitError("could not load llvm.stacksave intrinsic.");
332326
signalPassFailure();
333327
}
334-
if (failed(irBuilder.loadIntrinsic(mod, cudaq::llvmStackRestore))) {
335-
mod.emitError("could not load llvm.stackrestore intrinsic.");
336-
signalPassFailure();
337-
}
338328
}
339329
};
340330
} // namespace
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2025 NVIDIA Corporation & Affiliates. *
3+
* All rights reserved. *
4+
* *
5+
* This source code and the accompanying materials are made available under *
6+
* the terms of the Apache License 2.0 which accompanies this distribution. *
7+
******************************************************************************/
8+
9+
// clang-format off
10+
// RUN: if [ command -v split-file ]; then \
11+
// RUN: split-file %s %t && \
12+
// RUN: nvq++ %cpp_std -fenable-cudaq-run --target stim -c %t/gke-1.cpp \
13+
// RUN: %t/gke-2.cpp -o %t/gke.out && %t/gke.out ; else \
14+
// RUN: echo "skipping" ; fi
15+
// clang-format on
16+
17+
//--- gke-1.cpp
18+
19+
#include <cudaq.h>
20+
21+
// Will be defined in a separate file
22+
__qpu__ int mytest(int x, std::vector<cudaq::measure_result> y);
23+
24+
__qpu__ int mykernel() {
25+
cudaq::qvector q(2);
26+
h(q);
27+
auto mzq = mz(q);
28+
int res = cudaq::device_call(mytest, 1, mzq);
29+
return res;
30+
}
31+
32+
int main() {
33+
auto res = cudaq::run(1, mykernel);
34+
return 0;
35+
}
36+
37+
//--- gke-2.cpp
38+
39+
#include <cudaq.h>
40+
41+
__qpu__ int mytest(int x, std::vector<cudaq::measure_result> y) {
42+
return x * 2;
43+
}

0 commit comments

Comments
 (0)