From 307c24cb3258ace2b2bfe9cab0549e6d05481adf Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Fri, 27 Jun 2025 18:03:38 +0000 Subject: [PATCH 01/21] Add pass to find subcircuits with phase polynomials Signed-off-by: Adam Geller --- include/cudaq/Optimizer/Transforms/Passes.td | 6 + lib/Optimizer/Transforms/CMakeLists.txt | 1 + .../Transforms/PhasePolynomialPreprocess.cpp | 57 ++++ lib/Optimizer/Transforms/Subcircuit.h | 249 ++++++++++++++++++ 4 files changed, 313 insertions(+) create mode 100644 lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp create mode 100644 lib/Optimizer/Transforms/Subcircuit.h diff --git a/include/cudaq/Optimizer/Transforms/Passes.td b/include/cudaq/Optimizer/Transforms/Passes.td index b819ad3d7d0..183e6c80728 100644 --- a/include/cudaq/Optimizer/Transforms/Passes.td +++ b/include/cudaq/Optimizer/Transforms/Passes.td @@ -865,6 +865,12 @@ def ObserveAnsatz : Pass<"observe-ansatz", "mlir::func::FuncOp"> { ]; } +def PhasePolynomialPreprocess: Pass<"phase-polynomial-preprocess", "mlir::func::FuncOp"> { + let summary = "Isolate subcircuits representable by a single phase polynomial."; + + let dependentDialects = ["cudaq::cc::CCDialect", "quake::QuakeDialect"]; +} + def PromoteRefToVeqAlloc : Pass<"promote-qubit-allocation"> { let summary = "Promote single qubit allocations."; let description = [{ diff --git a/lib/Optimizer/Transforms/CMakeLists.txt b/lib/Optimizer/Transforms/CMakeLists.txt index cf2fc80ed62..49c1f4b4b66 100644 --- a/lib/Optimizer/Transforms/CMakeLists.txt +++ b/lib/Optimizer/Transforms/CMakeLists.txt @@ -48,6 +48,7 @@ add_cudaq_library(OptTransforms DependencyAnalysis.cpp MultiControlDecomposition.cpp ObserveAnsatz.cpp + PhasePolynomialPreprocess.cpp PruneCtrlRelations.cpp PySynthCallableBlockArgs.cpp QuakeAddMetadata.cpp diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp new file mode 100644 index 00000000000..65d4fa63169 --- /dev/null +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -0,0 +1,57 @@ +/******************************************************************************* + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +#include "PassDetails.h" +#include "cudaq/Optimizer/Transforms/Passes.h" +#include "mlir/Transforms/Passes.h" +#include "cudaq/Optimizer/Dialect/CC/CCOps.h" +#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" +#include "Subcircuit.h" + +namespace cudaq::opt { +#define GEN_PASS_DEF_PHASEPOLYNOMIALPREPROCESS +#include "cudaq/Optimizer/Transforms/Passes.h.inc" +} // namespace cudaq::opt + +#define DEBUG_TYPE "phase-polynomial-preprocess" + +using namespace mlir; + +namespace { +class PhasePolynomialPreprocessPass : public cudaq::opt::impl::PhasePolynomialPreprocessBase { + using PhasePolynomialPreprocessBase::PhasePolynomialPreprocessBase; + + SetVector processed; + SmallVector subcircuits; + +public: + // AXIS-SPECIFIC: could allow controlled y and z here + bool isControlledOp(Operation *op) { + return isa(op) && op->getNumOperands() == 2; + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + func.walk([&](Operation *op) { + if (!isControlledOp(op) || ::processed(op)) + return; + + Subcircuit subcircuit(op); + subcircuits.push_back(subcircuit); + }); + + for (auto subcircuit : subcircuits) { + llvm::outs() << "Calculated subcircuit: \n"; + for (auto *op : subcircuit.getOps()) + op->dump(); + llvm::outs() << "\n"; + } + } +}; +} \ No newline at end of file diff --git a/lib/Optimizer/Transforms/Subcircuit.h b/lib/Optimizer/Transforms/Subcircuit.h new file mode 100644 index 00000000000..992d6413887 --- /dev/null +++ b/lib/Optimizer/Transforms/Subcircuit.h @@ -0,0 +1,249 @@ + +#include "cudaq/Optimizer/Dialect/CC/CCOps.h" +#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" + +using namespace mlir; + +#define RAW(X) quake::X +#define RAW_MEASURE_OPS MEASURE_OPS(RAW) +#define RAW_GATE_OPS GATE_OPS(RAW) +#define RAW_QUANTUM_OPS QUANTUM_OPS(RAW) +// AXIS-SPECIFIC: Defines which operations break a circuit into subcircuits +#define CIRCUIT_BREAKERS(MACRO) MACRO(YOp), MACRO(ZOp), MACRO(HOp), MACRO(R1Op), MACRO(RxOp), MACRO(PhasedRxOp), MACRO(RyOp), \ + MACRO(U2Op), MACRO(U3Op) +#define RAW_CIRCUIT_BREAKERS CIRCUIT_BREAKERS(RAW) + + +unsigned calculateSkip(Operation *op) { + auto i = 0; + for (auto type : op->getOperandTypes()) { + if (isa(type)) + return i; + i++; + } + + return i; +} + +Value getNextOperand(Value v) { + auto result = dyn_cast(v); + auto op = result.getDefiningOp(); + auto skip = calculateSkip(op); + auto operandIDX = result.getResultNumber() + skip; + return op->getOperand(operandIDX); +} + +// TODO: Handle block arguments +OpResult getNextResult(OpResult v) { + assert(v.hasOneUse()); + auto correspondingOperand = v.getUses().begin(); + auto op = correspondingOperand.getUser(); + auto skip = calculateSkip(op); + auto resultIDX = correspondingOperand.getOperand()->getOperandNumber() - skip; + return op->getResult(resultIDX); +} + +inline bool processed(Operation *op) { + return op->hasAttr("processed"); +} + +inline void markProcessed(Operation *op) { + op->setAttr("processed", OpBuilder(op).getUnitAttr()); +} + +class Subcircuit { +protected: + SetVector seen; + SetVector ops; + SetVector termination_points; + SetVector anchor_points; + + bool isAfterTerminationPoint(Value wire) { + return isTerminationPoint(wire.getDefiningOp()); + } + + bool isTerminationPoint(Operation *op) { + // The operation is already part of another subcircuit + if (processed(op)) + return true; + + if (isa(op)) + return true; + + if (isa(op)) + return true; + + auto opi = dyn_cast(op); + assert(opi); + // Only allow single control + if (opi.getControls().size() > 1) + return true; + return false; + } + + void maybeAddAnchorPoint(Value v) { + if (!seen.contains(v)) + anchor_points.insert(v); + } + + void calculateSubcircuitForQubitForward(OpResult v) { + seen.insert(v); + if (!v.hasOneUse()) { + termination_points.insert(v); + return; + } + Operation *op = v.getUses().begin().getUser(); + + if (isTerminationPoint(op)) { + termination_points.insert(v); + return; + } + + ops.insert(op); + + // Controlled not, figure out whether we are tracking the control + // or target, and add an anchor point to the other qubit + if (op->getResults().size() > 1) { + auto control = op->getResult(0); + auto target = op->getResult(1); + // Is this the control or target qubit? + if (v.getResultNumber() == 0) { + // Tracking the control qubit + calculateSubcircuitForQubitForward(control); + maybeAddAnchorPoint(target); + } else { + // Tracking the target qubit + maybeAddAnchorPoint(control); + calculateSubcircuitForQubitForward(target); + } + } else { + // Otherwise, single qubit gate, just follow result + calculateSubcircuitForQubitForward(getNextResult(v)); + } + } + + void calculateSubcircuitForQubitBackward(Value v) { + seen.insert(v); + Operation *op = v.getDefiningOp(); + + if (isTerminationPoint(op)) { + termination_points.insert(v); + return; + } + + ops.insert(op); + + // Controlled not, figure out whether we are tracking the control + // or target, and add an anchor point to the other qubit + // Use getResults() as Rz has two operands but only one result + if (op->getResults().size() > 1) { + auto control = op->getOperand(0); + auto target = op->getOperand(1); + // Is this the control or target qubit? + if (v == target) { + // Tracking the control qubit + calculateSubcircuitForQubitBackward(control); + maybeAddAnchorPoint(target); + } else { + // Tracking the target qubit + maybeAddAnchorPoint(control); + calculateSubcircuitForQubitBackward(target); + } + } else { + // Otherwise, single qubit gate, just follow operand + calculateSubcircuitForQubitBackward(getNextOperand(v)); + } + } + + void calculateInitialSubcircuit(Operation *op) { + // AXIS-SPECIFIC: This could be any controlled operation + auto cnot = dyn_cast(op); + assert(cnot && cnot.getWires().size() == 2); + + auto result = cnot->getResult(0); + auto operand = cnot->getOperand(0); + ops.insert(cnot); + anchor_points.insert(cnot->getResult(1)); + calculateSubcircuitForQubitForward(result); + calculateSubcircuitForQubitBackward(operand); + + while (!anchor_points.empty()) { + auto next = anchor_points.back(); + anchor_points.pop_back(); + calculateSubcircuitForQubitForward(dyn_cast(next)); + seen.remove(next); + calculateSubcircuitForQubitBackward(next); + } + } + + // Prune operations after a termination point from the subcircuit + void pruneWire(Value wire) { + if (termination_points.contains(wire)) + termination_points.remove(wire); + if (!wire.hasOneUse()) + return; + Operation *op = wire.getUses().begin().getUser(); + + ops.remove(op); + + // TODO: According to the paper, if the op is a CNot and the wire we are pruning along is the target, then we do not have to prune along the control wire. However, this prevents placing each subcircuit in a separate block, so it is currently not supported + // auto opi = dyn_cast(op); + // assert(opi); + // auto controls = opi.getControls(); + // if (controls.size() > 0 && + // std::find(controls.begin(), controls.end(), wire) == controls.end()) { + // pruneSubcircuit(opi.getWires()[1]); + // return; + // } + + for (auto result : op->getResults()) { + pruneWire(result); + // Adjust termination border + for (auto operand : op->getOperands()) + if (ops.contains(operand.getDefiningOp())) + termination_points.insert(operand); + } + } + + void pruneSubcircuit() { + // The termination boundary should be defined by the first + // termination point seen along each wire in the subcircuit + // (this means that it is important to build subcircuits + // by inspecting controlled gates in topological order) + for (auto wire : termination_points) { + if (!isAfterTerminationPoint(wire)) + pruneWire(wire); + } + } + +public: + /// @brief Constructs a subcircuit with a phase polynomial starting from a cnot + Subcircuit(Operation *cnot) { + calculateInitialSubcircuit(cnot); + pruneSubcircuit(); + for (auto *op : ops) + markProcessed(op); + } + + SetVector getInitialWires() { + SetVector initial; + for (auto wire : termination_points) + if (isAfterTerminationPoint(wire)) + initial.insert(wire); + return initial; + } + + bool isInSubcircuit(Operation *op) { + return ops.contains(op); + } + + // TODO: would be nice to make Subcircuit iterable directly + SetVector getOps() { + return ops; + } + + /// @brief returns the number of wires in the subcircuit + size_t numWires() { + return getInitialWires().size(); + } +}; From b2ca80594a9eacdc64872da7fe2a970355b6e026 Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Fri, 27 Jun 2025 18:05:26 +0000 Subject: [PATCH 02/21] Formatting Signed-off-by: Adam Geller --- .../Transforms/PhasePolynomialPreprocess.cpp | 60 +-- lib/Optimizer/Transforms/Subcircuit.h | 403 +++++++++--------- 2 files changed, 229 insertions(+), 234 deletions(-) diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp index 65d4fa63169..71367410e9e 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -7,11 +7,11 @@ ******************************************************************************/ #include "PassDetails.h" -#include "cudaq/Optimizer/Transforms/Passes.h" -#include "mlir/Transforms/Passes.h" +#include "Subcircuit.h" #include "cudaq/Optimizer/Dialect/CC/CCOps.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" -#include "Subcircuit.h" +#include "cudaq/Optimizer/Transforms/Passes.h" +#include "mlir/Transforms/Passes.h" namespace cudaq::opt { #define GEN_PASS_DEF_PHASEPOLYNOMIALPREPROCESS @@ -23,35 +23,37 @@ namespace cudaq::opt { using namespace mlir; namespace { -class PhasePolynomialPreprocessPass : public cudaq::opt::impl::PhasePolynomialPreprocessBase { +class PhasePolynomialPreprocessPass + : public cudaq::opt::impl::PhasePolynomialPreprocessBase< + PhasePolynomialPreprocessPass> { using PhasePolynomialPreprocessBase::PhasePolynomialPreprocessBase; - SetVector processed; - SmallVector subcircuits; + SetVector processed; + SmallVector subcircuits; public: - // AXIS-SPECIFIC: could allow controlled y and z here - bool isControlledOp(Operation *op) { - return isa(op) && op->getNumOperands() == 2; - } - - void runOnOperation() override { - func::FuncOp func = getOperation(); - - func.walk([&](Operation *op) { - if (!isControlledOp(op) || ::processed(op)) - return; - - Subcircuit subcircuit(op); - subcircuits.push_back(subcircuit); - }); - - for (auto subcircuit : subcircuits) { - llvm::outs() << "Calculated subcircuit: \n"; - for (auto *op : subcircuit.getOps()) - op->dump(); - llvm::outs() << "\n"; - } + // AXIS-SPECIFIC: could allow controlled y and z here + bool isControlledOp(Operation *op) { + return isa(op) && op->getNumOperands() == 2; + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + func.walk([&](Operation *op) { + if (!isControlledOp(op) || ::processed(op)) + return; + + Subcircuit subcircuit(op); + subcircuits.push_back(subcircuit); + }); + + for (auto subcircuit : subcircuits) { + llvm::outs() << "Calculated subcircuit: \n"; + for (auto *op : subcircuit.getOps()) + op->dump(); + llvm::outs() << "\n"; } + } }; -} \ No newline at end of file +} // namespace diff --git a/lib/Optimizer/Transforms/Subcircuit.h b/lib/Optimizer/Transforms/Subcircuit.h index 992d6413887..efdf113799f 100644 --- a/lib/Optimizer/Transforms/Subcircuit.h +++ b/lib/Optimizer/Transforms/Subcircuit.h @@ -9,241 +9,234 @@ using namespace mlir; #define RAW_GATE_OPS GATE_OPS(RAW) #define RAW_QUANTUM_OPS QUANTUM_OPS(RAW) // AXIS-SPECIFIC: Defines which operations break a circuit into subcircuits -#define CIRCUIT_BREAKERS(MACRO) MACRO(YOp), MACRO(ZOp), MACRO(HOp), MACRO(R1Op), MACRO(RxOp), MACRO(PhasedRxOp), MACRO(RyOp), \ - MACRO(U2Op), MACRO(U3Op) +#define CIRCUIT_BREAKERS(MACRO) \ + MACRO(YOp), MACRO(ZOp), MACRO(HOp), MACRO(R1Op), MACRO(RxOp), \ + MACRO(PhasedRxOp), MACRO(RyOp), MACRO(U2Op), MACRO(U3Op) #define RAW_CIRCUIT_BREAKERS CIRCUIT_BREAKERS(RAW) - unsigned calculateSkip(Operation *op) { - auto i = 0; - for (auto type : op->getOperandTypes()) { - if (isa(type)) - return i; - i++; - } - - return i; + auto i = 0; + for (auto type : op->getOperandTypes()) { + if (isa(type)) + return i; + i++; + } + + return i; } Value getNextOperand(Value v) { - auto result = dyn_cast(v); - auto op = result.getDefiningOp(); - auto skip = calculateSkip(op); - auto operandIDX = result.getResultNumber() + skip; - return op->getOperand(operandIDX); + auto result = dyn_cast(v); + auto op = result.getDefiningOp(); + auto skip = calculateSkip(op); + auto operandIDX = result.getResultNumber() + skip; + return op->getOperand(operandIDX); } // TODO: Handle block arguments OpResult getNextResult(OpResult v) { - assert(v.hasOneUse()); - auto correspondingOperand = v.getUses().begin(); - auto op = correspondingOperand.getUser(); - auto skip = calculateSkip(op); - auto resultIDX = correspondingOperand.getOperand()->getOperandNumber() - skip; - return op->getResult(resultIDX); + assert(v.hasOneUse()); + auto correspondingOperand = v.getUses().begin(); + auto op = correspondingOperand.getUser(); + auto skip = calculateSkip(op); + auto resultIDX = correspondingOperand.getOperand()->getOperandNumber() - skip; + return op->getResult(resultIDX); } -inline bool processed(Operation *op) { - return op->hasAttr("processed"); -} +inline bool processed(Operation *op) { return op->hasAttr("processed"); } inline void markProcessed(Operation *op) { - op->setAttr("processed", OpBuilder(op).getUnitAttr()); + op->setAttr("processed", OpBuilder(op).getUnitAttr()); } class Subcircuit { protected: - SetVector seen; - SetVector ops; - SetVector termination_points; - SetVector anchor_points; - - bool isAfterTerminationPoint(Value wire) { - return isTerminationPoint(wire.getDefiningOp()); - } - - bool isTerminationPoint(Operation *op) { - // The operation is already part of another subcircuit - if (processed(op)) - return true; - - if (isa(op)) - return true; - - if (isa(op)) - return true; - - auto opi = dyn_cast(op); - assert(opi); - // Only allow single control - if (opi.getControls().size() > 1) - return true; - return false; - } - - void maybeAddAnchorPoint(Value v) { - if (!seen.contains(v)) - anchor_points.insert(v); + SetVector seen; + SetVector ops; + SetVector termination_points; + SetVector anchor_points; + + bool isAfterTerminationPoint(Value wire) { + return isTerminationPoint(wire.getDefiningOp()); + } + + bool isTerminationPoint(Operation *op) { + // The operation is already part of another subcircuit + if (processed(op)) + return true; + + if (isa(op)) + return true; + + if (isa(op)) + return true; + + auto opi = dyn_cast(op); + assert(opi); + // Only allow single control + if (opi.getControls().size() > 1) + return true; + return false; + } + + void maybeAddAnchorPoint(Value v) { + if (!seen.contains(v)) + anchor_points.insert(v); + } + + void calculateSubcircuitForQubitForward(OpResult v) { + seen.insert(v); + if (!v.hasOneUse()) { + termination_points.insert(v); + return; } + Operation *op = v.getUses().begin().getUser(); - void calculateSubcircuitForQubitForward(OpResult v) { - seen.insert(v); - if (!v.hasOneUse()) { - termination_points.insert(v); - return; - } - Operation *op = v.getUses().begin().getUser(); - - if (isTerminationPoint(op)) { - termination_points.insert(v); - return; - } - - ops.insert(op); - - // Controlled not, figure out whether we are tracking the control - // or target, and add an anchor point to the other qubit - if (op->getResults().size() > 1) { - auto control = op->getResult(0); - auto target = op->getResult(1); - // Is this the control or target qubit? - if (v.getResultNumber() == 0) { - // Tracking the control qubit - calculateSubcircuitForQubitForward(control); - maybeAddAnchorPoint(target); - } else { - // Tracking the target qubit - maybeAddAnchorPoint(control); - calculateSubcircuitForQubitForward(target); - } - } else { - // Otherwise, single qubit gate, just follow result - calculateSubcircuitForQubitForward(getNextResult(v)); - } + if (isTerminationPoint(op)) { + termination_points.insert(v); + return; } - void calculateSubcircuitForQubitBackward(Value v) { - seen.insert(v); - Operation *op = v.getDefiningOp(); - - if (isTerminationPoint(op)) { - termination_points.insert(v); - return; - } - - ops.insert(op); - - // Controlled not, figure out whether we are tracking the control - // or target, and add an anchor point to the other qubit - // Use getResults() as Rz has two operands but only one result - if (op->getResults().size() > 1) { - auto control = op->getOperand(0); - auto target = op->getOperand(1); - // Is this the control or target qubit? - if (v == target) { - // Tracking the control qubit - calculateSubcircuitForQubitBackward(control); - maybeAddAnchorPoint(target); - } else { - // Tracking the target qubit - maybeAddAnchorPoint(control); - calculateSubcircuitForQubitBackward(target); - } - } else { - // Otherwise, single qubit gate, just follow operand - calculateSubcircuitForQubitBackward(getNextOperand(v)); - } + ops.insert(op); + + // Controlled not, figure out whether we are tracking the control + // or target, and add an anchor point to the other qubit + if (op->getResults().size() > 1) { + auto control = op->getResult(0); + auto target = op->getResult(1); + // Is this the control or target qubit? + if (v.getResultNumber() == 0) { + // Tracking the control qubit + calculateSubcircuitForQubitForward(control); + maybeAddAnchorPoint(target); + } else { + // Tracking the target qubit + maybeAddAnchorPoint(control); + calculateSubcircuitForQubitForward(target); + } + } else { + // Otherwise, single qubit gate, just follow result + calculateSubcircuitForQubitForward(getNextResult(v)); } + } - void calculateInitialSubcircuit(Operation *op) { - // AXIS-SPECIFIC: This could be any controlled operation - auto cnot = dyn_cast(op); - assert(cnot && cnot.getWires().size() == 2); - - auto result = cnot->getResult(0); - auto operand = cnot->getOperand(0); - ops.insert(cnot); - anchor_points.insert(cnot->getResult(1)); - calculateSubcircuitForQubitForward(result); - calculateSubcircuitForQubitBackward(operand); - - while (!anchor_points.empty()) { - auto next = anchor_points.back(); - anchor_points.pop_back(); - calculateSubcircuitForQubitForward(dyn_cast(next)); - seen.remove(next); - calculateSubcircuitForQubitBackward(next); - } - } + void calculateSubcircuitForQubitBackward(Value v) { + seen.insert(v); + Operation *op = v.getDefiningOp(); - // Prune operations after a termination point from the subcircuit - void pruneWire(Value wire) { - if (termination_points.contains(wire)) - termination_points.remove(wire); - if (!wire.hasOneUse()) - return; - Operation *op = wire.getUses().begin().getUser(); - - ops.remove(op); - - // TODO: According to the paper, if the op is a CNot and the wire we are pruning along is the target, then we do not have to prune along the control wire. However, this prevents placing each subcircuit in a separate block, so it is currently not supported - // auto opi = dyn_cast(op); - // assert(opi); - // auto controls = opi.getControls(); - // if (controls.size() > 0 && - // std::find(controls.begin(), controls.end(), wire) == controls.end()) { - // pruneSubcircuit(opi.getWires()[1]); - // return; - // } - - for (auto result : op->getResults()) { - pruneWire(result); - // Adjust termination border - for (auto operand : op->getOperands()) - if (ops.contains(operand.getDefiningOp())) - termination_points.insert(operand); - } + if (isTerminationPoint(op)) { + termination_points.insert(v); + return; } - void pruneSubcircuit() { - // The termination boundary should be defined by the first - // termination point seen along each wire in the subcircuit - // (this means that it is important to build subcircuits - // by inspecting controlled gates in topological order) - for (auto wire : termination_points) { - if (!isAfterTerminationPoint(wire)) - pruneWire(wire); - } + ops.insert(op); + + // Controlled not, figure out whether we are tracking the control + // or target, and add an anchor point to the other qubit + // Use getResults() as Rz has two operands but only one result + if (op->getResults().size() > 1) { + auto control = op->getOperand(0); + auto target = op->getOperand(1); + // Is this the control or target qubit? + if (v == target) { + // Tracking the control qubit + calculateSubcircuitForQubitBackward(control); + maybeAddAnchorPoint(target); + } else { + // Tracking the target qubit + maybeAddAnchorPoint(control); + calculateSubcircuitForQubitBackward(target); + } + } else { + // Otherwise, single qubit gate, just follow operand + calculateSubcircuitForQubitBackward(getNextOperand(v)); } - -public: - /// @brief Constructs a subcircuit with a phase polynomial starting from a cnot - Subcircuit(Operation *cnot) { - calculateInitialSubcircuit(cnot); - pruneSubcircuit(); - for (auto *op : ops) - markProcessed(op); + } + + void calculateInitialSubcircuit(Operation *op) { + // AXIS-SPECIFIC: This could be any controlled operation + auto cnot = dyn_cast(op); + assert(cnot && cnot.getWires().size() == 2); + + auto result = cnot->getResult(0); + auto operand = cnot->getOperand(0); + ops.insert(cnot); + anchor_points.insert(cnot->getResult(1)); + calculateSubcircuitForQubitForward(result); + calculateSubcircuitForQubitBackward(operand); + + while (!anchor_points.empty()) { + auto next = anchor_points.back(); + anchor_points.pop_back(); + calculateSubcircuitForQubitForward(dyn_cast(next)); + seen.remove(next); + calculateSubcircuitForQubitBackward(next); } - - SetVector getInitialWires() { - SetVector initial; - for (auto wire : termination_points) - if (isAfterTerminationPoint(wire)) - initial.insert(wire); - return initial; + } + + // Prune operations after a termination point from the subcircuit + void pruneWire(Value wire) { + if (termination_points.contains(wire)) + termination_points.remove(wire); + if (!wire.hasOneUse()) + return; + Operation *op = wire.getUses().begin().getUser(); + + ops.remove(op); + + // TODO: According to the paper, if the op is a CNot and the wire we are + // pruning along is the target, then we do not have to prune along the + // control wire. However, this prevents placing each subcircuit in a + // separate block, so it is currently not supported auto opi = + // dyn_cast(op); assert(opi); auto controls = + // opi.getControls(); if (controls.size() > 0 && + // std::find(controls.begin(), controls.end(), wire) == controls.end()) + // { pruneSubcircuit(opi.getWires()[1]); return; + // } + + for (auto result : op->getResults()) { + pruneWire(result); + // Adjust termination border + for (auto operand : op->getOperands()) + if (ops.contains(operand.getDefiningOp())) + termination_points.insert(operand); } - - bool isInSubcircuit(Operation *op) { - return ops.contains(op); + } + + void pruneSubcircuit() { + // The termination boundary should be defined by the first + // termination point seen along each wire in the subcircuit + // (this means that it is important to build subcircuits + // by inspecting controlled gates in topological order) + for (auto wire : termination_points) { + if (!isAfterTerminationPoint(wire)) + pruneWire(wire); } + } - // TODO: would be nice to make Subcircuit iterable directly - SetVector getOps() { - return ops; - } - - /// @brief returns the number of wires in the subcircuit - size_t numWires() { - return getInitialWires().size(); - } +public: + /// @brief Constructs a subcircuit with a phase polynomial starting from a + /// cnot + Subcircuit(Operation *cnot) { + calculateInitialSubcircuit(cnot); + pruneSubcircuit(); + for (auto *op : ops) + markProcessed(op); + } + + SetVector getInitialWires() { + SetVector initial; + for (auto wire : termination_points) + if (isAfterTerminationPoint(wire)) + initial.insert(wire); + return initial; + } + + bool isInSubcircuit(Operation *op) { return ops.contains(op); } + + // TODO: would be nice to make Subcircuit iterable directly + SetVector getOps() { return ops; } + + /// @brief returns the number of wires in the subcircuit + size_t numWires() { return getInitialWires().size(); } }; From 9bdf5a646a956233458757a1a9f18bf74d4ca1bf Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Mon, 30 Jun 2025 23:40:08 +0000 Subject: [PATCH 03/21] Split subcircuits into separate functions Signed-off-by: Adam Geller --- include/cudaq/Optimizer/Transforms/Passes.td | 2 +- .../Transforms/PhasePolynomialPreprocess.cpp | 188 ++++++++++++++++-- lib/Optimizer/Transforms/Subcircuit.h | 45 +++-- 3 files changed, 203 insertions(+), 32 deletions(-) diff --git a/include/cudaq/Optimizer/Transforms/Passes.td b/include/cudaq/Optimizer/Transforms/Passes.td index 183e6c80728..b77cbeb1b61 100644 --- a/include/cudaq/Optimizer/Transforms/Passes.td +++ b/include/cudaq/Optimizer/Transforms/Passes.td @@ -865,7 +865,7 @@ def ObserveAnsatz : Pass<"observe-ansatz", "mlir::func::FuncOp"> { ]; } -def PhasePolynomialPreprocess: Pass<"phase-polynomial-preprocess", "mlir::func::FuncOp"> { +def PhasePolynomialPreprocess: Pass<"phase-polynomial-preprocess", "mlir::ModuleOp"> { let summary = "Isolate subcircuits representable by a single phase polynomial."; let dependentDialects = ["cudaq::cc::CCDialect", "quake::QuakeDialect"]; diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp index 71367410e9e..6febcb282b8 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -8,6 +8,7 @@ #include "PassDetails.h" #include "Subcircuit.h" +#include "cudaq/Optimizer/Builder/Factory.h" #include "cudaq/Optimizer/Dialect/CC/CCOps.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" #include "cudaq/Optimizer/Transforms/Passes.h" @@ -28,8 +29,157 @@ class PhasePolynomialPreprocessPass PhasePolynomialPreprocessPass> { using PhasePolynomialPreprocessBase::PhasePolynomialPreprocessBase; - SetVector processed; - SmallVector subcircuits; + // void markRounds(Value wire, int round) { + // for (auto &uses : wire.getUses()) { + // auto op = uses.getOwner(); + // if (op->hasAttr("round") && + // op->getAttrOfType("round").getInt()) + // continue; + // op->setAttr("round", IntegerAttr::get(mlir::IntegerType::get(), + // round)); for (auto result : op->getResults()) + // markRounds(wire, round + 1); + // } + // } + + class WireStepper { + Value old_wire; + Value new_wire; + Subcircuit *subcircuit; + + public: + WireStepper(Subcircuit *circuit, Value initial, Value arg) { + subcircuit = circuit; + old_wire = initial; + new_wire = arg; + } + + bool isStopped() { + return subcircuit->getTerminalWires().contains(old_wire); + } + + Value getNewWire() { return new_wire; } + + Value getOldWire() { return old_wire; } + + void step(DenseMap &cloned, OpBuilder &builder) { + if (isStopped()) + return; + + // TODO: Something more elegant here would be nice + Operation *op = nullptr; + auto opnum = -1; + for (auto &use : old_wire.getUses()) { + if (use.getOwner()->getBlock() == builder.getInsertionBlock()) + continue; + op = use.getOwner(); + opnum = use.getOperandNumber(); + } + + assert(op); + + if (cloned.count(op) == 1) { + cloned[op]->setOperand(opnum, new_wire); + assert(old_wire.hasOneUse()); + old_wire = getNextResult(old_wire); + new_wire = getNextResult(new_wire); + return; + } + + // Make sure all dependencies have been cloned + for (auto dependency : op->getOperands()) { + if (!isa(dependency.getType())) + continue; + auto dop = dependency.getDefiningOp(); + if (cloned.count(dop) != 1 && + !subcircuit->getInitialWires().contains(dependency)) + return; + } + + auto clone = builder.clone(*op); + clone->setOperand(opnum, new_wire); + + // For now, just copy over all classical constants + builder.setInsertionPointToStart(clone->getBlock()); + for (size_t i = 0; i < clone->getNumOperands(); i++) { + auto dependency = clone->getOperand(i); + if (!isa(dependency.getType())) { + auto dop = dependency.getDefiningOp(); + assert(isa(dop)); + auto clone_dop = builder.clone(*dop); + clone->setOperand(i, clone_dop->getResult(0)); + } + } + builder.setInsertionPointAfter(clone); + + cloned[op] = clone; + assert(old_wire.hasOneUse()); + old_wire = getNextResult(old_wire); + new_wire = getNextResult(new_wire); + } + }; + + void removeOld(Subcircuit &subcircuit, + SmallVector &removal_order, Operation *next) { + if (!subcircuit.getOps().contains(next) || + std::find(removal_order.begin(), removal_order.end(), next) != + removal_order.end()) + return; + + for (auto result : next->getResults()) + for (auto *user : result.getUsers()) + removeOld(subcircuit, removal_order, user); + + removal_order.push_back(next); + } + + void moveToFunc(Subcircuit *subcircuit, size_t subcircuit_num) { + auto module = getOperation(); + SmallVector types(subcircuit->getInitialWires().size(), + quake::WireType::get(module.getContext())); + auto name = std::string("subcircuit") + std::to_string(subcircuit_num); + auto fun = cudaq::opt::factory::createFunction(name, types, types, module); + fun.setPrivate(); + auto entry = fun.addEntryBlock(); + OpBuilder builder(fun); + + DenseMap cloned; + + // Need to keep ordering to match returns with arguments + SmallVector wires_in; + SmallVector steppers; + for (auto wire : subcircuit->getInitialWires()) { + wires_in.push_back(wire); + steppers.push_back( + new WireStepper(subcircuit, wire, fun.getArgument(steppers.size()))); + } + + builder.setInsertionPointToStart(entry); + while (true) { + auto stepped = false; + for (auto stepper : steppers) { + if (!stepper->isStopped()) + stepped = true; + stepper->step(cloned, builder); + } + + if (!stepped) + break; + } + + SmallVector new_wires; + for (size_t i = 0; i < steppers.size(); i++) + new_wires.push_back(steppers[i]->getNewWire()); + + builder.create(fun.getLoc(), new_wires); + + auto cnot = subcircuit->getStart(); + + builder.setInsertionPointAfter(cnot); + auto call = builder.create(cnot->getLoc(), types, + fun.getSymNameAttr(), wires_in); + for (size_t i = 0; i < steppers.size(); i++) + steppers[i]->getOldWire().replaceAllUsesWith(call.getResult(i)); + } public: // AXIS-SPECIFIC: could allow controlled y and z here @@ -38,21 +188,31 @@ class PhasePolynomialPreprocessPass } void runOnOperation() override { - func::FuncOp func = getOperation(); + auto module = getOperation(); + size_t i = 0; + SetVector subcircuits; - func.walk([&](Operation *op) { - if (!isControlledOp(op) || ::processed(op)) - return; + for (auto &op : module) { + if (auto func = dyn_cast(op)) { + // TODO: this is yucky, having to rewalk the function because we're + // mutating it as we go + func.walk([&](quake::XOp op) { + if (!isControlledOp(op) || ::processed(op)) + return; - Subcircuit subcircuit(op); - subcircuits.push_back(subcircuit); - }); + auto *subcircuit = new Subcircuit(op); + moveToFunc(subcircuit, i++); + subcircuits.insert(subcircuit); + }); + } + } - for (auto subcircuit : subcircuits) { - llvm::outs() << "Calculated subcircuit: \n"; - for (auto *op : subcircuit.getOps()) - op->dump(); - llvm::outs() << "\n"; + for (auto *subcircuit : subcircuits) { + for (auto op : subcircuit->getOps()) { + op->dropAllUses(); + op->erase(); + } + delete subcircuit; } } }; diff --git a/lib/Optimizer/Transforms/Subcircuit.h b/lib/Optimizer/Transforms/Subcircuit.h index efdf113799f..763cc29dd62 100644 --- a/lib/Optimizer/Transforms/Subcircuit.h +++ b/lib/Optimizer/Transforms/Subcircuit.h @@ -33,8 +33,7 @@ Value getNextOperand(Value v) { return op->getOperand(operandIDX); } -// TODO: Handle block arguments -OpResult getNextResult(OpResult v) { +OpResult getNextResult(Value v) { assert(v.hasOneUse()); auto correspondingOperand = v.getUses().begin(); auto op = correspondingOperand.getUser(); @@ -51,20 +50,22 @@ inline void markProcessed(Operation *op) { class Subcircuit { protected: - SetVector seen; SetVector ops; + SetVector initial_wires; + SetVector terminal_wires; + Operation *start; + // TODO: these three are really intermediate state + // for constructing the subcircuit, it would be nice + // to turn them into shared arguments instead SetVector termination_points; SetVector anchor_points; + SetVector seen; bool isAfterTerminationPoint(Value wire) { return isTerminationPoint(wire.getDefiningOp()); } bool isTerminationPoint(Operation *op) { - // The operation is already part of another subcircuit - if (processed(op)) - return true; - if (isa(op)) return true; @@ -72,7 +73,10 @@ class Subcircuit { return true; auto opi = dyn_cast(op); - assert(opi); + + if (!opi) + return true; + // Only allow single control if (opi.getControls().size() > 1) return true; @@ -92,7 +96,7 @@ class Subcircuit { } Operation *op = v.getUses().begin().getUser(); - if (isTerminationPoint(op)) { + if (isTerminationPoint(op) || processed(op)) { termination_points.insert(v); return; } @@ -124,7 +128,7 @@ class Subcircuit { seen.insert(v); Operation *op = v.getDefiningOp(); - if (isTerminationPoint(op)) { + if (isTerminationPoint(op) || processed(op)) { termination_points.insert(v); return; } @@ -209,7 +213,7 @@ class Subcircuit { // (this means that it is important to build subcircuits // by inspecting controlled gates in topological order) for (auto wire : termination_points) { - if (!isAfterTerminationPoint(wire)) + if (!isAfterTerminationPoint(wire) && wire.hasOneUse()) pruneWire(wire); } } @@ -222,16 +226,21 @@ class Subcircuit { pruneSubcircuit(); for (auto *op : ops) markProcessed(op); - } + start = cnot; - SetVector getInitialWires() { - SetVector initial; - for (auto wire : termination_points) + for (auto wire : termination_points) { + wire.dump(); if (isAfterTerminationPoint(wire)) - initial.insert(wire); - return initial; + initial_wires.insert(wire); + else + terminal_wires.insert(wire); + } } + SetVector getInitialWires() { return initial_wires; } + + SetVector getTerminalWires() { return terminal_wires; } + bool isInSubcircuit(Operation *op) { return ops.contains(op); } // TODO: would be nice to make Subcircuit iterable directly @@ -239,4 +248,6 @@ class Subcircuit { /// @brief returns the number of wires in the subcircuit size_t numWires() { return getInitialWires().size(); } + + Operation *getStart() { return start; } }; From 4ef12aab365c8bc9c9534a846ff29841343b28a6 Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Mon, 30 Jun 2025 23:41:19 +0000 Subject: [PATCH 04/21] Remove debug code Signed-off-by: Adam Geller --- lib/Optimizer/Transforms/Subcircuit.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/Optimizer/Transforms/Subcircuit.h b/lib/Optimizer/Transforms/Subcircuit.h index 763cc29dd62..62d73f7da1e 100644 --- a/lib/Optimizer/Transforms/Subcircuit.h +++ b/lib/Optimizer/Transforms/Subcircuit.h @@ -228,13 +228,11 @@ class Subcircuit { markProcessed(op); start = cnot; - for (auto wire : termination_points) { - wire.dump(); + for (auto wire : termination_points) if (isAfterTerminationPoint(wire)) initial_wires.insert(wire); else terminal_wires.insert(wire); - } } SetVector getInitialWires() { return initial_wires; } From ba7e16b63f1d492bbbeebd72b1825376d0411555 Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Wed, 2 Jul 2025 23:52:53 +0000 Subject: [PATCH 05/21] A bit of cleanup on subcircuit logic Signed-off-by: Adam Geller --- lib/Optimizer/Transforms/CMakeLists.txt | 1 + .../Transforms/PhasePolynomialPreprocess.cpp | 30 ++-- lib/Optimizer/Transforms/Subcircuit.cpp | 153 ++++++++++++++++++ lib/Optimizer/Transforms/Subcircuit.h | 92 +++-------- 4 files changed, 189 insertions(+), 87 deletions(-) create mode 100644 lib/Optimizer/Transforms/Subcircuit.cpp diff --git a/lib/Optimizer/Transforms/CMakeLists.txt b/lib/Optimizer/Transforms/CMakeLists.txt index 49c1f4b4b66..5a944f6c10c 100644 --- a/lib/Optimizer/Transforms/CMakeLists.txt +++ b/lib/Optimizer/Transforms/CMakeLists.txt @@ -58,6 +58,7 @@ add_cudaq_library(OptTransforms RegToMem.cpp ReplaceStateWithKernel.cpp SROA.cpp + Subcircuit.cpp StatePreparation.cpp UnitarySynthesis.cpp UpdateRegisterNames.cpp diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp index 6febcb282b8..2d77a3e966b 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -29,18 +29,7 @@ class PhasePolynomialPreprocessPass PhasePolynomialPreprocessPass> { using PhasePolynomialPreprocessBase::PhasePolynomialPreprocessBase; - // void markRounds(Value wire, int round) { - // for (auto &uses : wire.getUses()) { - // auto op = uses.getOwner(); - // if (op->hasAttr("round") && - // op->getAttrOfType("round").getInt()) - // continue; - // op->setAttr("round", IntegerAttr::get(mlir::IntegerType::get(), - // round)); for (auto result : op->getResults()) - // markRounds(wire, round + 1); - // } - // } - + // TODO: I think this could potentially be generalized nicely class WireStepper { Value old_wire; Value new_wire; @@ -99,6 +88,8 @@ class PhasePolynomialPreprocessPass clone->setOperand(opnum, new_wire); // For now, just copy over all classical constants + // TODO: make classical values arguments to the function instead, + // to allow non-constant rotation angles builder.setInsertionPointToStart(clone->getBlock()); for (size_t i = 0; i < clone->getNumOperands(); i++) { auto dependency = clone->getOperand(i); @@ -141,6 +132,9 @@ class PhasePolynomialPreprocessPass fun.setPrivate(); auto entry = fun.addEntryBlock(); OpBuilder builder(fun); + fun.getOperation()->setAttr("subcircuit", builder.getUnitAttr()); + fun.getOperation()->setAttr( + "num_cnots", builder.getUI32IntegerAttr(subcircuit->numCNots())); DenseMap cloned; @@ -182,11 +176,6 @@ class PhasePolynomialPreprocessPass } public: - // AXIS-SPECIFIC: could allow controlled y and z here - bool isControlledOp(Operation *op) { - return isa(op) && op->getNumOperands() == 2; - } - void runOnOperation() override { auto module = getOperation(); size_t i = 0; @@ -194,14 +183,15 @@ class PhasePolynomialPreprocessPass for (auto &op : module) { if (auto func = dyn_cast(op)) { - // TODO: this is yucky, having to rewalk the function because we're - // mutating it as we go func.walk([&](quake::XOp op) { - if (!isControlledOp(op) || ::processed(op)) + if (!::isControlledOp(op) || ::processed(op)) return; auto *subcircuit = new Subcircuit(op); moveToFunc(subcircuit, i++); + // Add the subcircuit to erase from the function after we + // finish walking it, as we don't want to erase ops from a + // function we are currently walking subcircuits.insert(subcircuit); }); } diff --git a/lib/Optimizer/Transforms/Subcircuit.cpp b/lib/Optimizer/Transforms/Subcircuit.cpp new file mode 100644 index 00000000000..93dced3a0b3 --- /dev/null +++ b/lib/Optimizer/Transforms/Subcircuit.cpp @@ -0,0 +1,153 @@ +/******************************************************************************* + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +/****************************************************************-*- C++ -*-**** + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +#include "Subcircuit.h" +#include "cudaq/Optimizer/Dialect/CC/CCOps.h" +#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" + +using namespace mlir; + +#define RAW(X) quake::X +#define RAW_MEASURE_OPS MEASURE_OPS(RAW) +#define RAW_GATE_OPS GATE_OPS(RAW) +#define RAW_QUANTUM_OPS QUANTUM_OPS(RAW) +// AXIS-SPECIFIC: Defines which operations break a circuit into subcircuits +#define CIRCUIT_BREAKERS(MACRO) \ + MACRO(YOp), MACRO(ZOp), MACRO(HOp), MACRO(R1Op), MACRO(RxOp), \ + MACRO(PhasedRxOp), MACRO(RyOp), MACRO(U2Op), MACRO(U3Op) +#define RAW_CIRCUIT_BREAKERS CIRCUIT_BREAKERS(RAW) + +unsigned calculateSkip(Operation *op) { + auto i = 0; + for (auto type : op->getOperandTypes()) { + if (isa(type)) + return i; + i++; + } + + return i; +} + +Value getNextOperand(Value v) { + auto result = dyn_cast(v); + auto op = result.getDefiningOp(); + auto skip = calculateSkip(op); + auto operandIDX = result.getResultNumber() + skip; + return op->getOperand(operandIDX); +} + +OpResult getNextResult(Value v) { + assert(v.hasOneUse()); + auto correspondingOperand = v.getUses().begin(); + auto op = correspondingOperand.getUser(); + auto skip = calculateSkip(op); + auto resultIDX = correspondingOperand.getOperand()->getOperandNumber() - skip; + return op->getResult(resultIDX); +} + +bool processed(Operation *op) { return op->hasAttr("processed"); } + +void markProcessed(Operation *op) { + op->setAttr("processed", OpBuilder(op).getUnitAttr()); +} + +// AXIS-SPECIFIC: could allow controlled y and z here +bool isControlledOp(Operation *op) { + return isa(op) && op->getNumOperands() == 2; +} + +bool isTerminationPoint(Operation *op) { + if (!isQuakeOperation(op)) + return true; + + if (isa(op)) + return true; + + if (isa(op)) + return true; + + auto opi = dyn_cast(op); + + if (!opi) + return true; + + // Only allow single control + if (opi.getControls().size() > 1) + return true; + return false; +} + +/// @brief Constructs a subcircuit with a phase polynomial starting from a +/// cnot +Subcircuit::Subcircuit(Operation *cnot) { + calculateInitialSubcircuit(cnot); + pruneSubcircuit(); + for (auto *op : ops) + markProcessed(op); + start = cnot; + + for (auto wire : termination_points) + if (isAfterTerminationPoint(wire)) + initial_wires.insert(wire); + else + terminal_wires.insert(wire); +} + +/// @brief Reconstructs a subcircuit from a subcircuit function +Subcircuit::Subcircuit(func::FuncOp subcircuit_func) { + // First, some validation + assert(subcircuit_func.getOperation()->hasAttr("subcircuit")); + assert(subcircuit_func.getBlocks().size() == 1); + auto &body_block = subcircuit_func.getRegion().getBlocks().front(); + // Construct the subcircuit + for (auto &op : body_block) { + auto *opp = &op; + if (opp == body_block.getTerminator()) + continue; + if (isa(op)) + continue; + assert(!isTerminationPoint(opp)); + ops.insert(opp); + } + // TODO: address possible constant args (and returns) + for (auto arg : body_block.getArguments()) + initial_wires.insert(arg); + for (auto ret : body_block.getTerminator()->getOperands()) + terminal_wires.insert(ret); +} + +SetVector Subcircuit::getInitialWires() { return initial_wires; } + +SetVector Subcircuit::getTerminalWires() { return terminal_wires; } + +bool Subcircuit::isInSubcircuit(Operation *op) { return ops.contains(op); } + +// TODO: would be nice to make Subcircuit iterable directly +SetVector Subcircuit::getOps() { return ops; } + +/// @brief returns the number of wires in the subcircuit +size_t Subcircuit::numWires() { return getInitialWires().size(); } + +/// @brief returns the number of two-qubit operations in the subcircuit +size_t Subcircuit::numCNots() { + size_t num_cnots = 0; + for (auto *op : ops) + if (isControlledOp(op)) + num_cnots++; + return num_cnots; +} + +Operation *Subcircuit::getStart() { return start; } diff --git a/lib/Optimizer/Transforms/Subcircuit.h b/lib/Optimizer/Transforms/Subcircuit.h index 62d73f7da1e..a171938089e 100644 --- a/lib/Optimizer/Transforms/Subcircuit.h +++ b/lib/Optimizer/Transforms/Subcircuit.h @@ -1,3 +1,4 @@ +#pragma once #include "cudaq/Optimizer/Dialect/CC/CCOps.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" @@ -14,39 +15,20 @@ using namespace mlir; MACRO(PhasedRxOp), MACRO(RyOp), MACRO(U2Op), MACRO(U3Op) #define RAW_CIRCUIT_BREAKERS CIRCUIT_BREAKERS(RAW) -unsigned calculateSkip(Operation *op) { - auto i = 0; - for (auto type : op->getOperandTypes()) { - if (isa(type)) - return i; - i++; - } +unsigned calculateSkip(Operation *op); + +Value getNextOperand(Value v); - return i; -} +OpResult getNextResult(Value v); -Value getNextOperand(Value v) { - auto result = dyn_cast(v); - auto op = result.getDefiningOp(); - auto skip = calculateSkip(op); - auto operandIDX = result.getResultNumber() + skip; - return op->getOperand(operandIDX); -} +bool processed(Operation *op); -OpResult getNextResult(Value v) { - assert(v.hasOneUse()); - auto correspondingOperand = v.getUses().begin(); - auto op = correspondingOperand.getUser(); - auto skip = calculateSkip(op); - auto resultIDX = correspondingOperand.getOperand()->getOperandNumber() - skip; - return op->getResult(resultIDX); -} +void markProcessed(Operation *op); -inline bool processed(Operation *op) { return op->hasAttr("processed"); } +// AXIS-SPECIFIC: could allow controlled y and z here +bool isControlledOp(Operation *op); -inline void markProcessed(Operation *op) { - op->setAttr("processed", OpBuilder(op).getUnitAttr()); -} +bool isTerminationPoint(Operation *op); class Subcircuit { protected: @@ -65,24 +47,6 @@ class Subcircuit { return isTerminationPoint(wire.getDefiningOp()); } - bool isTerminationPoint(Operation *op) { - if (isa(op)) - return true; - - if (isa(op)) - return true; - - auto opi = dyn_cast(op); - - if (!opi) - return true; - - // Only allow single control - if (opi.getControls().size() > 1) - return true; - return false; - } - void maybeAddAnchorPoint(Value v) { if (!seen.contains(v)) anchor_points.insert(v); @@ -96,7 +60,7 @@ class Subcircuit { } Operation *op = v.getUses().begin().getUser(); - if (isTerminationPoint(op) || processed(op)) { + if (isTerminationPoint(op)) { termination_points.insert(v); return; } @@ -128,7 +92,7 @@ class Subcircuit { seen.insert(v); Operation *op = v.getDefiningOp(); - if (isTerminationPoint(op) || processed(op)) { + if (isTerminationPoint(op)) { termination_points.insert(v); return; } @@ -221,31 +185,25 @@ class Subcircuit { public: /// @brief Constructs a subcircuit with a phase polynomial starting from a /// cnot - Subcircuit(Operation *cnot) { - calculateInitialSubcircuit(cnot); - pruneSubcircuit(); - for (auto *op : ops) - markProcessed(op); - start = cnot; - - for (auto wire : termination_points) - if (isAfterTerminationPoint(wire)) - initial_wires.insert(wire); - else - terminal_wires.insert(wire); - } + Subcircuit(Operation *cnot); - SetVector getInitialWires() { return initial_wires; } + /// @brief Reconstructs a subcircuit from a subcircuit function + Subcircuit(func::FuncOp subcircuit_func); - SetVector getTerminalWires() { return terminal_wires; } + SetVector getInitialWires(); - bool isInSubcircuit(Operation *op) { return ops.contains(op); } + SetVector getTerminalWires(); + + bool isInSubcircuit(Operation *op); // TODO: would be nice to make Subcircuit iterable directly - SetVector getOps() { return ops; } + SetVector getOps(); /// @brief returns the number of wires in the subcircuit - size_t numWires() { return getInitialWires().size(); } + size_t numWires(); + + /// @brief returns the number of two-qubit operations in the subcircuit + size_t numCNots(); - Operation *getStart() { return start; } + Operation *getStart(); }; From 986fceeb37fae3f6543446537165ca790a77bcf5 Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Fri, 4 Jul 2025 18:55:05 +0000 Subject: [PATCH 06/21] Initial working prototype of phase polynomial calculations and rotation merging Signed-off-by: Adam Geller --- include/cudaq/Optimizer/Transforms/Passes.td | 6 + lib/Optimizer/Transforms/CMakeLists.txt | 1 + .../PhasePolynomialRotationMerging.cpp | 286 ++++++++++++++++++ 3 files changed, 293 insertions(+) create mode 100644 lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp diff --git a/include/cudaq/Optimizer/Transforms/Passes.td b/include/cudaq/Optimizer/Transforms/Passes.td index b77cbeb1b61..363eefb7cfc 100644 --- a/include/cudaq/Optimizer/Transforms/Passes.td +++ b/include/cudaq/Optimizer/Transforms/Passes.td @@ -865,6 +865,12 @@ def ObserveAnsatz : Pass<"observe-ansatz", "mlir::func::FuncOp"> { ]; } +def PhasePolynomialRotationMerging: Pass<"phase-polynomial-rotation-merging", "mlir::func::FuncOp"> { + let summary = "Perform phase polynomial based rotation merging."; + + let dependentDialects = ["cudaq::cc::CCDialect", "quake::QuakeDialect"]; +} + def PhasePolynomialPreprocess: Pass<"phase-polynomial-preprocess", "mlir::ModuleOp"> { let summary = "Isolate subcircuits representable by a single phase polynomial."; diff --git a/lib/Optimizer/Transforms/CMakeLists.txt b/lib/Optimizer/Transforms/CMakeLists.txt index 5a944f6c10c..51f645b7c51 100644 --- a/lib/Optimizer/Transforms/CMakeLists.txt +++ b/lib/Optimizer/Transforms/CMakeLists.txt @@ -48,6 +48,7 @@ add_cudaq_library(OptTransforms DependencyAnalysis.cpp MultiControlDecomposition.cpp ObserveAnsatz.cpp + PhasePolynomialRotationMerging.cpp PhasePolynomialPreprocess.cpp PruneCtrlRelations.cpp PySynthCallableBlockArgs.cpp diff --git a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp new file mode 100644 index 00000000000..b4031550791 --- /dev/null +++ b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp @@ -0,0 +1,286 @@ +/******************************************************************************* + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +#include "PassDetails.h" +#include "Subcircuit.h" +#include "cudaq/Optimizer/Builder/Factory.h" +#include "cudaq/Optimizer/Dialect/CC/CCOps.h" +#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" +#include "cudaq/Optimizer/Transforms/Passes.h" +#include "mlir/Transforms/Passes.h" + +namespace cudaq::opt { +#define GEN_PASS_DEF_PHASEPOLYNOMIALROTATIONMERGING +#include "cudaq/Optimizer/Transforms/Passes.h.inc" +} // namespace cudaq::opt + +#define DEBUG_TYPE "phase-polynomial-rotation-merging" + +using namespace mlir; + +namespace { +class PhasePolynomialRotationMergingPass + : public cudaq::opt::impl::PhasePolynomialRotationMergingBase< + PhasePolynomialRotationMergingPass> { + using PhasePolynomialRotationMergingBase::PhasePolynomialRotationMergingBase; + + struct PhaseVariable { + public: + size_t idx; + // TODO: do we really need the initial_wire here? + // I think it's just useful for debugging + Value initial_wire; + PhaseVariable(size_t index, Value wire) : idx(index), initial_wire(wire) {} + + bool operator==(PhaseVariable other) { return idx == other.idx; } + }; + + class Phase { + SetVector vars; + bool isInverted; + + public: + Phase() : isInverted(false) {} + + Phase(PhaseVariable *var) : isInverted(false) { vars.insert(var); } + + bool operator==(Phase other) { + for (auto var : vars) + if (!other.vars.contains(var)) + return false; + for (auto var : other.vars) + if (!vars.contains(var)) + return false; + return isInverted == other.isInverted; + } + + static Phase *combine(Phase *p1, Phase *p2) { + Phase *new_phase = new Phase(); + for (auto var : p1->vars) + new_phase->vars.insert(var); + for (auto var : p2->vars) + if (new_phase->vars.contains(var)) + new_phase->vars.remove(var); + else + new_phase->vars.insert(var); + return new_phase; + } + + static Phase *invert(Phase *p1) { + Phase *new_phase = new Phase(); + for (auto var : p1->vars) + new_phase->vars.insert(var); + new_phase->isInverted = !p1->isInverted; + return new_phase; + } + + void dump() { + llvm::outs() << "Phase: "; + if (isInverted) + llvm::outs() << "!"; + llvm::outs() << "{"; + auto first = true; + for (auto var : vars) { + if (!first) + llvm::outs() << " ^ "; + llvm::outs() << var->idx; + first = false; + } + llvm::outs() << "}\n"; + } + + std::optional getIntRepresentation() { + int64_t sum = 0; + for (auto var : vars) { + if (var->idx > sizeof(int64_t) - 1) + return std::nullopt; + sum += 1 << var->idx; + } + } + }; + + class PhaseStorage { + // TODO: If SmallVector, no need for pointers here, so make not + // pointer to avoid memory leaks + SmallVector phases; + SmallVector rotations; + + void combineRotations(size_t prev_idx, quake::RzOp rzop) { + auto old_rzop = rotations[prev_idx]; + auto builder = OpBuilder(old_rzop); + auto rot_arg1 = old_rzop.getOperand(0); + auto rot_arg2 = rzop.getOperand(0); + auto new_rot_arg = + builder.create(old_rzop.getLoc(), rot_arg1, rot_arg2); + auto new_rot = builder.clone(*old_rzop.getOperation()); + new_rot->setOperand(0, new_rot_arg.getResult()); + old_rzop.getResult(0).replaceAllUsesWith(new_rot->getResult(0)); + old_rzop.erase(); + rzop.getResult(0).replaceAllUsesWith(rzop.getOperand(1)); + rzop.erase(); + } + + public: + /// @brief registers a new rotation op for the given phase + /// @returns true if the rotation was combined, false otherwise + bool addOrCombineRotationForPhase(quake::RzOp op, Phase *phase) { + for (size_t i = 0; i < phases.size(); i++) + if (*phases[i] == *phase) { + combineRotations(i, op); + return true; + } + + phases.push_back(phase); + rotations.push_back(op); + return false; + } + }; + + class PhaseStepper { + Value wire; + Subcircuit *subcircuit; + PhaseStorage *store; + Phase *current_phase; + + public: + class StepperContainer { + SmallVector steppers; + PhaseStorage *store; + SetVector vars; + + PhaseStepper *getStepperForValue(Value v) { + for (auto *stepper : steppers) + if (stepper->wire == v) + return stepper; + return nullptr; + } + + public: + // Caller is responsible for cleaning up circuit + StepperContainer(Subcircuit *circuit) { + store = new PhaseStorage(); + size_t i = 0; + for (auto wire : circuit->getInitialWires()) { + auto *new_var = new PhaseVariable(i++, wire); + // StepperContainer is responsible for cleaning up PhaseSteppers + steppers.push_back(new PhaseStepper(circuit, store, wire, new_var)); + // StepperContainer is responsible for cleaning up PhaseVariables + vars.insert(new_var); + } + } + + ~StepperContainer() { + delete store; + for (auto stepper : steppers) + delete stepper; + for (auto var : vars) + delete var; + } + + bool isStopped() { + for (auto *stepper : steppers) + if (!stepper->isStopped()) + return false; + return true; + } + + void stepAll() { + if (isStopped()) + return; + for (auto *stepper : steppers) + stepper->step(this); + } + + std::optional + maybeGetControlPhase(quake::OperatorInterface opi) { + assert(isControlledOp(opi.getOperation())); + auto control = opi.getControls().front(); + auto *stepper = getStepperForValue(control); + if (stepper) + return stepper->current_phase; + return std::nullopt; + } + + bool targetVisited(quake::OperatorInterface opi) { + assert(isControlledOp(opi.getOperation())); + auto next_result = getNextResult(opi.getTarget(0)); + // Wait until target wire stepper steps to ensure + // control phase is available + if (getStepperForValue(next_result)) + return true; + return false; + } + }; + + PhaseStepper(Subcircuit *circuit, PhaseStorage *store, Value initial, + PhaseVariable *var) { + subcircuit = circuit; + this->store = store; + wire = initial; + current_phase = new Phase(var); + } + + bool isStopped() { + Operation *op = *wire.getUsers().begin(); + assert(op); + return isTerminationPoint(op); + } + + void step(StepperContainer *container) { + if (isStopped()) + return; + assert(wire.hasOneUse()); + + Operation *op = *wire.getUsers().begin(); + assert(op); + auto opi = dyn_cast(op); + assert(opi); + + if (isControlledOp(op)) { + // Controlled not, and we are the target, so update phase + if (opi.getTarget(0) == wire) { + auto phase_opt = container->maybeGetControlPhase(opi); + // Wait until we have the phase for the other wire + if (!phase_opt.has_value()) + return; + current_phase = Phase::combine(current_phase, phase_opt.value()); + } else { + // Wait until the target has visited the operation so it can + // access our phase (the control phase) + if (!container->targetVisited(opi)) + return; + } + } else if (isa(op) && opi.getControls().size() == 0) { + // Simple not, invert phase + // AXIS-SPECIFIC: Would want to handle y and z gates here too + current_phase = Phase::invert(current_phase); + } else if (auto rzop = dyn_cast(op)) { + if (store->addOrCombineRotationForPhase(rzop, current_phase)) + return; + } + + wire = getNextResult(wire); + } + }; + +public: + void runOnOperation() override { + auto func = getOperation(); + + if (!func.getOperation()->hasAttr("subcircuit")) + return; + + auto subcircuit = new Subcircuit(func); + auto container = PhaseStepper::StepperContainer(subcircuit); + + while (!container.isStopped()) + container.stepAll(); + delete subcircuit; + } +}; +} // namespace From 7f110a9b6555ee364f70b50b36c23f9db1436b7d Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Fri, 4 Jul 2025 19:39:57 +0000 Subject: [PATCH 07/21] Make classical values arguments to subcircuit functions to allow non-constant rotation angles Signed-off-by: Adam Geller --- .../Transforms/PhasePolynomialPreprocess.cpp | 27 +++++++++++-------- lib/Optimizer/Transforms/Subcircuit.cpp | 6 ++--- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp index 2d77a3e966b..a971e025e95 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -50,7 +50,8 @@ class PhasePolynomialPreprocessPass Value getOldWire() { return old_wire; } - void step(DenseMap &cloned, OpBuilder &builder) { + void step(DenseMap &cloned, OpBuilder &builder, + std::function addFuncArg) { if (isStopped()) return; @@ -87,17 +88,14 @@ class PhasePolynomialPreprocessPass auto clone = builder.clone(*op); clone->setOperand(opnum, new_wire); - // For now, just copy over all classical constants - // TODO: make classical values arguments to the function instead, + // Make classical values arguments to the function, // to allow non-constant rotation angles builder.setInsertionPointToStart(clone->getBlock()); for (size_t i = 0; i < clone->getNumOperands(); i++) { auto dependency = clone->getOperand(i); if (!isa(dependency.getType())) { - auto dop = dependency.getDefiningOp(); - assert(isa(dop)); - auto clone_dop = builder.clone(*dop); - clone->setOperand(i, clone_dop->getResult(0)); + auto new_arg = addFuncArg(dependency); + clone->setOperand(i, new_arg); } } builder.setInsertionPointAfter(clone); @@ -139,21 +137,28 @@ class PhasePolynomialPreprocessPass DenseMap cloned; // Need to keep ordering to match returns with arguments - SmallVector wires_in; + SmallVector args; SmallVector steppers; for (auto wire : subcircuit->getInitialWires()) { - wires_in.push_back(wire); + args.push_back(wire); steppers.push_back( new WireStepper(subcircuit, wire, fun.getArgument(steppers.size()))); } + auto add_arg_fun = [&](Value v) { + auto idx = args.size(); + args.push_back(v); + fun.insertArgument(idx, v.getType(), {}, v.getDefiningOp()->getLoc()); + return fun.getArgument(idx); + }; + builder.setInsertionPointToStart(entry); while (true) { auto stepped = false; for (auto stepper : steppers) { if (!stepper->isStopped()) stepped = true; - stepper->step(cloned, builder); + stepper->step(cloned, builder, add_arg_fun); } if (!stepped) @@ -170,7 +175,7 @@ class PhasePolynomialPreprocessPass builder.setInsertionPointAfter(cnot); auto call = builder.create(cnot->getLoc(), types, - fun.getSymNameAttr(), wires_in); + fun.getSymNameAttr(), args); for (size_t i = 0; i < steppers.size(); i++) steppers[i]->getOldWire().replaceAllUsesWith(call.getResult(i)); } diff --git a/lib/Optimizer/Transforms/Subcircuit.cpp b/lib/Optimizer/Transforms/Subcircuit.cpp index 93dced3a0b3..8f51e33e28c 100644 --- a/lib/Optimizer/Transforms/Subcircuit.cpp +++ b/lib/Optimizer/Transforms/Subcircuit.cpp @@ -117,14 +117,12 @@ Subcircuit::Subcircuit(func::FuncOp subcircuit_func) { auto *opp = &op; if (opp == body_block.getTerminator()) continue; - if (isa(op)) - continue; assert(!isTerminationPoint(opp)); ops.insert(opp); } - // TODO: address possible constant args (and returns) for (auto arg : body_block.getArguments()) - initial_wires.insert(arg); + if (isa(arg.getType())) + initial_wires.insert(arg); for (auto ret : body_block.getTerminator()->getOperands()) terminal_wires.insert(ret); } From 70b369b2952b97e23d21189140d9b8937fdf7d9a Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Mon, 7 Jul 2025 21:41:54 +0000 Subject: [PATCH 08/21] Some cleanup Signed-off-by: Adam Geller --- lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp | 4 ++-- lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp | 1 + lib/Optimizer/Transforms/Subcircuit.cpp | 5 +++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp index a971e025e95..5d7044a67c8 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -57,9 +57,9 @@ class PhasePolynomialPreprocessPass // TODO: Something more elegant here would be nice Operation *op = nullptr; - auto opnum = -1; + size_t opnum = 0; for (auto &use : old_wire.getUses()) { - if (use.getOwner()->getBlock() == builder.getInsertionBlock()) + if (!subcircuit->getOps().contains(use.getOwner())) continue; op = use.getOwner(); opnum = use.getOperandNumber(); diff --git a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp index b4031550791..780743a14ec 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp @@ -117,6 +117,7 @@ class PhasePolynomialRotationMergingPass auto rot_arg2 = rzop.getOperand(0); auto new_rot_arg = builder.create(old_rzop.getLoc(), rot_arg1, rot_arg2); + // TODO: Can replace operand 0 directly rather than cloning first? auto new_rot = builder.clone(*old_rzop.getOperation()); new_rot->setOperand(0, new_rot_arg.getResult()); old_rzop.getResult(0).replaceAllUsesWith(new_rot->getResult(0)); diff --git a/lib/Optimizer/Transforms/Subcircuit.cpp b/lib/Optimizer/Transforms/Subcircuit.cpp index 8f51e33e28c..1be5c6fa7e3 100644 --- a/lib/Optimizer/Transforms/Subcircuit.cpp +++ b/lib/Optimizer/Transforms/Subcircuit.cpp @@ -70,6 +70,11 @@ bool isControlledOp(Operation *op) { } bool isTerminationPoint(Operation *op) { + // TODO: it may be cleaner to only accept non-null input to + // ensure the null case is explicitly handled by users + if (!op) + return true; + if (!isQuakeOperation(op)) return true; From 69de2ccd012f97111c837283d822e782224f188f Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Mon, 7 Jul 2025 21:42:22 +0000 Subject: [PATCH 09/21] Remove incorrectly placed braces Signed-off-by: Adam Geller --- lib/Optimizer/Transforms/Subcircuit.h | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/lib/Optimizer/Transforms/Subcircuit.h b/lib/Optimizer/Transforms/Subcircuit.h index a171938089e..0158667112f 100644 --- a/lib/Optimizer/Transforms/Subcircuit.h +++ b/lib/Optimizer/Transforms/Subcircuit.h @@ -162,13 +162,12 @@ class Subcircuit { // { pruneSubcircuit(opi.getWires()[1]); return; // } - for (auto result : op->getResults()) { + for (auto result : op->getResults()) pruneWire(result); - // Adjust termination border - for (auto operand : op->getOperands()) - if (ops.contains(operand.getDefiningOp())) - termination_points.insert(operand); - } + // Adjust termination border + for (auto operand : op->getOperands()) + if (ops.contains(operand.getDefiningOp())) + termination_points.insert(operand); } void pruneSubcircuit() { From 89277a5aedf8a7272e32d77303ea5398d7bcc536 Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Mon, 7 Jul 2025 23:56:58 +0000 Subject: [PATCH 10/21] Add tests for circuit breaking ops in preprocessing Signed-off-by: Adam Geller --- test/Quake/phase_polynomial_preprocess.qke | 203 ++++++++++++++++++ .../phase_polynomial_rotation_merging.qke | 51 +++++ 2 files changed, 254 insertions(+) create mode 100644 test/Quake/phase_polynomial_preprocess.qke create mode 100644 test/Quake/phase_polynomial_rotation_merging.qke diff --git a/test/Quake/phase_polynomial_preprocess.qke b/test/Quake/phase_polynomial_preprocess.qke new file mode 100644 index 00000000000..48a6d5ecb4f --- /dev/null +++ b/test/Quake/phase_polynomial_preprocess.qke @@ -0,0 +1,203 @@ +// ========================================================================== // +// Copyright (c) 2025 NVIDIA Corporation & Affiliates. // +// All rights reserved. // +// // +// This source code and the accompanying materials are made available under // +// the terms of the Apache License 2.0 which accompanies this distribution. // +// ========================================================================== // + +// RUN: cudaq-opt --phase-polynomial-preprocess -split-input-file %s | FileCheck %s + +func.func @kernel1() { + %cst = arith.constant 1.000000e+00 : f64 + %0 = quake.null_wire + %1 = quake.null_wire + %2:2 = quake.x [%0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %3 = quake.h %2#0 : (!quake.wire) -> !quake.wire + %4 = quake.rz (%cst) %2#1 : (f64, !quake.wire) -> !quake.wire + %5:2 = quake.x [%3] %4 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + return +} + +// CHECK: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: cc.return %0#0, %0#1 : !quake.wire, !quake.wire +// CHECK: } +// CHECK: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %1 = quake.rz (%arg2) %0#1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: cc.return %0#0, %1 : !quake.wire, !quake.wire +// CHECK: } +// CHECK: func.func @kernel1() { +// CHECK: %cst = arith.constant 1.000000e+00 : f64 +// CHECK: %0 = quake.null_wire +// CHECK: %1 = quake.null_wire +// CHECK: %2:2 = call @subcircuit0(%0, %1, %cst) : (!quake.wire, !quake.wire, f64) -> (!quake.wire, !quake.wire) +// CHECK: %3 = quake.h %2#0 : (!quake.wire) -> !quake.wire +// CHECK: %4:2 = call @subcircuit1(%3, %2#1) : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: return +// CHECK: } + +// ----- + +func.func @kernel2() { + %cst = arith.constant 1.000000e+00 : f64 + %0 = quake.null_wire + %1 = quake.null_wire + %2 = quake.null_wire + %3 = quake.h %0 : (!quake.wire) -> !quake.wire + %4 = quake.h %1 : (!quake.wire) -> !quake.wire + %5 = quake.h %2 : (!quake.wire) -> !quake.wire + %6 = quake.rz (%cst) %4 : (f64, !quake.wire) -> !quake.wire + %7 = quake.rz (%cst) %5 : (f64, !quake.wire) -> !quake.wire + %8:2 = quake.x [%6] %3 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %9 = quake.rz (%cst) %8#1 : (f64, !quake.wire) -> !quake.wire + %10:2 = quake.x [%8#0] %7 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %11:2 = quake.x [%9] %10#0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %12 = quake.h %10#1 : (!quake.wire) -> !quake.wire + %13:2 = quake.x [%11#1] %12 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %14:2 = quake.x [%11#0] %13#0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %15 = quake.rz (%cst) %14#1 : (f64, !quake.wire) -> !quake.wire + %16 = quake.h %14#0 : (!quake.wire) -> !quake.wire + %17 = quake.h %15 : (!quake.wire) -> !quake.wire + return +} + +// CHECK: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %1:2 = quake.x [%arg2] %0#0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %2 = quake.rz (%arg3) %1#1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: cc.return %2, %0#1, %1#0 : !quake.wire, !quake.wire, !quake.wire +// CHECK: } +// CHECK: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64, %arg5: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { +// CHECK-DAG: %0 = quake.rz (%arg3) %arg0 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK-DAG: %1 = quake.rz (%arg4) %arg1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %2:2 = quake.x [%0] %arg2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %3:2 = quake.x [%2#0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %4 = quake.rz (%arg5) %2#1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %5:2 = quake.x [%4] %3#0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: cc.return %5#1, %3#1, %5#0 : !quake.wire, !quake.wire, !quake.wire +// CHECK: } +// CHECK: func.func @kernel2() { +// CHECK: %cst = arith.constant 1.000000e+00 : f64 +// CHECK: %0 = quake.null_wire +// CHECK: %1 = quake.null_wire +// CHECK: %2 = quake.null_wire +// CHECK: %3 = quake.h %0 : (!quake.wire) -> !quake.wire +// CHECK: %4 = quake.h %1 : (!quake.wire) -> !quake.wire +// CHECK: %5 = quake.h %2 : (!quake.wire) -> !quake.wire +// CHECK: %6:3 = call @subcircuit0(%4, %5, %3, %cst, %cst, %cst) : (!quake.wire, !quake.wire, !quake.wire, f64, f64, f64) -> (!quake.wire, !quake.wire, !quake.wire) +// CHECK: %7 = quake.h %6#1 : (!quake.wire) -> !quake.wire +// CHECK: %8:3 = call @subcircuit1(%6#0, %7, %6#2, %cst) : (!quake.wire, !quake.wire, !quake.wire, f64) -> (!quake.wire, !quake.wire, !quake.wire) +// CHECK: %9 = quake.h %8#2 : (!quake.wire) -> !quake.wire +// CHECK: %10 = quake.h %8#0 : (!quake.wire) -> !quake.wire +// CHECK: return +// CHECK: } + +// ----- + +func.func @kernel3() { + %cst = arith.constant 1.000000e+00 : f64 + %true = arith.constant true + %0 = quake.null_wire + %1 = quake.null_wire + %2:2 = quake.x [%0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %3 = quake.h %2#0 : (!quake.wire) -> !quake.wire + %4:2 = cc.if(%true) ((%arg0 = %2#1, %arg1 = %3)) -> (!quake.wire, !quake.wire) { + %5 = quake.rz (%cst) %arg0 : (f64, !quake.wire) -> !quake.wire + %6:2 = quake.x [%arg1] %5 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + cc.continue %6#0, %6#1 : !quake.wire, !quake.wire + } else { + cc.continue %arg0, %arg1 : !quake.wire, !quake.wire + } + return +} + +// CHECK-LABEL: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { +// CHECK: %0 = quake.rz (%arg2) %arg1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %1:2 = quake.x [%arg0] %0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: cc.return %1#0, %1#1 : !quake.wire, !quake.wire +// CHECK: } +// CHECK-LABEL: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: cc.return %0#0, %0#1 : !quake.wire, !quake.wire +// CHECK: } +// CHECK-LABEL: func.func @kernel3() { +// CHECK: %cst = arith.constant 1.000000e+00 : f64 +// CHECK: %true = arith.constant true +// CHECK: %0 = quake.null_wire +// CHECK: %1 = quake.null_wire +// CHECK: %2:2 = call @subcircuit0(%0, %1) : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %3 = quake.h %2#0 : (!quake.wire) -> !quake.wire +// CHECK: %4:2 = cc.if(%true) ((%arg0 = %2#1, %arg1 = %3)) -> (!quake.wire, !quake.wire) { +// CHECK: %5:2 = func.call @subcircuit1(%arg1, %arg0, %cst) : (!quake.wire, !quake.wire, f64) -> (!quake.wire, !quake.wire) +// CHECK: cc.continue %5#0, %5#1 : !quake.wire, !quake.wire +// CHECK: } else { +// CHECK: cc.continue %arg0, %arg1 : !quake.wire, !quake.wire +// CHECK: } +// CHECK: return +// CHECK: } + +// ----- + +func.func @kernel4() { + %cst = arith.constant 1.000000e+00 : f64 + %0 = quake.null_wire + %1 = quake.null_wire + %2 = quake.null_wire + %3 = quake.rz (%cst) %0 : (f64, !quake.wire) -> !quake.wire + %4:3 = quake.x [%1, %2] %3 : (!quake.wire, !quake.wire, !quake.wire) -> (!quake.wire, !quake.wire, !quake.wire) + %5:3 = quake.x [%4#0, %4#1] %4#2 : (!quake.wire, !quake.wire, !quake.wire) -> (!quake.wire, !quake.wire, !quake.wire) + %6 = quake.rz (%cst) %5#2 : (f64, !quake.wire) -> !quake.wire + %7:2 = quake.x [%5#0] %5#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + return +} + +// CHECK-LABEL: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: cc.return %0#0, %0#1 : !quake.wire, !quake.wire +// CHECK: } +// CHECK-LABEL: func.func @kernel4() { +// CHECK: %cst = arith.constant 1.000000e+00 : f64 +// CHECK: %0 = quake.null_wire +// CHECK: %1 = quake.null_wire +// CHECK: %2 = quake.null_wire +// CHECK: %3 = quake.rz (%cst) %0 : (f64, !quake.wire) -> !quake.wire +// CHECK: %4:3 = quake.x [%1, %2] %3 : (!quake.wire, !quake.wire, !quake.wire) -> (!quake.wire, !quake.wire, !quake.wire) +// CHECK: %5:3 = quake.x [%4#0, %4#1] %4#2 : (!quake.wire, !quake.wire, !quake.wire) -> (!quake.wire, !quake.wire, !quake.wire) +// CHECK: %6 = quake.rz (%cst) %5#2 : (f64, !quake.wire) -> !quake.wire +// CHECK: %7:2 = call @subcircuit0(%5#0, %5#1) : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: return +// CHECK: } + +// ----- + +func.func @kernel5() { + %cst = arith.constant 1.000000e+00 : f64 + %0 = quake.null_wire + %1 = quake.null_wire + %2:2 = quake.x [%0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %3 = quake.ry (%cst) %2#0 : (f64, !quake.wire) -> !quake.wire + %4 = quake.rz (%cst) %2#1 : (f64, !quake.wire) -> !quake.wire + %5:2 = quake.x [%3] %4 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + return +} + +// CHECK-LABEL: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: cc.return %0#0, %0#1 : !quake.wire, !quake.wire +// CHECK: } +// CHECK-LABEL: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %1 = quake.rz (%arg2) %0#1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: cc.return %0#0, %1 : !quake.wire, !quake.wire +// CHECK: } +// CHECK-LABEL: func.func @kernel5() { +// CHECK: %cst = arith.constant 1.000000e+00 : f64 +// CHECK: %0 = quake.null_wire +// CHECK: %1 = quake.null_wire +// CHECK: %2:2 = call @subcircuit0(%0, %1, %cst) : (!quake.wire, !quake.wire, f64) -> (!quake.wire, !quake.wire) +// CHECK: %3 = quake.ry (%cst) %2#0 : (f64, !quake.wire) -> !quake.wire +// CHECK: %4:2 = call @subcircuit1(%3, %2#1) : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: return +// CHECK: } \ No newline at end of file diff --git a/test/Quake/phase_polynomial_rotation_merging.qke b/test/Quake/phase_polynomial_rotation_merging.qke new file mode 100644 index 00000000000..d9518e50d79 --- /dev/null +++ b/test/Quake/phase_polynomial_rotation_merging.qke @@ -0,0 +1,51 @@ +// ========================================================================== // +// Copyright (c) 2025 NVIDIA Corporation & Affiliates. // +// All rights reserved. // +// // +// This source code and the accompanying materials are made available under // +// the terms of the Apache License 2.0 which accompanies this distribution. // +// ========================================================================== // + +// RUN: cudaq-opt --phase-polynomial-rotation-merging -split-input-file %s | FileCheck %s + +func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { + %0 = quake.rz (%arg2) %arg1 : (f64, !quake.wire) -> !quake.wire + %1:2 = quake.x [%arg0] %0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %2 = quake.rz (%arg2) %1#1 : (f64, !quake.wire) -> !quake.wire + %3:2 = quake.x [%1#0] %2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %4 = quake.rz (%arg2) %3#0 : (f64, !quake.wire) -> !quake.wire + %5 = quake.rz (%arg2) %3#1 : (f64, !quake.wire) -> !quake.wire + %6:2 = quake.x [%5] %4 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + cc.return %6#1, %6#0 : !quake.wire, !quake.wire +} + +// CHECK: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { +// CHECK: %[[VAL_0:.*]] = arith.addf %arg2, %arg2 : f64 +// CHECK: %[[VAL_1:.*]] = quake.rz (%[[VAL_0]]) %arg1 : (f64, !quake.wire) -> !quake.wire +// CHECK: %[[VAL_2:.*]]:2 = quake.x [%arg0] %[[VAL_1]] : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %[[VAL_3:.*]] = quake.rz (%arg2) %[[VAL_2]]#1 : (f64, !quake.wire) -> !quake.wire +// CHECK: %[[VAL_4:.*]]:2 = quake.x [%[[VAL_2]]#0] %[[VAL_3]] : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %[[VAL_5:.*]] = quake.rz (%arg2) %[[VAL_4]]#0 : (f64, !quake.wire) -> !quake.wire +// CHECK: %[[VAL_6:.*]]:2 = quake.x [%[[VAL_4]]#1] %[[VAL_5]] : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: cc.return %[[VAL_6]]#1, %[[VAL_6]]#0 : !quake.wire, !quake.wire +// CHECK: } + +// ----- + +func.func private @subcircuit1(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3: f64) -> !quake.wire attributes {num_cnots = 0 : ui32, subcircuit} { + %0 = quake.rz (%arg1) %arg0 : (f64, !quake.wire) -> !quake.wire + %1 = quake.x %0 : (!quake.wire) -> !quake.wire + %2 = quake.rz (%arg2) %1 : (f64, !quake.wire) -> !quake.wire + %3 = quake.x %2 : (!quake.wire) -> !quake.wire + %4 = quake.rz (%arg3) %3 : (f64, !quake.wire) -> !quake.wire + cc.return %4 : !quake.wire +} + +// CHECK: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3: f64) -> !quake.wire attributes {num_cnots = 0 : ui32, subcircuit} { +// CHECK: %[[VAL_0:.*]] = arith.addf %arg1, %arg3 : f64 +// CHECK: %[[VAL_1:.*]] = quake.rz (%[[VAL_0]]) %arg0 : (f64, !quake.wire) -> !quake.wire +// CHECK: %[[VAL_2:.*]] = quake.x %[[VAL_1]] : (!quake.wire) -> !quake.wire +// CHECK: %[[VAL_3:.*]] = quake.rz (%arg2) %[[VAL_2]] : (f64, !quake.wire) -> !quake.wire +// CHECK: %[[VAL_4:.*]] = quake.x %[[VAL_3]] : (!quake.wire) -> !quake.wire +// CHECK: cc.return %[[VAL_4]] : !quake.wire +// CHECK: } From 753bbb8f5f735ff6dbf9e7596051b9f40addadba Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Mon, 7 Jul 2025 23:58:13 +0000 Subject: [PATCH 11/21] Don't run phase polynomial based rotation merging opt on illegal circuits, instead of failing Signed-off-by: Adam Geller --- .../PhasePolynomialRotationMerging.cpp | 4 +- lib/Optimizer/Transforms/Subcircuit.cpp | 22 ++++++---- lib/Optimizer/Transforms/Subcircuit.h | 4 +- .../phase_polynomial_rotation_merging.qke | 42 +++++++++++++++++++ 4 files changed, 63 insertions(+), 9 deletions(-) diff --git a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp index 780743a14ec..fc4c3ad9f6f 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp @@ -276,7 +276,9 @@ class PhasePolynomialRotationMergingPass if (!func.getOperation()->hasAttr("subcircuit")) return; - auto subcircuit = new Subcircuit(func); + auto subcircuit = Subcircuit::constructFromFunc(func); + if (!subcircuit) + return; auto container = PhaseStepper::StepperContainer(subcircuit); while (!container.isStopped()) diff --git a/lib/Optimizer/Transforms/Subcircuit.cpp b/lib/Optimizer/Transforms/Subcircuit.cpp index 1be5c6fa7e3..704e95dcc85 100644 --- a/lib/Optimizer/Transforms/Subcircuit.cpp +++ b/lib/Optimizer/Transforms/Subcircuit.cpp @@ -112,24 +112,32 @@ Subcircuit::Subcircuit(Operation *cnot) { } /// @brief Reconstructs a subcircuit from a subcircuit function -Subcircuit::Subcircuit(func::FuncOp subcircuit_func) { +/// @returns A newly allocated subcircuit if the function defines +/// a valid subcircuit, `nullptr` otherwise. +Subcircuit *Subcircuit::constructFromFunc(func::FuncOp subcircuit_func) { // First, some validation - assert(subcircuit_func.getOperation()->hasAttr("subcircuit")); - assert(subcircuit_func.getBlocks().size() == 1); + if (!subcircuit_func.getOperation()->hasAttr("subcircuit")) + return nullptr; + if (subcircuit_func.getBlocks().size() != 1) + return nullptr; auto &body_block = subcircuit_func.getRegion().getBlocks().front(); + auto subcircuit = new Subcircuit(); // Construct the subcircuit for (auto &op : body_block) { auto *opp = &op; if (opp == body_block.getTerminator()) continue; - assert(!isTerminationPoint(opp)); - ops.insert(opp); + // Ensure circuit only contains valid operations + if (isTerminationPoint(opp)) + return nullptr; + subcircuit->ops.insert(opp); } for (auto arg : body_block.getArguments()) if (isa(arg.getType())) - initial_wires.insert(arg); + subcircuit->initial_wires.insert(arg); for (auto ret : body_block.getTerminator()->getOperands()) - terminal_wires.insert(ret); + subcircuit->terminal_wires.insert(ret); + return subcircuit; } SetVector Subcircuit::getInitialWires() { return initial_wires; } diff --git a/lib/Optimizer/Transforms/Subcircuit.h b/lib/Optimizer/Transforms/Subcircuit.h index 0158667112f..af8b05be66d 100644 --- a/lib/Optimizer/Transforms/Subcircuit.h +++ b/lib/Optimizer/Transforms/Subcircuit.h @@ -181,13 +181,15 @@ class Subcircuit { } } + Subcircuit() {} + public: /// @brief Constructs a subcircuit with a phase polynomial starting from a /// cnot Subcircuit(Operation *cnot); /// @brief Reconstructs a subcircuit from a subcircuit function - Subcircuit(func::FuncOp subcircuit_func); + static Subcircuit *constructFromFunc(func::FuncOp subcircuit_func); SetVector getInitialWires(); diff --git a/test/Quake/phase_polynomial_rotation_merging.qke b/test/Quake/phase_polynomial_rotation_merging.qke index d9518e50d79..48592f41c2e 100644 --- a/test/Quake/phase_polynomial_rotation_merging.qke +++ b/test/Quake/phase_polynomial_rotation_merging.qke @@ -49,3 +49,45 @@ func.func private @subcircuit1(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3 // CHECK: %[[VAL_4:.*]] = quake.x %[[VAL_3]] : (!quake.wire) -> !quake.wire // CHECK: cc.return %[[VAL_4]] : !quake.wire // CHECK: } + +// ----- + +// Invalid subcircuit functions shouldn't be touched + +func.func private @subcircuit2(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3: f64) -> !quake.wire attributes {num_cnots = 0 : ui32, subcircuit} { + %0 = quake.rz (%arg1) %arg0 : (f64, !quake.wire) -> !quake.wire + %1 = quake.x %0 : (!quake.wire) -> !quake.wire + %2 = quake.ry (%arg2) %1 : (f64, !quake.wire) -> !quake.wire + %3 = quake.x %2 : (!quake.wire) -> !quake.wire + %4 = quake.rz (%arg3) %3 : (f64, !quake.wire) -> !quake.wire + cc.return %4 : !quake.wire +} + +// CHECK-LABEL: func.func private @subcircuit2(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3: f64) -> !quake.wire attributes {num_cnots = 0 : ui32, subcircuit} { +// CHECK: %0 = quake.rz (%arg1) %arg0 : (f64, !quake.wire) -> !quake.wire +// CHECK: %1 = quake.x %0 : (!quake.wire) -> !quake.wire +// CHECK: %2 = quake.ry (%arg2) %1 : (f64, !quake.wire) -> !quake.wire +// CHECK: %3 = quake.x %2 : (!quake.wire) -> !quake.wire +// CHECK: %4 = quake.rz (%arg3) %3 : (f64, !quake.wire) -> !quake.wire +// CHECK: cc.return %4 : !quake.wire +// CHECK: } + +// ----- + +func.func private @subcircuit3(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3: f64) -> !quake.wire attributes {num_cnots = 0 : ui32} { + %0 = quake.rz (%arg1) %arg0 : (f64, !quake.wire) -> !quake.wire + %1 = quake.x %0 : (!quake.wire) -> !quake.wire + %2 = quake.rz (%arg2) %1 : (f64, !quake.wire) -> !quake.wire + %3 = quake.x %2 : (!quake.wire) -> !quake.wire + %4 = quake.rz (%arg3) %3 : (f64, !quake.wire) -> !quake.wire + cc.return %4 : !quake.wire +} + +// CHECK-LABEL: func.func private @subcircuit3(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3: f64) -> !quake.wire attributes {num_cnots = 0 : ui32} { +// CHECK: %0 = quake.rz (%arg1) %arg0 : (f64, !quake.wire) -> !quake.wire +// CHECK: %1 = quake.x %0 : (!quake.wire) -> !quake.wire +// CHECK: %2 = quake.rz (%arg2) %1 : (f64, !quake.wire) -> !quake.wire +// CHECK: %3 = quake.x %2 : (!quake.wire) -> !quake.wire +// CHECK: %4 = quake.rz (%arg3) %3 : (f64, !quake.wire) -> !quake.wire +// CHECK: cc.return %4 : !quake.wire +// CHECK: } From 3fd0e655c60ee374ea1bf8d1f6b106cff721e07a Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Tue, 8 Jul 2025 20:09:16 +0000 Subject: [PATCH 12/21] Avoid some possible memory leaks Signed-off-by: Adam Geller --- .../Transforms/PhasePolynomialPreprocess.cpp | 3 ++ .../PhasePolynomialRotationMerging.cpp | 41 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp index 5d7044a67c8..ddd7197c201 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -178,6 +178,9 @@ class PhasePolynomialPreprocessPass fun.getSymNameAttr(), args); for (size_t i = 0; i < steppers.size(); i++) steppers[i]->getOldWire().replaceAllUsesWith(call.getResult(i)); + + for (auto stepper : steppers) + delete stepper; } public: diff --git a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp index fc4c3ad9f6f..618119b5e71 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp @@ -59,23 +59,23 @@ class PhasePolynomialRotationMergingPass return isInverted == other.isInverted; } - static Phase *combine(Phase *p1, Phase *p2) { - Phase *new_phase = new Phase(); - for (auto var : p1->vars) - new_phase->vars.insert(var); - for (auto var : p2->vars) - if (new_phase->vars.contains(var)) - new_phase->vars.remove(var); + static Phase combine(Phase &p1, Phase &p2) { + Phase new_phase = Phase(); + for (auto var : p1.vars) + new_phase.vars.insert(var); + for (auto var : p2.vars) + if (new_phase.vars.contains(var)) + new_phase.vars.remove(var); else - new_phase->vars.insert(var); + new_phase.vars.insert(var); return new_phase; } - static Phase *invert(Phase *p1) { - Phase *new_phase = new Phase(); - for (auto var : p1->vars) - new_phase->vars.insert(var); - new_phase->isInverted = !p1->isInverted; + static Phase invert(Phase &p1) { + auto new_phase = Phase(); + for (auto var : p1.vars) + new_phase.vars.insert(var); + new_phase.isInverted = !p1.isInverted; return new_phase; } @@ -105,9 +105,7 @@ class PhasePolynomialRotationMergingPass }; class PhaseStorage { - // TODO: If SmallVector, no need for pointers here, so make not - // pointer to avoid memory leaks - SmallVector phases; + SmallVector phases; SmallVector rotations; void combineRotations(size_t prev_idx, quake::RzOp rzop) { @@ -129,9 +127,9 @@ class PhasePolynomialRotationMergingPass public: /// @brief registers a new rotation op for the given phase /// @returns true if the rotation was combined, false otherwise - bool addOrCombineRotationForPhase(quake::RzOp op, Phase *phase) { + bool addOrCombineRotationForPhase(quake::RzOp op, Phase phase) { for (size_t i = 0; i < phases.size(); i++) - if (*phases[i] == *phase) { + if (phases[i] == phase) { combineRotations(i, op); return true; } @@ -146,7 +144,7 @@ class PhasePolynomialRotationMergingPass Value wire; Subcircuit *subcircuit; PhaseStorage *store; - Phase *current_phase; + Phase current_phase; public: class StepperContainer { @@ -197,8 +195,7 @@ class PhasePolynomialRotationMergingPass stepper->step(this); } - std::optional - maybeGetControlPhase(quake::OperatorInterface opi) { + std::optional maybeGetControlPhase(quake::OperatorInterface opi) { assert(isControlledOp(opi.getOperation())); auto control = opi.getControls().front(); auto *stepper = getStepperForValue(control); @@ -223,7 +220,7 @@ class PhasePolynomialRotationMergingPass subcircuit = circuit; this->store = store; wire = initial; - current_phase = new Phase(var); + current_phase = Phase(var); } bool isStopped() { From 5b39536b6b3a3590a7e4c676f4191f42413c1aa5 Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Tue, 8 Jul 2025 21:20:37 +0000 Subject: [PATCH 13/21] Track inversion when combining phases, handle swaps Signed-off-by: Adam Geller --- .../PhasePolynomialRotationMerging.cpp | 35 ++++++++-- test/Quake/phase_polynomial_preprocess.qke | 2 +- .../phase_polynomial_rotation_merging.qke | 70 +++++++++++++++++++ 3 files changed, 99 insertions(+), 8 deletions(-) diff --git a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp index 618119b5e71..23b13764c5a 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp @@ -68,6 +68,7 @@ class PhasePolynomialRotationMergingPass new_phase.vars.remove(var); else new_phase.vars.insert(var); + new_phase.isInverted = (p1.isInverted != p2.isInverted); return new_phase; } @@ -204,14 +205,31 @@ class PhasePolynomialRotationMergingPass return std::nullopt; } - bool targetVisited(quake::OperatorInterface opi) { - assert(isControlledOp(opi.getOperation())); - auto next_result = getNextResult(opi.getTarget(0)); + /// @brief handles a swap between two wires, swapping their phases + /// @returns `true` if the swap has been handled and stepping can + /// continue, `false` otherwise + bool handleSwap(quake::SwapOp swap) { + auto wire0 = swap.getTarget(0); + auto wire1 = swap.getTarget(1); + if (wireVisited(wire0) || wireVisited(wire1)) + return true; + + auto stepper0 = getStepperForValue(wire0); + auto stepper1 = getStepperForValue(wire1); + if (!stepper0 || !stepper1) + return false; + + auto tmp = stepper0->current_phase; + stepper0->current_phase = stepper1->current_phase; + stepper1->current_phase = tmp; + return true; + } + + bool wireVisited(Value wire) { + auto next_result = getNextResult(wire); // Wait until target wire stepper steps to ensure // control phase is available - if (getStepperForValue(next_result)) - return true; - return false; + return !!getStepperForValue(next_result); } }; @@ -250,7 +268,7 @@ class PhasePolynomialRotationMergingPass } else { // Wait until the target has visited the operation so it can // access our phase (the control phase) - if (!container->targetVisited(opi)) + if (!container->wireVisited(opi.getTarget(0))) return; } } else if (isa(op) && opi.getControls().size() == 0) { @@ -260,6 +278,9 @@ class PhasePolynomialRotationMergingPass } else if (auto rzop = dyn_cast(op)) { if (store->addOrCombineRotationForPhase(rzop, current_phase)) return; + } else if (auto swap = dyn_cast(op)) { + if (!container->handleSwap(swap)) + return; } wire = getNextResult(wire); diff --git a/test/Quake/phase_polynomial_preprocess.qke b/test/Quake/phase_polynomial_preprocess.qke index 48a6d5ecb4f..ad3e5cc6041 100644 --- a/test/Quake/phase_polynomial_preprocess.qke +++ b/test/Quake/phase_polynomial_preprocess.qke @@ -200,4 +200,4 @@ func.func @kernel5() { // CHECK: %3 = quake.ry (%cst) %2#0 : (f64, !quake.wire) -> !quake.wire // CHECK: %4:2 = call @subcircuit1(%3, %2#1) : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) // CHECK: return -// CHECK: } \ No newline at end of file +// CHECK: } diff --git a/test/Quake/phase_polynomial_rotation_merging.qke b/test/Quake/phase_polynomial_rotation_merging.qke index 48592f41c2e..d2211570a16 100644 --- a/test/Quake/phase_polynomial_rotation_merging.qke +++ b/test/Quake/phase_polynomial_rotation_merging.qke @@ -91,3 +91,73 @@ func.func private @subcircuit3(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3 // CHECK: %4 = quake.rz (%arg3) %3 : (f64, !quake.wire) -> !quake.wire // CHECK: cc.return %4 : !quake.wire // CHECK: } + +// ----- + +func.func private @subcircuit4(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { + %0 = quake.rz (%arg3) %arg2 : (f64, !quake.wire) -> !quake.wire {processed} + %1:2 = quake.x [%arg0] %0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} + %2:2 = quake.swap %1#0, %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} + %3:2 = quake.x [%2#1] %1#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} + %4 = quake.rz (%arg4) %3#1 : (f64, !quake.wire) -> !quake.wire {processed} + cc.return %2#0, %3#0, %4 : !quake.wire, !quake.wire, !quake.wire +} + +// CHECK-LABEL: func.func private @subcircuit4(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { +// CHECK: %0 = arith.addf %arg3, %arg4 : f64 +// CHECK: %1 = quake.rz (%0) %arg2 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %2:2 = quake.x [%arg0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %3:2 = quake.swap %2#0, %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %4:2 = quake.x [%3#1] %2#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: cc.return %3#0, %4#0, %4#1 : !quake.wire, !quake.wire, !quake.wire +// CHECK: } + +// ----- + +func.func private @subcircuit4(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { + %0 = quake.rz (%arg3) %arg2 : (f64, !quake.wire) -> !quake.wire {processed} + %1:2 = quake.x [%arg0] %0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} + %2:2 = quake.swap %1#0, %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} + %3:2 = quake.x [%2#1] %1#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} + %4 = quake.rz (%arg4) %3#1 : (f64, !quake.wire) -> !quake.wire {processed} + cc.return %2#0, %3#0, %4 : !quake.wire, !quake.wire, !quake.wire +} + +// CHECK-LABEL: func.func private @subcircuit4(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { +// CHECK: %0 = arith.addf %arg3, %arg4 : f64 +// CHECK: %1 = quake.rz (%0) %arg2 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %2:2 = quake.x [%arg0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %3:2 = quake.swap %2#0, %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %4:2 = quake.x [%3#1] %2#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: cc.return %3#0, %4#0, %4#1 : !quake.wire, !quake.wire, !quake.wire +// CHECK: } + +// ----- + +func.func private @subcircuit5(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64, %arg3: f64, %arg4: f64, %arg5: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { + %0 = quake.x %arg0 : (!quake.wire) -> !quake.wire {processed} + %1:2 = quake.x [%0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} + %2 = quake.rz (%arg2) %1#1 : (f64, !quake.wire) -> !quake.wire {processed} + %3:2 = quake.x [%1#0] %2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} + %4 = quake.x %3#0 : (!quake.wire) -> !quake.wire {processed} + %5 = quake.rz (%arg3) %3#1 : (f64, !quake.wire) -> !quake.wire {processed} + %6:2 = quake.x [%4] %5 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} + %7 = quake.rz (%arg4) %6#1 : (f64, !quake.wire) -> !quake.wire {processed} + %8 = quake.x %7 : (!quake.wire) -> !quake.wire {processed} + %9 = quake.rz (%arg5) %8 : (f64, !quake.wire) -> !quake.wire {processed} + cc.return %6#0, %9 : !quake.wire, !quake.wire +} + +// CHECK-LABEL: func.func private @subcircuit5(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64, %arg3: f64, %arg4: f64, %arg5: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { +// CHECK: %0 = quake.x %arg0 : (!quake.wire) -> !quake.wire {processed} +// CHECK: %1:2 = quake.x [%0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %2 = arith.addf %arg2, %arg5 : f64 +// CHECK: %3 = quake.rz (%2) %1#1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %4:2 = quake.x [%1#0] %3 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %5 = quake.x %4#0 : (!quake.wire) -> !quake.wire {processed} +// CHECK: %6 = quake.rz (%arg3) %4#1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %7:2 = quake.x [%5] %6 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %8 = quake.rz (%arg4) %7#1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %9 = quake.x %8 : (!quake.wire) -> !quake.wire {processed} +// CHECK: cc.return %7#0, %9 : !quake.wire, !quake.wire +// CHECK: } From 70a318807fbae52d8099726cdd7e11689c822aaf Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Wed, 9 Jul 2025 01:30:31 +0000 Subject: [PATCH 14/21] Myriad fixes and improvements to work on phase_estimation example Signed-off-by: Adam Geller --- .../Transforms/PhasePolynomialPreprocess.cpp | 18 ++++++++++++- .../PhasePolynomialRotationMerging.cpp | 6 +---- lib/Optimizer/Transforms/Subcircuit.h | 26 +++++++++++++++---- 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp index ddd7197c201..b4de298e316 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -121,6 +121,13 @@ class PhasePolynomialPreprocessPass removal_order.push_back(next); } + void shiftAfter(Operation *pivot, Operation *to_shift) { + if (to_shift->isBeforeInBlock(pivot)) + to_shift->moveAfter(pivot); + for (auto user : to_shift->getUsers()) + shiftAfter(to_shift, user); + } + void moveToFunc(Subcircuit *subcircuit, size_t subcircuit_num) { auto module = getOperation(); SmallVector types(subcircuit->getInitialWires().size(), @@ -172,13 +179,22 @@ class PhasePolynomialPreprocessPass builder.create(fun.getLoc(), new_wires); auto cnot = subcircuit->getStart(); + auto latest = cnot; + for (auto arg : args) { + auto dop = arg.getDefiningOp(); + if (dop && latest->isBeforeInBlock(dop)) + latest = dop; + } + builder.setInsertionPointAfter(latest); - builder.setInsertionPointAfter(cnot); auto call = builder.create(cnot->getLoc(), types, fun.getSymNameAttr(), args); for (size_t i = 0; i < steppers.size(); i++) steppers[i]->getOldWire().replaceAllUsesWith(call.getResult(i)); + for (auto user : call->getUsers()) + shiftAfter(call, user); + for (auto stepper : steppers) delete stepper; } diff --git a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp index 23b13764c5a..fcf2c1843e7 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp @@ -116,11 +116,7 @@ class PhasePolynomialRotationMergingPass auto rot_arg2 = rzop.getOperand(0); auto new_rot_arg = builder.create(old_rzop.getLoc(), rot_arg1, rot_arg2); - // TODO: Can replace operand 0 directly rather than cloning first? - auto new_rot = builder.clone(*old_rzop.getOperation()); - new_rot->setOperand(0, new_rot_arg.getResult()); - old_rzop.getResult(0).replaceAllUsesWith(new_rot->getResult(0)); - old_rzop.erase(); + old_rzop->setOperand(0, new_rot_arg.getResult()); rzop.getResult(0).replaceAllUsesWith(rzop.getOperand(1)); rzop.erase(); } diff --git a/lib/Optimizer/Transforms/Subcircuit.h b/lib/Optimizer/Transforms/Subcircuit.h index af8b05be66d..be18bfbd742 100644 --- a/lib/Optimizer/Transforms/Subcircuit.h +++ b/lib/Optimizer/Transforms/Subcircuit.h @@ -143,14 +143,18 @@ class Subcircuit { } // Prune operations after a termination point from the subcircuit - void pruneWire(Value wire) { + void pruneWire(Value wire, SetVector &pruned) { if (termination_points.contains(wire)) termination_points.remove(wire); if (!wire.hasOneUse()) return; Operation *op = wire.getUses().begin().getUser(); + if (pruned.contains(op)) + return; + ops.remove(op); + pruned.insert(op); // TODO: According to the paper, if the op is a CNot and the wire we are // pruning along is the target, then we do not have to prune along the @@ -163,11 +167,13 @@ class Subcircuit { // } for (auto result : op->getResults()) - pruneWire(result); + pruneWire(result, pruned); // Adjust termination border for (auto operand : op->getOperands()) if (ops.contains(operand.getDefiningOp())) termination_points.insert(operand); + else if (termination_points.contains(operand) && isAfterTerminationPoint(operand)) + termination_points.remove(operand); } void pruneSubcircuit() { @@ -175,10 +181,20 @@ class Subcircuit { // termination point seen along each wire in the subcircuit // (this means that it is important to build subcircuits // by inspecting controlled gates in topological order) - for (auto wire : termination_points) { + SmallVector sorted; + SetVector pruned; + for (auto wire : termination_points) if (!isAfterTerminationPoint(wire) && wire.hasOneUse()) - pruneWire(wire); - } + sorted.push_back(wire); + + auto cmp = [](Value v1, Value v2){ + return v1.getDefiningOp()->isBeforeInBlock(v2.getDefiningOp()); + }; + + std::sort(sorted.begin(), sorted.end(), cmp); + + for (size_t i = 0; i < sorted.size(); i++) + pruneWire(sorted[i], pruned); } Subcircuit() {} From a794e09575f5b26491fd36c0e5652352ffde89f5 Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Wed, 9 Jul 2025 01:31:13 +0000 Subject: [PATCH 15/21] Formatting Signed-off-by: Adam Geller --- lib/Optimizer/Transforms/Subcircuit.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/Optimizer/Transforms/Subcircuit.h b/lib/Optimizer/Transforms/Subcircuit.h index be18bfbd742..d0c7cf16ebf 100644 --- a/lib/Optimizer/Transforms/Subcircuit.h +++ b/lib/Optimizer/Transforms/Subcircuit.h @@ -172,7 +172,8 @@ class Subcircuit { for (auto operand : op->getOperands()) if (ops.contains(operand.getDefiningOp())) termination_points.insert(operand); - else if (termination_points.contains(operand) && isAfterTerminationPoint(operand)) + else if (termination_points.contains(operand) && + isAfterTerminationPoint(operand)) termination_points.remove(operand); } @@ -187,7 +188,7 @@ class Subcircuit { if (!isAfterTerminationPoint(wire) && wire.hasOneUse()) sorted.push_back(wire); - auto cmp = [](Value v1, Value v2){ + auto cmp = [](Value v1, Value v2) { return v1.getDefiningOp()->isBeforeInBlock(v2.getDefiningOp()); }; From ecccaa8310552d7aec70d2b6428823abac6b03e3 Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Thu, 10 Jul 2025 00:37:59 +0000 Subject: [PATCH 16/21] Various performance fixes Signed-off-by: Adam Geller --- .../Transforms/PhasePolynomialPreprocess.cpp | 27 ++++++---- .../PhasePolynomialRotationMerging.cpp | 3 ++ lib/Optimizer/Transforms/Subcircuit.h | 53 +++++++++---------- test/Quake/phase_polynomial_preprocess.qke | 39 +++++++------- 4 files changed, 64 insertions(+), 58 deletions(-) diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp index b4de298e316..ee7045ca418 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -34,6 +34,7 @@ class PhasePolynomialPreprocessPass Value old_wire; Value new_wire; Subcircuit *subcircuit; + bool stopped = false; public: WireStepper(Subcircuit *circuit, Value initial, Value arg) { @@ -43,7 +44,8 @@ class PhasePolynomialPreprocessPass } bool isStopped() { - return subcircuit->getTerminalWires().contains(old_wire); + return stopped || + (stopped = subcircuit->getTerminalWires().contains(old_wire)); } Value getNewWire() { return new_wire; } @@ -52,14 +54,14 @@ class PhasePolynomialPreprocessPass void step(DenseMap &cloned, OpBuilder &builder, std::function addFuncArg) { - if (isStopped()) - return; - - // TODO: Something more elegant here would be nice + // TODO: Something more elegant here would be nice. + // The problem is that the old_wire may have two uses, + // one in the original block, and one in the new function by the cloned + // op. We want to ignore the cloned op here. Operation *op = nullptr; size_t opnum = 0; for (auto &use : old_wire.getUses()) { - if (!subcircuit->getOps().contains(use.getOwner())) + if (use.getOwner()->hasAttr("clone")) continue; op = use.getOwner(); opnum = use.getOperandNumber(); @@ -87,6 +89,7 @@ class PhasePolynomialPreprocessPass auto clone = builder.clone(*op); clone->setOperand(opnum, new_wire); + clone->setAttr("clone", builder.getUnitAttr()); // Make classical values arguments to the function, // to allow non-constant rotation angles @@ -122,8 +125,9 @@ class PhasePolynomialPreprocessPass } void shiftAfter(Operation *pivot, Operation *to_shift) { - if (to_shift->isBeforeInBlock(pivot)) - to_shift->moveAfter(pivot); + if (pivot->isBeforeInBlock(to_shift)) + return; + to_shift->moveAfter(pivot); for (auto user : to_shift->getUsers()) shiftAfter(to_shift, user); } @@ -163,8 +167,9 @@ class PhasePolynomialPreprocessPass while (true) { auto stepped = false; for (auto stepper : steppers) { - if (!stepper->isStopped()) - stepped = true; + if (stepper->isStopped()) + continue; + stepped = true; stepper->step(cloned, builder, add_arg_fun); } @@ -181,6 +186,8 @@ class PhasePolynomialPreprocessPass auto cnot = subcircuit->getStart(); auto latest = cnot; for (auto arg : args) { + if (!isa(arg.getType())) + continue; auto dop = arg.getDefiningOp(); if (dop && latest->isBeforeInBlock(dop)) latest = dop; diff --git a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp index fcf2c1843e7..8b9469d4b58 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp @@ -240,6 +240,9 @@ class PhasePolynomialRotationMergingPass bool isStopped() { Operation *op = *wire.getUsers().begin(); assert(op); + // Have to have explicit check for termination point + // because rotation merging may have removed old termination + // point return isTerminationPoint(op); } diff --git a/lib/Optimizer/Transforms/Subcircuit.h b/lib/Optimizer/Transforms/Subcircuit.h index d0c7cf16ebf..0c27aa18ebf 100644 --- a/lib/Optimizer/Transforms/Subcircuit.h +++ b/lib/Optimizer/Transforms/Subcircuit.h @@ -47,10 +47,7 @@ class Subcircuit { return isTerminationPoint(wire.getDefiningOp()); } - void maybeAddAnchorPoint(Value v) { - if (!seen.contains(v)) - anchor_points.insert(v); - } + void addAnchorPoint(Value v) { anchor_points.insert(v); } void calculateSubcircuitForQubitForward(OpResult v) { seen.insert(v); @@ -67,25 +64,23 @@ class Subcircuit { ops.insert(op); + auto nextResult = getNextResult(v); + // Controlled not, figure out whether we are tracking the control // or target, and add an anchor point to the other qubit if (op->getResults().size() > 1) { auto control = op->getResult(0); auto target = op->getResult(1); // Is this the control or target qubit? - if (v.getResultNumber() == 0) { + if (nextResult == control) // Tracking the control qubit - calculateSubcircuitForQubitForward(control); - maybeAddAnchorPoint(target); - } else { + addAnchorPoint(target); + else // Tracking the target qubit - maybeAddAnchorPoint(control); - calculateSubcircuitForQubitForward(target); - } - } else { - // Otherwise, single qubit gate, just follow result - calculateSubcircuitForQubitForward(getNextResult(v)); + addAnchorPoint(control); } + + calculateSubcircuitForQubitForward(nextResult); } void calculateSubcircuitForQubitBackward(Value v) { @@ -99,6 +94,8 @@ class Subcircuit { ops.insert(op); + auto nextOperand = getNextOperand(v); + // Controlled not, figure out whether we are tracking the control // or target, and add an anchor point to the other qubit // Use getResults() as Rz has two operands but only one result @@ -106,19 +103,15 @@ class Subcircuit { auto control = op->getOperand(0); auto target = op->getOperand(1); // Is this the control or target qubit? - if (v == target) { + if (nextOperand == control) // Tracking the control qubit - calculateSubcircuitForQubitBackward(control); - maybeAddAnchorPoint(target); - } else { + addAnchorPoint(target); + else // Tracking the target qubit - maybeAddAnchorPoint(control); - calculateSubcircuitForQubitBackward(target); - } - } else { - // Otherwise, single qubit gate, just follow operand - calculateSubcircuitForQubitBackward(getNextOperand(v)); + addAnchorPoint(control); } + + calculateSubcircuitForQubitBackward(nextOperand); } void calculateInitialSubcircuit(Operation *op) { @@ -136,16 +129,15 @@ class Subcircuit { while (!anchor_points.empty()) { auto next = anchor_points.back(); anchor_points.pop_back(); + if (seen.contains(next)) + continue; calculateSubcircuitForQubitForward(dyn_cast(next)); - seen.remove(next); calculateSubcircuitForQubitBackward(next); } } // Prune operations after a termination point from the subcircuit void pruneWire(Value wire, SetVector &pruned) { - if (termination_points.contains(wire)) - termination_points.remove(wire); if (!wire.hasOneUse()) return; Operation *op = wire.getUses().begin().getUser(); @@ -153,6 +145,9 @@ class Subcircuit { if (pruned.contains(op)) return; + if (termination_points.contains(wire)) + termination_points.remove(wire); + ops.remove(op); pruned.insert(op); @@ -170,7 +165,7 @@ class Subcircuit { pruneWire(result, pruned); // Adjust termination border for (auto operand : op->getOperands()) - if (ops.contains(operand.getDefiningOp())) + if (operand.getDefiningOp() && ops.contains(operand.getDefiningOp())) termination_points.insert(operand); else if (termination_points.contains(operand) && isAfterTerminationPoint(operand)) @@ -189,7 +184,7 @@ class Subcircuit { sorted.push_back(wire); auto cmp = [](Value v1, Value v2) { - return v1.getDefiningOp()->isBeforeInBlock(v2.getDefiningOp()); + return !v1.getDefiningOp()->isBeforeInBlock(v2.getDefiningOp()); }; std::sort(sorted.begin(), sorted.end(), cmp); diff --git a/test/Quake/phase_polynomial_preprocess.qke b/test/Quake/phase_polynomial_preprocess.qke index ad3e5cc6041..d7a109a539a 100644 --- a/test/Quake/phase_polynomial_preprocess.qke +++ b/test/Quake/phase_polynomial_preprocess.qke @@ -63,22 +63,22 @@ func.func @kernel2() { return } -// CHECK: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { -// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %1:2 = quake.x [%arg2] %0#0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK-LABEL: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %1:2 = quake.x [%arg1] %0#0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} // CHECK: %2 = quake.rz (%arg3) %1#1 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK: cc.return %2, %0#1, %1#0 : !quake.wire, !quake.wire, !quake.wire -// CHECK: } -// CHECK: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64, %arg5: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { -// CHECK-DAG: %0 = quake.rz (%arg3) %arg0 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK-DAG: %1 = quake.rz (%arg4) %arg1 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK: %2:2 = quake.x [%0] %arg2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %3:2 = quake.x [%2#0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %4 = quake.rz (%arg5) %2#1 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK: %5:2 = quake.x [%4] %3#0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: cc.return %5#1, %3#1, %5#0 : !quake.wire, !quake.wire, !quake.wire -// CHECK: } -// CHECK: func.func @kernel2() { +// CHECK: cc.return %2, %1#0, %0#1 : !quake.wire, !quake.wire, !quake.wire +// CHECK: } +// CHECK-LABEL: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64, %arg5: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { +// CHECK: %0 = quake.rz (%arg3) %arg0 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %2 = quake.rz (%arg4) %arg2 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %1:2 = quake.x [%0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %4:2 = quake.x [%1#0] %2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %3 = quake.rz (%arg5) %1#1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %5:2 = quake.x [%3] %4#0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: cc.return %5#1, %5#0, %4#1 : !quake.wire, !quake.wire, !quake.wire +// CHECK: } +// CHECK-LABEL: func.func @kernel2() { // CHECK: %cst = arith.constant 1.000000e+00 : f64 // CHECK: %0 = quake.null_wire // CHECK: %1 = quake.null_wire @@ -86,12 +86,13 @@ func.func @kernel2() { // CHECK: %3 = quake.h %0 : (!quake.wire) -> !quake.wire // CHECK: %4 = quake.h %1 : (!quake.wire) -> !quake.wire // CHECK: %5 = quake.h %2 : (!quake.wire) -> !quake.wire -// CHECK: %6:3 = call @subcircuit0(%4, %5, %3, %cst, %cst, %cst) : (!quake.wire, !quake.wire, !quake.wire, f64, f64, f64) -> (!quake.wire, !quake.wire, !quake.wire) -// CHECK: %7 = quake.h %6#1 : (!quake.wire) -> !quake.wire -// CHECK: %8:3 = call @subcircuit1(%6#0, %7, %6#2, %cst) : (!quake.wire, !quake.wire, !quake.wire, f64) -> (!quake.wire, !quake.wire, !quake.wire) -// CHECK: %9 = quake.h %8#2 : (!quake.wire) -> !quake.wire +// CHECK: %6:3 = call @subcircuit0(%4, %3, %5, %cst, %cst, %cst) : (!quake.wire, !quake.wire, !quake.wire, f64, f64, f64) -> (!quake.wire, !quake.wire, !quake.wire) +// CHECK: %7 = quake.h %6#2 : (!quake.wire) -> !quake.wire +// CHECK: %8:3 = call @subcircuit1(%6#0, %6#1, %7, %cst) : (!quake.wire, !quake.wire, !quake.wire, f64) -> (!quake.wire, !quake.wire, !quake.wire) +// CHECK: %9 = quake.h %8#1 : (!quake.wire) -> !quake.wire // CHECK: %10 = quake.h %8#0 : (!quake.wire) -> !quake.wire // CHECK: return +// CHECK: } // CHECK: } // ----- From 523dd81bb9a41f9239206d42ac9730d9486dd23a Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Thu, 10 Jul 2025 21:18:51 +0000 Subject: [PATCH 17/21] Update tests Signed-off-by: Adam Geller --- .../Transforms/PhasePolynomialPreprocess.cpp | 5 + test/Quake/phase_polynomial_preprocess.qke | 122 ++++++++++++++---- .../phase_polynomial_rotation_merging.qke | 100 +++++++------- 3 files changed, 148 insertions(+), 79 deletions(-) diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp index ee7045ca418..943ee52d5ca 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -194,6 +194,11 @@ class PhasePolynomialPreprocessPass } builder.setInsertionPointAfter(latest); + fun.walk([&](Operation *op){ + op->removeAttr("clone"); + op->removeAttr("processed"); + }); + auto call = builder.create(cnot->getLoc(), types, fun.getSymNameAttr(), args); for (size_t i = 0; i < steppers.size(); i++) diff --git a/test/Quake/phase_polynomial_preprocess.qke b/test/Quake/phase_polynomial_preprocess.qke index d7a109a539a..e1f70fa81f2 100644 --- a/test/Quake/phase_polynomial_preprocess.qke +++ b/test/Quake/phase_polynomial_preprocess.qke @@ -19,16 +19,16 @@ func.func @kernel1() { return } -// CHECK: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { -// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK-LABEL: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) // CHECK: cc.return %0#0, %0#1 : !quake.wire, !quake.wire // CHECK: } -// CHECK: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { -// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %1 = quake.rz (%arg2) %0#1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK-LABEL: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %1 = quake.rz (%arg2) %0#1 : (f64, !quake.wire) -> !quake.wire // CHECK: cc.return %0#0, %1 : !quake.wire, !quake.wire // CHECK: } -// CHECK: func.func @kernel1() { +// CHECK-LABEL: func.func @kernel1() { // CHECK: %cst = arith.constant 1.000000e+00 : f64 // CHECK: %0 = quake.null_wire // CHECK: %1 = quake.null_wire @@ -64,18 +64,18 @@ func.func @kernel2() { } // CHECK-LABEL: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { -// CHECK: %0:2 = quake.x [%arg0] %arg2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %1:2 = quake.x [%arg1] %0#0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %2 = quake.rz (%arg3) %1#1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %0:2 = quake.x [%arg0] %arg2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %1:2 = quake.x [%arg1] %0#0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %2 = quake.rz (%arg3) %1#1 : (f64, !quake.wire) -> !quake.wire // CHECK: cc.return %2, %1#0, %0#1 : !quake.wire, !quake.wire, !quake.wire // CHECK: } // CHECK-LABEL: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64, %arg5: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { -// CHECK: %0 = quake.rz (%arg3) %arg0 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK: %2 = quake.rz (%arg4) %arg2 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK: %1:2 = quake.x [%0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %4:2 = quake.x [%1#0] %2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %3 = quake.rz (%arg5) %1#1 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK: %5:2 = quake.x [%3] %4#0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %0 = quake.rz (%arg3) %arg0 : (f64, !quake.wire) -> !quake.wire +// CHECK: %1:2 = quake.x [%0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %2 = quake.rz (%arg4) %arg2 : (f64, !quake.wire) -> !quake.wire +// CHECK: %3 = quake.rz (%arg5) %1#1 : (f64, !quake.wire) -> !quake.wire +// CHECK: %4:2 = quake.x [%1#0] %2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %5:2 = quake.x [%3] %4#0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) // CHECK: cc.return %5#1, %5#0, %4#1 : !quake.wire, !quake.wire, !quake.wire // CHECK: } // CHECK-LABEL: func.func @kernel2() { @@ -93,7 +93,6 @@ func.func @kernel2() { // CHECK: %10 = quake.h %8#0 : (!quake.wire) -> !quake.wire // CHECK: return // CHECK: } -// CHECK: } // ----- @@ -115,12 +114,12 @@ func.func @kernel3() { } // CHECK-LABEL: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { -// CHECK: %0 = quake.rz (%arg2) %arg1 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK: %1:2 = quake.x [%arg0] %0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %0 = quake.rz (%arg2) %arg1 : (f64, !quake.wire) -> !quake.wire +// CHECK: %1:2 = quake.x [%arg0] %0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) // CHECK: cc.return %1#0, %1#1 : !quake.wire, !quake.wire // CHECK: } // CHECK-LABEL: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { -// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) // CHECK: cc.return %0#0, %0#1 : !quake.wire, !quake.wire // CHECK: } // CHECK-LABEL: func.func @kernel3() { @@ -155,7 +154,7 @@ func.func @kernel4() { } // CHECK-LABEL: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { -// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) // CHECK: cc.return %0#0, %0#1 : !quake.wire, !quake.wire // CHECK: } // CHECK-LABEL: func.func @kernel4() { @@ -185,12 +184,12 @@ func.func @kernel5() { } // CHECK-LABEL: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { -// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) // CHECK: cc.return %0#0, %0#1 : !quake.wire, !quake.wire // CHECK: } // CHECK-LABEL: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { -// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %1 = quake.rz (%arg2) %0#1 : (f64, !quake.wire) -> !quake.wire {processed} +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %1 = quake.rz (%arg2) %0#1 : (f64, !quake.wire) -> !quake.wire // CHECK: cc.return %0#0, %1 : !quake.wire, !quake.wire // CHECK: } // CHECK-LABEL: func.func @kernel5() { @@ -202,3 +201,80 @@ func.func @kernel5() { // CHECK: %4:2 = call @subcircuit1(%3, %2#1) : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) // CHECK: return // CHECK: } + +// ----- + +// A test of a wire which should immediately terminate +// (and therefore not be included in the subcircuit) +// This is seen in that subcircuit0 has no input for %2 + +func.func @kernel6() { + %cst = arith.constant 1.000000e+00 : f64 + %0 = quake.null_wire + %1 = quake.null_wire + %2 = quake.null_wire + %3:2 = quake.x [%0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %4 = quake.h %3#1 : (!quake.wire) -> !quake.wire + %5:2 = quake.x [%4] %2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %6:2 = quake.x [%3#0] %5#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + return +} + +// CHECK-LABEL: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %1:2 = quake.x [%arg2] %0#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: cc.return %0#0, %1#1, %1#0 : !quake.wire, !quake.wire, !quake.wire +// CHECK: } +// CHECK-LABEL: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: cc.return %0#0, %0#1 : !quake.wire, !quake.wire +// CHECK: } +// CHECK-LABEL: func.func @kernel6() { +// CHECK: %cst = arith.constant 1.000000e+00 : f64 +// CHECK: %0 = quake.null_wire +// CHECK: %1 = quake.null_wire +// CHECK: %2 = quake.null_wire +// CHECK: %3:2 = call @subcircuit0(%0, %1) : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %4 = quake.h %3#1 : (!quake.wire) -> !quake.wire +// CHECK: %5:3 = call @subcircuit1(%4, %2, %3#0) : (!quake.wire, !quake.wire, !quake.wire) -> (!quake.wire, !quake.wire, !quake.wire) +// CHECK: return +// CHECK: } + +// ----- + +// A check that the call is correctly placed, with the second +// quake.h moved after call @subcircuit0. + +func.func @kernel7() { + %cst = arith.constant 1.000000e+00 : f64 + %0 = quake.null_wire + %1 = quake.null_wire + %2 = quake.null_wire + %3:2 = quake.x [%0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %4 = quake.h %3#0 : (!quake.wire) -> !quake.wire + %5 = quake.h %2 : (!quake.wire) -> !quake.wire + %6:2 = quake.x [%5] %3#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %7:2 = quake.x [%6#0] %4 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + return +} + +// CHECK-LABEL: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: !quake.wire) -> (!quake.wire, !quake.wire) attributes {num_cnots = 1 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: cc.return %0#0, %0#1 : !quake.wire, !quake.wire +// CHECK: } +// CHECK-LABEL: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { +// CHECK: %0:2 = quake.x [%arg0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %1:2 = quake.x [%arg2] %0#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: cc.return %0#0, %1#1, %1#0 : !quake.wire, !quake.wire, !quake.wire +// CHECK: } +// CHECK-LABEL: func.func @kernel7() { +// CHECK: %cst = arith.constant 1.000000e+00 : f64 +// CHECK: %0 = quake.null_wire +// CHECK: %1 = quake.null_wire +// CHECK: %2 = quake.null_wire +// CHECK: %3 = quake.h %2 : (!quake.wire) -> !quake.wire +// CHECK: %4:3 = call @subcircuit0(%0, %1, %3) : (!quake.wire, !quake.wire, !quake.wire) -> (!quake.wire, !quake.wire, !quake.wire) +// CHECK: %5 = quake.h %4#0 : (!quake.wire) -> !quake.wire +// CHECK: %6:2 = call @subcircuit1(%4#2, %5) : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: return +// CHECK: } \ No newline at end of file diff --git a/test/Quake/phase_polynomial_rotation_merging.qke b/test/Quake/phase_polynomial_rotation_merging.qke index d2211570a16..f037d9b3f0b 100644 --- a/test/Quake/phase_polynomial_rotation_merging.qke +++ b/test/Quake/phase_polynomial_rotation_merging.qke @@ -6,7 +6,7 @@ // the terms of the Apache License 2.0 which accompanies this distribution. // // ========================================================================== // -// RUN: cudaq-opt --phase-polynomial-rotation-merging -split-input-file %s | FileCheck %s +// RUN: cudaq-opt --phase-polynomial-rotation-merging %s | FileCheck %s func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { %0 = quake.rz (%arg2) %arg1 : (f64, !quake.wire) -> !quake.wire @@ -19,7 +19,7 @@ func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f6 cc.return %6#1, %6#0 : !quake.wire, !quake.wire } -// CHECK: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { +// CHECK-LABEL: func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { // CHECK: %[[VAL_0:.*]] = arith.addf %arg2, %arg2 : f64 // CHECK: %[[VAL_1:.*]] = quake.rz (%[[VAL_0]]) %arg1 : (f64, !quake.wire) -> !quake.wire // CHECK: %[[VAL_2:.*]]:2 = quake.x [%arg0] %[[VAL_1]] : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) @@ -30,8 +30,6 @@ func.func private @subcircuit0(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f6 // CHECK: cc.return %[[VAL_6]]#1, %[[VAL_6]]#0 : !quake.wire, !quake.wire // CHECK: } -// ----- - func.func private @subcircuit1(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3: f64) -> !quake.wire attributes {num_cnots = 0 : ui32, subcircuit} { %0 = quake.rz (%arg1) %arg0 : (f64, !quake.wire) -> !quake.wire %1 = quake.x %0 : (!quake.wire) -> !quake.wire @@ -41,7 +39,7 @@ func.func private @subcircuit1(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3 cc.return %4 : !quake.wire } -// CHECK: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3: f64) -> !quake.wire attributes {num_cnots = 0 : ui32, subcircuit} { +// CHECK-LABEL: func.func private @subcircuit1(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3: f64) -> !quake.wire attributes {num_cnots = 0 : ui32, subcircuit} { // CHECK: %[[VAL_0:.*]] = arith.addf %arg1, %arg3 : f64 // CHECK: %[[VAL_1:.*]] = quake.rz (%[[VAL_0]]) %arg0 : (f64, !quake.wire) -> !quake.wire // CHECK: %[[VAL_2:.*]] = quake.x %[[VAL_1]] : (!quake.wire) -> !quake.wire @@ -50,8 +48,6 @@ func.func private @subcircuit1(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3 // CHECK: cc.return %[[VAL_4]] : !quake.wire // CHECK: } -// ----- - // Invalid subcircuit functions shouldn't be touched func.func private @subcircuit2(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3: f64) -> !quake.wire attributes {num_cnots = 0 : ui32, subcircuit} { @@ -72,8 +68,6 @@ func.func private @subcircuit2(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3 // CHECK: cc.return %4 : !quake.wire // CHECK: } -// ----- - func.func private @subcircuit3(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3: f64) -> !quake.wire attributes {num_cnots = 0 : ui32} { %0 = quake.rz (%arg1) %arg0 : (f64, !quake.wire) -> !quake.wire %1 = quake.x %0 : (!quake.wire) -> !quake.wire @@ -92,72 +86,66 @@ func.func private @subcircuit3(%arg0: !quake.wire, %arg1: f64, %arg2: f64, %arg3 // CHECK: cc.return %4 : !quake.wire // CHECK: } -// ----- - func.func private @subcircuit4(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { - %0 = quake.rz (%arg3) %arg2 : (f64, !quake.wire) -> !quake.wire {processed} - %1:2 = quake.x [%arg0] %0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} - %2:2 = quake.swap %1#0, %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} - %3:2 = quake.x [%2#1] %1#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} - %4 = quake.rz (%arg4) %3#1 : (f64, !quake.wire) -> !quake.wire {processed} + %0 = quake.rz (%arg3) %arg2 : (f64, !quake.wire) -> !quake.wire + %1:2 = quake.x [%arg0] %0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %2:2 = quake.swap %1#0, %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %3:2 = quake.x [%2#1] %1#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %4 = quake.rz (%arg4) %3#1 : (f64, !quake.wire) -> !quake.wire cc.return %2#0, %3#0, %4 : !quake.wire, !quake.wire, !quake.wire } // CHECK-LABEL: func.func private @subcircuit4(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { // CHECK: %0 = arith.addf %arg3, %arg4 : f64 -// CHECK: %1 = quake.rz (%0) %arg2 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK: %2:2 = quake.x [%arg0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %3:2 = quake.swap %2#0, %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %4:2 = quake.x [%3#1] %2#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %1 = quake.rz (%0) %arg2 : (f64, !quake.wire) -> !quake.wire +// CHECK: %2:2 = quake.x [%arg0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %3:2 = quake.swap %2#0, %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %4:2 = quake.x [%3#1] %2#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) // CHECK: cc.return %3#0, %4#0, %4#1 : !quake.wire, !quake.wire, !quake.wire // CHECK: } -// ----- - -func.func private @subcircuit4(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { - %0 = quake.rz (%arg3) %arg2 : (f64, !quake.wire) -> !quake.wire {processed} - %1:2 = quake.x [%arg0] %0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} - %2:2 = quake.swap %1#0, %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} - %3:2 = quake.x [%2#1] %1#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} - %4 = quake.rz (%arg4) %3#1 : (f64, !quake.wire) -> !quake.wire {processed} +func.func private @subcircuit5(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { + %0 = quake.rz (%arg3) %arg2 : (f64, !quake.wire) -> !quake.wire + %1:2 = quake.x [%arg0] %0 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %2:2 = quake.swap %1#0, %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %3:2 = quake.x [%2#1] %1#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %4 = quake.rz (%arg4) %3#1 : (f64, !quake.wire) -> !quake.wire cc.return %2#0, %3#0, %4 : !quake.wire, !quake.wire, !quake.wire } -// CHECK-LABEL: func.func private @subcircuit4(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { +// CHECK-LABEL: func.func private @subcircuit5(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: !quake.wire, %arg3: f64, %arg4: f64) -> (!quake.wire, !quake.wire, !quake.wire) attributes {num_cnots = 2 : ui32, subcircuit} { // CHECK: %0 = arith.addf %arg3, %arg4 : f64 -// CHECK: %1 = quake.rz (%0) %arg2 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK: %2:2 = quake.x [%arg0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %3:2 = quake.swap %2#0, %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %4:2 = quake.x [%3#1] %2#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK: %1 = quake.rz (%0) %arg2 : (f64, !quake.wire) -> !quake.wire +// CHECK: %2:2 = quake.x [%arg0] %1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %3:2 = quake.swap %2#0, %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %4:2 = quake.x [%3#1] %2#1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) // CHECK: cc.return %3#0, %4#0, %4#1 : !quake.wire, !quake.wire, !quake.wire // CHECK: } -// ----- - -func.func private @subcircuit5(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64, %arg3: f64, %arg4: f64, %arg5: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { - %0 = quake.x %arg0 : (!quake.wire) -> !quake.wire {processed} - %1:2 = quake.x [%0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} - %2 = quake.rz (%arg2) %1#1 : (f64, !quake.wire) -> !quake.wire {processed} - %3:2 = quake.x [%1#0] %2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} - %4 = quake.x %3#0 : (!quake.wire) -> !quake.wire {processed} - %5 = quake.rz (%arg3) %3#1 : (f64, !quake.wire) -> !quake.wire {processed} - %6:2 = quake.x [%4] %5 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} - %7 = quake.rz (%arg4) %6#1 : (f64, !quake.wire) -> !quake.wire {processed} - %8 = quake.x %7 : (!quake.wire) -> !quake.wire {processed} - %9 = quake.rz (%arg5) %8 : (f64, !quake.wire) -> !quake.wire {processed} +func.func private @subcircuit6(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64, %arg3: f64, %arg4: f64, %arg5: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { + %0 = quake.x %arg0 : (!quake.wire) -> !quake.wire + %1:2 = quake.x [%0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %2 = quake.rz (%arg2) %1#1 : (f64, !quake.wire) -> !quake.wire + %3:2 = quake.x [%1#0] %2 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %4 = quake.x %3#0 : (!quake.wire) -> !quake.wire + %5 = quake.rz (%arg3) %3#1 : (f64, !quake.wire) -> !quake.wire + %6:2 = quake.x [%4] %5 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) + %7 = quake.rz (%arg4) %6#1 : (f64, !quake.wire) -> !quake.wire + %8 = quake.x %7 : (!quake.wire) -> !quake.wire + %9 = quake.rz (%arg5) %8 : (f64, !quake.wire) -> !quake.wire cc.return %6#0, %9 : !quake.wire, !quake.wire } -// CHECK-LABEL: func.func private @subcircuit5(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64, %arg3: f64, %arg4: f64, %arg5: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { -// CHECK: %0 = quake.x %arg0 : (!quake.wire) -> !quake.wire {processed} -// CHECK: %1:2 = quake.x [%0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} +// CHECK-LABEL: func.func private @subcircuit6(%arg0: !quake.wire, %arg1: !quake.wire, %arg2: f64, %arg3: f64, %arg4: f64, %arg5: f64) -> (!quake.wire, !quake.wire) attributes {num_cnots = 3 : ui32, subcircuit} { +// CHECK: %0 = quake.x %arg0 : (!quake.wire) -> !quake.wire +// CHECK: %1:2 = quake.x [%0] %arg1 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) // CHECK: %2 = arith.addf %arg2, %arg5 : f64 -// CHECK: %3 = quake.rz (%2) %1#1 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK: %4:2 = quake.x [%1#0] %3 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %5 = quake.x %4#0 : (!quake.wire) -> !quake.wire {processed} -// CHECK: %6 = quake.rz (%arg3) %4#1 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK: %7:2 = quake.x [%5] %6 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) {processed} -// CHECK: %8 = quake.rz (%arg4) %7#1 : (f64, !quake.wire) -> !quake.wire {processed} -// CHECK: %9 = quake.x %8 : (!quake.wire) -> !quake.wire {processed} +// CHECK: %3 = quake.rz (%2) %1#1 : (f64, !quake.wire) -> !quake.wire +// CHECK: %4:2 = quake.x [%1#0] %3 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %5 = quake.x %4#0 : (!quake.wire) -> !quake.wire +// CHECK: %6 = quake.rz (%arg3) %4#1 : (f64, !quake.wire) -> !quake.wire +// CHECK: %7:2 = quake.x [%5] %6 : (!quake.wire, !quake.wire) -> (!quake.wire, !quake.wire) +// CHECK: %8 = quake.rz (%arg4) %7#1 : (f64, !quake.wire) -> !quake.wire +// CHECK: %9 = quake.x %8 : (!quake.wire) -> !quake.wire // CHECK: cc.return %7#0, %9 : !quake.wire, !quake.wire // CHECK: } From a0fa4c74b93ba56d3bbf4245df44bf786cf7ef47 Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Thu, 10 Jul 2025 21:18:51 +0000 Subject: [PATCH 18/21] Update tests Signed-off-by: Adam Geller --- .../Transforms/PhasePolynomialPreprocess.cpp | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp index 943ee52d5ca..a82cb3d0f96 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -194,7 +194,7 @@ class PhasePolynomialPreprocessPass } builder.setInsertionPointAfter(latest); - fun.walk([&](Operation *op){ + fun.walk([&](Operation *op) { op->removeAttr("clone"); op->removeAttr("processed"); }); @@ -243,3 +243,27 @@ class PhasePolynomialPreprocessPass } }; } // namespace + +static void createUnrollingPipeline(OpPassManager &pm, unsigned threshold, + bool signalFailure, bool allowBreak, + bool allowClosedInterval) { + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(cudaq::opt::createClassicalMemToReg()); + pm.addNestedPass(createCanonicalizerPass()); + cudaq::opt::LoopNormalizeOptions lno{allowClosedInterval, allowBreak}; + pm.addNestedPass(cudaq::opt::createLoopNormalize(lno)); + pm.addNestedPass(createCanonicalizerPass()); + cudaq::opt::LoopUnrollOptions luo{threshold, signalFailure, allowBreak}; + pm.addNestedPass(cudaq::opt::createLoopUnroll(luo)); + pm.addNestedPass(cudaq::opt::createUpdateRegisterNames()); +} + +void cudaq::opt::registerUnrollingPipeline() { + PassPipelineRegistration( + "unrolling-pipeline", + "Fully unroll loops that can be completely unrolled.", + [](OpPassManager &pm, const UnrollPipelineOptions &upo) { + createUnrollingPipeline(pm, upo.threshold, upo.signalFailure, + upo.allowBreak, upo.allowClosedInterval); + }); +} From 047fc6f34615a669c8ceb670a4129055a63bb0e1 Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Fri, 11 Jul 2025 00:27:58 +0000 Subject: [PATCH 19/21] Add opt pipeline and test using CircuitCheck Signed-off-by: Adam Geller --- include/cudaq/Optimizer/InitAllPasses.h | 1 + include/cudaq/Optimizer/Transforms/Passes.h | 1 + .../Transforms/PhasePolynomialPreprocess.cpp | 46 +++++++----- .../qpe2qubits.qke | 70 +++++++++++++++++++ 4 files changed, 102 insertions(+), 16 deletions(-) create mode 100644 test/Transforms/PhasePolynomialRotationMerging/qpe2qubits.qke diff --git a/include/cudaq/Optimizer/InitAllPasses.h b/include/cudaq/Optimizer/InitAllPasses.h index 91724d36a40..2279e20e0e9 100644 --- a/include/cudaq/Optimizer/InitAllPasses.h +++ b/include/cudaq/Optimizer/InitAllPasses.h @@ -21,6 +21,7 @@ inline void registerCudaqPassesAndPipelines() { // CUDA-Q pipelines opt::registerAggressiveEarlyInliningPipeline(); + opt::registerPhasePolynomialOptimizationPipeline(); opt::registerUnrollingPipeline(); opt::registerClassicalOptimizationPipeline(); opt::registerToExecutionManagerCCPipeline(); diff --git a/include/cudaq/Optimizer/Transforms/Passes.h b/include/cudaq/Optimizer/Transforms/Passes.h index f3ccfc5baa9..7eb9822f205 100644 --- a/include/cudaq/Optimizer/Transforms/Passes.h +++ b/include/cudaq/Optimizer/Transforms/Passes.h @@ -25,6 +25,7 @@ void addAggressiveEarlyInlining(mlir::OpPassManager &pm, bool fatalCheck = false); void registerAggressiveEarlyInliningPipeline(); +void registerPhasePolynomialOptimizationPipeline(); void registerUnrollingPipeline(); void registerClassicalOptimizationPipeline(); void registerMappingPipeline(); diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp index a82cb3d0f96..ad7097ad95a 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -244,26 +244,40 @@ class PhasePolynomialPreprocessPass }; } // namespace -static void createUnrollingPipeline(OpPassManager &pm, unsigned threshold, - bool signalFailure, bool allowBreak, - bool allowClosedInterval) { +static void createPhasePolynomialOptPipeline(OpPassManager &pm) { pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass(cudaq::opt::createClassicalMemToReg()); + pm.addNestedPass(createCSEPass()); + //opt::LoopUnrollOptions luo; + //luo.threshold = 2048; + //pm.addNestedPass(opt::createLoopUnroll(luo)); + //pm.addNestedPass(createCanonicalizerPass()); + //pm.addNestedPass(createCSEPass()); + pm.addNestedPass( + cudaq::opt::createFactorQuantumAllocations()); + pm.addNestedPass(cudaq::opt::createMemToReg()); pm.addNestedPass(createCanonicalizerPass()); - cudaq::opt::LoopNormalizeOptions lno{allowClosedInterval, allowBreak}; - pm.addNestedPass(cudaq::opt::createLoopNormalize(lno)); + pm.addNestedPass(createCSEPass()); + pm.addPass(cudaq::opt::createPhasePolynomialPreprocess()); + pm.addNestedPass( + cudaq::opt::createPhasePolynomialRotationMerging()); + pm.addNestedPass(cudaq::opt::createQuakeSimplify()); pm.addNestedPass(createCanonicalizerPass()); - cudaq::opt::LoopUnrollOptions luo{threshold, signalFailure, allowBreak}; - pm.addNestedPass(cudaq::opt::createLoopUnroll(luo)); - pm.addNestedPass(cudaq::opt::createUpdateRegisterNames()); + pm.addNestedPass(createCSEPass()); + cudaq::opt::addAggressiveEarlyInlining(pm); + pm.addNestedPass(cudaq::opt::createRegToMem()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass( + cudaq::opt::createCombineQuantumAllocations()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); } -void cudaq::opt::registerUnrollingPipeline() { - PassPipelineRegistration( - "unrolling-pipeline", - "Fully unroll loops that can be completely unrolled.", - [](OpPassManager &pm, const UnrollPipelineOptions &upo) { - createUnrollingPipeline(pm, upo.threshold, upo.signalFailure, - upo.allowBreak, upo.allowClosedInterval); +void cudaq::opt::registerPhasePolynomialOptimizationPipeline() { + PassPipelineRegistration<>( + "phase-polynomial-opt-pipeline", + "Apply phase polynomial based rotation merging.", + [](OpPassManager &pm) { + createPhasePolynomialOptPipeline(pm); }); } diff --git a/test/Transforms/PhasePolynomialRotationMerging/qpe2qubits.qke b/test/Transforms/PhasePolynomialRotationMerging/qpe2qubits.qke new file mode 100644 index 00000000000..4947adca828 --- /dev/null +++ b/test/Transforms/PhasePolynomialRotationMerging/qpe2qubits.qke @@ -0,0 +1,70 @@ +// ========================================================================== // +// Copyright (c) 2025 NVIDIA Corporation & Affiliates. // +// All rights reserved. // +// // +// This source code and the accompanying materials are made available under // +// the terms of the Apache License 2.0 which accompanies this distribution. // +// ========================================================================== // + +// ========================================================================== // +// Copyright (c) 2025 NVIDIA Corporation & Affiliates. // +// All rights reserved. // +// // +// This source code and the accompanying materials are made available under // +// the terms of the Apache License 2.0 which accompanies this distribution. // +// ========================================================================== // + +// This test is a cleaned up version of quake IR from +// docs/sphinx/applications/phase_estimation.cpp using 2 qubits. +// It uses CircuitCheck to verify that the optimization produces +// an equivalent circuit. + +// RUN: cudaq-opt --phase-polynomial-opt-pipeline %s | CircuitCheck %s + +module attributes {cc.sizeof_string = 32 : i64, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.triple = "x86_64-unknown-linux-gnu", quake.mangled_name_map = {__nvqpp__mlirgen__Z4mainE3$_0 = "_ZZ4mainENK3$_0clERN5cudaq5quditILm2EEE", __nvqpp__mlirgen__function_iqft._Z4iqftN5cudaq5qviewILm2EEE = "_Z4iqftN5cudaq5qviewILm2EEE", __nvqpp__mlirgen__instance_qpeZ4mainE3$_0r1PiGate._ZN3qpeclIZ4mainE3$_08r1PiGateEEviOT_OT0_ = "_ZN3qpeclIZ4mainE3$_08r1PiGateEEviOT_OT0_", __nvqpp__mlirgen__r1PiGate = "_ZN8r1PiGateclERN5cudaq5quditILm2EEE"}} { + func.func @__nvqpp__mlirgen__instance_qpeZ4mainE3$_0r1PiGate._ZN3qpeclIZ4mainE3$_08r1PiGateEEviOT_OT0_(%arg0: !cc.callable<(!quake.ref) -> ()>) attributes {"cudaq-entrypoint", "cudaq-kernel"} { + %cst = arith.constant 5.000000e-01 : f64 + %cst_0 = arith.constant -5.000000e-01 : f64 + %0 = quake.alloca !quake.veq<3> + %1 = quake.extract_ref %0[0] : (!quake.veq<3>) -> !quake.ref + %2 = quake.extract_ref %0[1] : (!quake.veq<3>) -> !quake.ref + %3 = quake.extract_ref %0[2] : (!quake.veq<3>) -> !quake.ref + quake.x %3 : (!quake.ref) -> () + quake.h %1 : (!quake.ref) -> () + quake.h %2 : (!quake.ref) -> () + quake.rz (%cst) %1 : (f64, !quake.ref) -> () + quake.x [%1] %3 : (!quake.ref, !quake.ref) -> () + quake.rz (%cst_0) %3 : (f64, !quake.ref) -> () + quake.x [%1] %3 : (!quake.ref, !quake.ref) -> () + quake.rz (%cst) %3 : (f64, !quake.ref) -> () + quake.rz (%cst) %2 : (f64, !quake.ref) -> () + quake.x [%2] %3 : (!quake.ref, !quake.ref) -> () + quake.rz (%cst_0) %3 : (f64, !quake.ref) -> () + quake.x [%2] %3 : (!quake.ref, !quake.ref) -> () + quake.rz (%cst) %3 : (f64, !quake.ref) -> () + quake.rz (%cst) %2 : (f64, !quake.ref) -> () + quake.x [%2] %3 : (!quake.ref, !quake.ref) -> () + quake.rz (%cst_0) %3 : (f64, !quake.ref) -> () + quake.x [%2] %3 : (!quake.ref, !quake.ref) -> () + quake.rz (%cst) %3 : (f64, !quake.ref) -> () + quake.x [%2] %1 : (!quake.ref, !quake.ref) -> () + quake.x [%1] %2 : (!quake.ref, !quake.ref) -> () + quake.x [%2] %1 : (!quake.ref, !quake.ref) -> () + quake.h %1 : (!quake.ref) -> () + quake.h %2 : (!quake.ref) -> () + %4 = cc.alloca !cc.array + %measOut = quake.mz %1 : (!quake.ref) -> !quake.measure + %5 = quake.discriminate %measOut : (!quake.measure) -> i1 + %6 = cc.cast %4 : (!cc.ptr>) -> !cc.ptr + %7 = cc.cast unsigned %5 : (i1) -> i8 + cc.store %7, %6 : !cc.ptr + %measOut_1 = quake.mz %2 : (!quake.ref) -> !quake.measure + %8 = quake.discriminate %measOut_1 : (!quake.measure) -> i1 + %9 = cc.compute_ptr %4[1] : (!cc.ptr>) -> !cc.ptr + %10 = cc.cast unsigned %8 : (i1) -> i8 + cc.store %10, %9 : !cc.ptr + quake.dealloc %0 : !quake.veq<3> + return + } +} + From 1746f4752745ec34e05430bb768d6018172a416c Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Tue, 15 Jul 2025 19:22:20 +0000 Subject: [PATCH 20/21] Fix bug in sort function Signed-off-by: Adam Geller --- .../Transforms/PhasePolynomialPreprocess.cpp | 21 +++++++------------ .../PhasePolynomialRotationMerging.cpp | 4 ++-- lib/Optimizer/Transforms/Subcircuit.h | 17 ++++++++++----- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp index ad7097ad95a..c522e7f1deb 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialPreprocess.cpp @@ -216,7 +216,6 @@ class PhasePolynomialPreprocessPass auto module = getOperation(); size_t i = 0; SetVector subcircuits; - for (auto &op : module) { if (auto func = dyn_cast(op)) { func.walk([&](quake::XOp op) { @@ -247,13 +246,12 @@ class PhasePolynomialPreprocessPass static void createPhasePolynomialOptPipeline(OpPassManager &pm) { pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); - //opt::LoopUnrollOptions luo; - //luo.threshold = 2048; - //pm.addNestedPass(opt::createLoopUnroll(luo)); - //pm.addNestedPass(createCanonicalizerPass()); - //pm.addNestedPass(createCSEPass()); - pm.addNestedPass( - cudaq::opt::createFactorQuantumAllocations()); + // opt::LoopUnrollOptions luo; + // luo.threshold = 2048; + // pm.addNestedPass(opt::createLoopUnroll(luo)); + // pm.addNestedPass(createCanonicalizerPass()); + // pm.addNestedPass(createCSEPass()); + pm.addNestedPass(cudaq::opt::createFactorQuantumAllocations()); pm.addNestedPass(cudaq::opt::createMemToReg()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); @@ -267,8 +265,7 @@ static void createPhasePolynomialOptPipeline(OpPassManager &pm) { pm.addNestedPass(cudaq::opt::createRegToMem()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); - pm.addNestedPass( - cudaq::opt::createCombineQuantumAllocations()); + pm.addNestedPass(cudaq::opt::createCombineQuantumAllocations()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); } @@ -277,7 +274,5 @@ void cudaq::opt::registerPhasePolynomialOptimizationPipeline() { PassPipelineRegistration<>( "phase-polynomial-opt-pipeline", "Apply phase polynomial based rotation merging.", - [](OpPassManager &pm) { - createPhasePolynomialOptPipeline(pm); - }); + [](OpPassManager &pm) { createPhasePolynomialOptPipeline(pm); }); } diff --git a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp index 8b9469d4b58..eed86775889 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp @@ -204,7 +204,7 @@ class PhasePolynomialRotationMergingPass /// @brief handles a swap between two wires, swapping their phases /// @returns `true` if the swap has been handled and stepping can /// continue, `false` otherwise - bool handleSwap(quake::SwapOp swap) { + bool maybeHandleSwap(quake::SwapOp swap) { auto wire0 = swap.getTarget(0); auto wire1 = swap.getTarget(1); if (wireVisited(wire0) || wireVisited(wire1)) @@ -278,7 +278,7 @@ class PhasePolynomialRotationMergingPass if (store->addOrCombineRotationForPhase(rzop, current_phase)) return; } else if (auto swap = dyn_cast(op)) { - if (!container->handleSwap(swap)) + if (!container->maybeHandleSwap(swap)) return; } diff --git a/lib/Optimizer/Transforms/Subcircuit.h b/lib/Optimizer/Transforms/Subcircuit.h index 0c27aa18ebf..0925e7747d8 100644 --- a/lib/Optimizer/Transforms/Subcircuit.h +++ b/lib/Optimizer/Transforms/Subcircuit.h @@ -49,16 +49,18 @@ class Subcircuit { void addAnchorPoint(Value v) { anchor_points.insert(v); } + void addTerminationPoint(Value v) { termination_points.insert(v); } + void calculateSubcircuitForQubitForward(OpResult v) { seen.insert(v); if (!v.hasOneUse()) { - termination_points.insert(v); + addTerminationPoint(v); return; } Operation *op = v.getUses().begin().getUser(); if (isTerminationPoint(op)) { - termination_points.insert(v); + addTerminationPoint(v); return; } @@ -88,7 +90,7 @@ class Subcircuit { Operation *op = v.getDefiningOp(); if (isTerminationPoint(op)) { - termination_points.insert(v); + addTerminationPoint(v); return; } @@ -166,7 +168,7 @@ class Subcircuit { // Adjust termination border for (auto operand : op->getOperands()) if (operand.getDefiningOp() && ops.contains(operand.getDefiningOp())) - termination_points.insert(operand); + addTerminationPoint(operand); else if (termination_points.contains(operand) && isAfterTerminationPoint(operand)) termination_points.remove(operand); @@ -177,13 +179,16 @@ class Subcircuit { // termination point seen along each wire in the subcircuit // (this means that it is important to build subcircuits // by inspecting controlled gates in topological order) - SmallVector sorted; + std::vector sorted; SetVector pruned; for (auto wire : termination_points) if (!isAfterTerminationPoint(wire) && wire.hasOneUse()) sorted.push_back(wire); auto cmp = [](Value v1, Value v2) { + if (v1.getDefiningOp() == v2.getDefiningOp()) + return dyn_cast(v1).getResultNumber() >= + dyn_cast(v2).getResultNumber(); return !v1.getDefiningOp()->isBeforeInBlock(v2.getDefiningOp()); }; @@ -201,6 +206,8 @@ class Subcircuit { Subcircuit(Operation *cnot); /// @brief Reconstructs a subcircuit from a subcircuit function + /// @returns A newly allocated subcircuit if the function defines + /// a valid subcircuit, `nullptr` otherwise. static Subcircuit *constructFromFunc(func::FuncOp subcircuit_func); SetVector getInitialWires(); From 8c0d5806698ae983a624ed3b6f74fba2385b6e78 Mon Sep 17 00:00:00 2001 From: Adam Geller Date: Tue, 15 Jul 2025 19:23:37 +0000 Subject: [PATCH 21/21] Add preliminary code to detect phase invariant loops Signed-off-by: Adam Geller --- .../PhasePolynomialRotationMerging.cpp | 37 ++++++++++++++++++- lib/Optimizer/Transforms/Subcircuit.cpp | 24 ++++++++++-- lib/Optimizer/Transforms/Subcircuit.h | 3 ++ 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp index eed86775889..253c705b9cc 100644 --- a/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp +++ b/lib/Optimizer/Transforms/PhasePolynomialRotationMerging.cpp @@ -147,7 +147,7 @@ class PhasePolynomialRotationMergingPass class StepperContainer { SmallVector steppers; PhaseStorage *store; - SetVector vars; + SmallVector vars; PhaseStepper *getStepperForValue(Value v) { for (auto *stepper : steppers) @@ -166,7 +166,7 @@ class PhasePolynomialRotationMergingPass // StepperContainer is responsible for cleaning up PhaseSteppers steppers.push_back(new PhaseStepper(circuit, store, wire, new_var)); // StepperContainer is responsible for cleaning up PhaseVariables - vars.insert(new_var); + vars.push_back(new_var); } } @@ -178,6 +178,29 @@ class PhasePolynomialRotationMergingPass delete var; } + static bool isPhaseInvariant(Block *b) { + llvm::outs() << "Inspecting "; + b->dump(); + + auto subcircuit = Subcircuit::constructFromBlock(b); + + if (!subcircuit) + return false; + + llvm::outs() << "Valid subcircuit!\n"; + + auto stepper = StepperContainer(subcircuit); + + while (!stepper.isStopped()) + stepper.stepAll(); + + for (size_t i = 0; i < stepper.steppers.size(); i++) + if (stepper.steppers[i]->current_phase != stepper.vars[i]) + return false; + + return true; + } + bool isStopped() { for (auto *stepper : steppers) if (!stepper->isStopped()) @@ -301,6 +324,16 @@ class PhasePolynomialRotationMergingPass while (!container.isStopped()) container.stepAll(); delete subcircuit; + + // func.walk([&](Operation *op){ + // if (auto loop = dyn_cast(op)) + // if + // (PhaseStepper::StepperContainer::isPhaseInvariant(&loop.getLoopBody().front())) + // { + // llvm::outs() << "Phase invariant!: "; + // loop.dump(); + // } + // }); } }; } // namespace diff --git a/lib/Optimizer/Transforms/Subcircuit.cpp b/lib/Optimizer/Transforms/Subcircuit.cpp index 704e95dcc85..3248a6421e3 100644 --- a/lib/Optimizer/Transforms/Subcircuit.cpp +++ b/lib/Optimizer/Transforms/Subcircuit.cpp @@ -111,9 +111,27 @@ Subcircuit::Subcircuit(Operation *cnot) { terminal_wires.insert(wire); } -/// @brief Reconstructs a subcircuit from a subcircuit function -/// @returns A newly allocated subcircuit if the function defines -/// a valid subcircuit, `nullptr` otherwise. +Subcircuit *Subcircuit::constructFromBlock(Block *b) { + // First, some validation + auto subcircuit = new Subcircuit(); + // Construct the subcircuit + for (auto &op : b->getOperations()) { + auto *opp = &op; + if (opp == b->getTerminator()) + continue; + // Ensure circuit only contains valid operations + if (isTerminationPoint(opp)) + return nullptr; + subcircuit->ops.insert(opp); + } + for (auto arg : b->getArguments()) + if (isa(arg.getType())) + subcircuit->initial_wires.insert(arg); + for (auto ret : b->getTerminator()->getOperands()) + subcircuit->terminal_wires.insert(ret); + return subcircuit; +} + Subcircuit *Subcircuit::constructFromFunc(func::FuncOp subcircuit_func) { // First, some validation if (!subcircuit_func.getOperation()->hasAttr("subcircuit")) diff --git a/lib/Optimizer/Transforms/Subcircuit.h b/lib/Optimizer/Transforms/Subcircuit.h index 0925e7747d8..44a07b99640 100644 --- a/lib/Optimizer/Transforms/Subcircuit.h +++ b/lib/Optimizer/Transforms/Subcircuit.h @@ -210,6 +210,9 @@ class Subcircuit { /// a valid subcircuit, `nullptr` otherwise. static Subcircuit *constructFromFunc(func::FuncOp subcircuit_func); + /// @brief Reconstructs a subcircuit from a block + static Subcircuit *constructFromBlock(Block *b); + SetVector getInitialWires(); SetVector getTerminalWires();