@@ -259,8 +259,21 @@ struct ApplyOpPattern : public OpRewritePattern<quake::ApplyOp> {
259259
260260 LogicalResult matchAndRewrite (quake::ApplyOp apply,
261261 PatternRewriter &rewriter) const override {
262- auto calleeName = getVariantFunctionName (
263- apply, apply.getCallee ()->getRootReference ().str ());
262+ std::string calleeOrigName;
263+ if (apply.getCallee ()) {
264+ calleeOrigName = apply.getCallee ()->getRootReference ().str ();
265+ } else {
266+ // Check if the first argument is a func.ConstantOp.
267+ auto calleeVals = apply.getIndirectCallee ();
268+ if (calleeVals.empty ())
269+ return failure ();
270+ Value calleeVal = calleeVals.front ();
271+ auto fc = calleeVal.getDefiningOp <func::ConstantOp>();
272+ if (!fc)
273+ return failure ();
274+ calleeOrigName = fc.getValue ().str ();
275+ }
276+ auto calleeName = getVariantFunctionName (apply, calleeOrigName);
264277 auto *ctx = apply.getContext ();
265278 auto consTy = quake::VeqType::getUnsized (ctx);
266279 SmallVector<Value> newArgs;
@@ -286,14 +299,44 @@ struct ApplyOpPattern : public OpRewritePattern<quake::ApplyOp> {
286299 const bool constProp;
287300};
288301
302+ struct FoldCallable : public OpRewritePattern <quake::ApplyOp> {
303+ using OpRewritePattern::OpRewritePattern;
304+
305+ LogicalResult matchAndRewrite (quake::ApplyOp apply,
306+ PatternRewriter &rewriter) const override {
307+ // If we already know the callee function, there's nothing to do.
308+ if (apply.getCallee ())
309+ return failure ();
310+
311+ Value ind = apply.getIndirectCallee ()[0 ];
312+ if (auto callee = ind.getDefiningOp <cudaq::cc::InstantiateCallableOp>()) {
313+ auto sym = callee.getCallee ();
314+ SmallVector<Value> newArguments = {ind};
315+ newArguments.append (apply.getArgs ().begin (), apply.getArgs ().end ());
316+ rewriter.replaceOpWithNewOp <quake::ApplyOp>(
317+ apply, apply.getResultTypes (), sym, apply.getIsAdj (),
318+ apply.getControls (), newArguments);
319+ return success ();
320+ }
321+ return failure ();
322+ }
323+ };
324+
289325class ApplySpecializationPass
290326 : public cudaq::opt::impl::ApplySpecializationBase<
291327 ApplySpecializationPass> {
292328public:
293329 using ApplySpecializationBase::ApplySpecializationBase;
294330
295331 void runOnOperation () override {
296- ApplyOpAnalysis analysis (getOperation (), constantPropagation);
332+ ModuleOp module = getOperation ();
333+ auto *ctx = module .getContext ();
334+ RewritePatternSet patterns (ctx);
335+ patterns.insert <FoldCallable>(ctx);
336+ if (failed (applyPatternsAndFoldGreedily (module , std::move (patterns))))
337+ signalPassFailure ();
338+
339+ ApplyOpAnalysis analysis (module , constantPropagation);
297340 const auto &applyVariants = analysis.getAnalysisInfo ();
298341 if (succeeded (step1 (applyVariants)))
299342 step2 ();
0 commit comments