diff --git a/lib/Optimizer/Transforms/DecompositionPatterns.cpp b/lib/Optimizer/Transforms/DecompositionPatterns.cpp index b16013ebe3c..957cf6fc27b 100644 --- a/lib/Optimizer/Transforms/DecompositionPatterns.cpp +++ b/lib/Optimizer/Transforms/DecompositionPatterns.cpp @@ -6,6 +6,22 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ +/** + * This file contains the decomposition patterns that match single gates and + * decompose them into a sequence of other gates. + * + * Each pattern definition contains 3 elements: + * 1. The pattern itself, which defines what ops to match and how to replace + * them. It must inherit from DecompositionPattern. + * 2. The pattern type, which contains the pattern metadata. It must inherit + * from DecompositionPatternType. + * 3. A call to the CUDAQ_REGISTER_TYPE macro to register the pattern in the + * registry. + * + * Writing 2 and 3 manually is a bit verbose. The REGISTER_DECOMPOSITION_PATTERN + * macro can be used for this purpose instead. + */ + #include "DecompositionPatterns.h" #include "cudaq/Optimizer/Builder/Factory.h" #include "cudaq/Optimizer/Dialect/CC/CCOps.h" @@ -13,9 +29,18 @@ #include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include +#include +#include +#include +#include +#include +#include using namespace mlir; +LLVM_INSTANTIATE_REGISTRY(cudaq::DecompositionPatternType::RegistryType) + namespace { //===----------------------------------------------------------------------===// @@ -283,6 +308,36 @@ LogicalResult checkAndExtractControls(quake::OperatorInterface op, return success(); } +// From here on, we define the decomposition patterns ========================== + +/// Macro to register a decomposition pattern with its metadata +/// Usage: REGISTER_DECOMPOSITION_PATTERN(PatternName, "source_op", "target1", +/// "target2", ...) +/// where "source_op" is the operation that the pattern matches and +/// {"target1", "target2", ...} are the operations that the pattern may produce. +#define REGISTER_DECOMPOSITION_PATTERN(PATTERN, SOURCE_OP, ...) \ + struct PATTERN##Type : public cudaq::DecompositionPatternType { \ + using cudaq::DecompositionPatternType::DecompositionPatternType; \ + llvm::StringRef getSourceOp() const override { return SOURCE_OP; } \ + llvm::ArrayRef getTargetOps() const override { \ + static constexpr llvm::StringRef ops[] = {__VA_ARGS__}; \ + return ops; \ + } \ + llvm::StringRef getPatternName() const override { return #PATTERN; } \ + std::unique_ptr \ + create(mlir::MLIRContext *context, \ + mlir::PatternBenefit benefit = 1) const override { \ + std::unique_ptr pattern = \ + RewritePattern::create(context, benefit); \ + return pattern; \ + } \ + }; \ + CUDAQ_REGISTER_TYPE(cudaq::DecompositionPatternType, PATTERN##Type, PATTERN) + +// TODO: The decomposition patterns "SToR1", "TToR1", "R1ToU3", "U3ToRotations" +// can handle arbitrary number of controls, but currently metadata cannot +// capture this. The pattern types therefore only advertise them for 0 controls. + //===----------------------------------------------------------------------===// // HOp decompositions //===----------------------------------------------------------------------===// @@ -291,10 +346,14 @@ LogicalResult checkAndExtractControls(quake::OperatorInterface op, // ─────────────────────────────────── // quake.phased_rx(π/2, π/2) target // quake.phased_rx(π, 0) target -struct HToPhasedRx : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - void initialize() { setDebugName("HToPhasedRx"); } +struct HToPhasedRxType; // forward declare the pattern type, defined in the + // macro below +struct HToPhasedRx + : public cudaq::DecompositionPattern { + + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::HOp op, PatternRewriter &rewriter) const override { @@ -323,14 +382,18 @@ struct HToPhasedRx : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(HToPhasedRx, "h", "phased_rx"); // quake.exp_pauli(theta) target pauliWord // ─────────────────────────────────── // Basis change operations, cnots, rz(theta), adjoint basis change -struct ExpPauliDecomposition : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("ExpPauliDecomposition"); } +struct ExpPauliDecompositionType; // forward declare the pattern type, defined + // in the macro below +struct ExpPauliDecomposition + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::ExpPauliOp expPauliOp, PatternRewriter &rewriter) const override { @@ -496,12 +559,18 @@ struct ExpPauliDecomposition : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(ExpPauliDecomposition, "exp_pauli", "rx", "h", + "x(1)", "rz"); // Naive mapping of R1 to Rz, ignoring the global phase. // This is only expected to work with full inlining and // quake apply specialization. -struct R1ToRz : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct R1ToRzType; // forward declare the pattern type, defined in the macro + // below +struct R1ToRz : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; + LogicalResult matchAndRewrite(quake::R1Op r1Op, PatternRewriter &rewriter) const override { if (!r1Op.getControls().empty()) @@ -513,15 +582,17 @@ struct R1ToRz : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(R1ToRz, "r1", "rz"); // Naive mapping of R1 to U3 // quake.r1(λ) [control] target // ─────────────────────────────────── // quake.u3(0, 0, λ) [control] target -struct R1ToU3 : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("R1ToU3"); } +struct R1ToU3Type; // forward declare the pattern type, defined in the macro + // below +struct R1ToU3 : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::R1Op r1Op, PatternRewriter &rewriter) const override { @@ -533,14 +604,17 @@ struct R1ToU3 : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(R1ToU3, "r1", "u3"); // quake.r1 (θ) target // ───────────────────────────────── // quake.r1(-θ) target -struct R1AdjToR1 : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("R1AdjToR1"); } +struct R1AdjToR1Type; // forward declare the pattern type, defined in the macro + // below +struct R1AdjToR1 + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::R1Op op, PatternRewriter &rewriter) const override { @@ -567,16 +641,19 @@ struct R1AdjToR1 : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(R1AdjToR1, "r1", "r1"); // quake.swap a, b // ─────────────────────────────────── // quake.cnot b, a; // quake.cnot a, b; // quake.cnot b, a; -struct SwapToCX : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("SwapToCX"); } +struct SwapToCXType; // forward declare the pattern type, defined in the macro + // below +struct SwapToCX + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::SwapOp op, PatternRewriter &rewriter) const override { @@ -595,6 +672,7 @@ struct SwapToCX : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(SwapToCX, "swap", "x(1)"); // quake.h control, target // ─────────────────────────────────── @@ -605,10 +683,11 @@ struct SwapToCX : public OpRewritePattern { // quake.t target; // quake.h target; // quake.s target; -struct CHToCX : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("CHToCX"); } +struct CHToCXType; // forward declare the pattern type, defined in the macro + // below +struct CHToCX : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::HOp op, PatternRewriter &rewriter) const override { @@ -634,6 +713,7 @@ struct CHToCX : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(CHToCX, "h(1)", "s", "h", "t", "x(1)"); //===----------------------------------------------------------------------===// // SOp decompositions @@ -644,10 +724,12 @@ struct CHToCX : public OpRewritePattern { // phased_rx(π/2, 0) target // phased_rx(-π/2, π/2) target // phased_rx(-π/2, 0) target -struct SToPhasedRx : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("SToPhasedRx"); } +struct SToPhasedRxType; // forward declare the pattern type, defined in the + // macro below +struct SToPhasedRx + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::SOp op, PatternRewriter &rewriter) const override { @@ -681,6 +763,7 @@ struct SToPhasedRx : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(SToPhasedRx, "s", "phased_rx"); // quake.s [control] target // ──────────────────────────────────── @@ -688,10 +771,11 @@ struct SToPhasedRx : public OpRewritePattern { // // Adding this gate equivalence will enable further decomposition via other // patterns such as controlled-r1 to cnot. -struct SToR1 : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("SToR1"); } +struct SToR1Type; // forward declare the pattern type, defined in the macro + // below +struct SToR1 : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::SOp op, PatternRewriter &rewriter) const override { @@ -710,6 +794,7 @@ struct SToR1 : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(SToR1, "s", "r1"); //===----------------------------------------------------------------------===// // TOp decompositions @@ -720,10 +805,12 @@ struct SToR1 : public OpRewritePattern { // quake.phased_rx(π/2, 0) target // quake.phased_rx(-π/4, π/2) target // quake.phased_rx(-π/2, 0) target -struct TToPhasedRx : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("TToPhasedRx"); } +struct TToPhasedRxType; // forward declare the pattern type, defined in the + // macro below +struct TToPhasedRx + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::TOp op, PatternRewriter &rewriter) const override { @@ -758,6 +845,7 @@ struct TToPhasedRx : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(TToPhasedRx, "t", "phased_rx"); // quake.t [control] target // ──────────────────────────────────── @@ -765,10 +853,11 @@ struct TToPhasedRx : public OpRewritePattern { // // Adding this gate equivalence will enable further decomposition via other // patterns such as controlled-r1 to cnot. -struct TToR1 : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("TToR1"); } +struct TToR1Type; // forward declare the pattern type, defined in the macro + // below +struct TToR1 : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::TOp op, PatternRewriter &rewriter) const override { @@ -786,6 +875,7 @@ struct TToR1 : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(TToR1, "t", "r1"); //===----------------------------------------------------------------------===// // XOp decompositions @@ -796,10 +886,11 @@ struct TToR1 : public OpRewritePattern { // quake.h target // quake.z [control] target // quake.h target -struct CXToCZ : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("CXToCZ"); } +struct CXToCZType; // forward declare the pattern type, defined in the macro + // below +struct CXToCZ : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::XOp op, PatternRewriter &rewriter) const override { @@ -833,16 +924,18 @@ struct CXToCZ : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(CXToCZ, "x(1)", "h", "z(1)"); // quake.x [controls] target // ────────────────────────────────── // quake.h target // quake.z [controls] target // quake.h target -struct CCXToCCZ : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("CCXToCCZ"); } +struct CCXToCCZType; // forward declare the pattern type, defined in the macro + // below +struct CCXToCCZ : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::XOp op, PatternRewriter &rewriter) const override { @@ -865,14 +958,17 @@ struct CCXToCCZ : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(CCXToCCZ, "x(2)", "h", "z(2)"); // quake.x target // ─────────────────────────────── // quake.phased_rx(π, 0) target -struct XToPhasedRx : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("XToPhasedRx"); } +struct XToPhasedRxType; // forward declare the pattern type, defined in the + // macro below +struct XToPhasedRx + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::XOp op, PatternRewriter &rewriter) const override { @@ -897,6 +993,7 @@ struct XToPhasedRx : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(XToPhasedRx, "x", "phased_rx"); //===----------------------------------------------------------------------===// // YOp decompositions @@ -905,10 +1002,12 @@ struct XToPhasedRx : public OpRewritePattern { // quake.y target // ───────────────────────────────── // quake.phased_rx(π, -π/2) target -struct YToPhasedRx : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("YToPhasedRx"); } +struct YToPhasedRxType; // forward declare the pattern type, defined in the + // macro below +struct YToPhasedRx + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::YOp op, PatternRewriter &rewriter) const override { @@ -934,6 +1033,7 @@ struct YToPhasedRx : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(YToPhasedRx, "y", "phased_rx"); // quake.y [control] target // ─────────────────────────────────── @@ -941,10 +1041,11 @@ struct YToPhasedRx : public OpRewritePattern { // quake.x [control] target; // quake.s target; -struct CYToCX : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("CYToCX"); } +struct CYToCXType; // forward declare the pattern type, defined in the macro + // below +struct CYToCX : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::YOp op, PatternRewriter &rewriter) const override { @@ -970,6 +1071,7 @@ struct CYToCX : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(CYToCX, "y(1)", "s", "x(1)"); //===----------------------------------------------------------------------===// // ZOp decompositions @@ -986,10 +1088,11 @@ struct CYToCX : public OpRewritePattern { // └───┘ └───┘└───┘└───┘└───┘└───┘└───┘└───┘ └───┘ // // NOTE: `┴` denotes the adjoint of `T`. -struct CCZToCX : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("CCZToCX"); } +struct CCZToCXType; // forward declare the pattern type, defined in the macro + // below +struct CCZToCX : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::ZOp op, PatternRewriter &rewriter) const override { @@ -1045,16 +1148,19 @@ struct CCZToCX : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(CCZToCX, "z(2)", "t", "x(1)"); // quake.z [control] target // ────────────────────────────────── // quake.h target // quake.x [control] target // quake.h target -struct CZToCX : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - void initialize() { setDebugName("CZToCX"); } +struct CZToCXType; // forward declare the pattern type, defined in the macro + // below +struct CZToCX : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::ZOp op, PatternRewriter &rewriter) const override { @@ -1088,16 +1194,19 @@ struct CZToCX : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(CZToCX, "z(1)", "h", "x(1)"); // quake.z target // ────────────────────────────────── // quake.phased_rx(π/2, 0) target // quake.phased_rx(-π, π/2) target // quake.phased_rx(-π/2, 0) target -struct ZToPhasedRx : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("ZToPhasedRx"); } +struct ZToPhasedRxType; // forward declare the pattern type, defined in the + // macro below +struct ZToPhasedRx + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::ZOp op, PatternRewriter &rewriter) const override { @@ -1130,6 +1239,7 @@ struct ZToPhasedRx : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(ZToPhasedRx, "z", "phased_rx"); //===----------------------------------------------------------------------===// // R1Op decompositions @@ -1142,10 +1252,11 @@ struct ZToPhasedRx : public OpRewritePattern { // quake.r1(-λ/2) target // quake.x [control] target // quake.r1(λ/2) target -struct CR1ToCX : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("CR1ToCX"); } +struct CR1ToCXType; // forward declare the pattern type, defined in the macro + // below +struct CR1ToCX : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::R1Op op, PatternRewriter &rewriter) const override { @@ -1187,16 +1298,19 @@ struct CR1ToCX : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(CR1ToCX, "r1(1)", "r1", "x(1)"); // quake.r1(λ) target // ────────────────────────────────── // quake.phased_rx(π/2, 0) target // quake.phased_rx(-λ, π/2) target // quake.phased_rx(-π/2, 0) target -struct R1ToPhasedRx : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("R1ToPhasedRx"); } +struct R1ToPhasedRxType; // forward declare the pattern type, defined in the + // macro below +struct R1ToPhasedRx + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::R1Op op, PatternRewriter &rewriter) const override { @@ -1233,6 +1347,7 @@ struct R1ToPhasedRx : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(R1ToPhasedRx, "r1", "phased_rx"); //===----------------------------------------------------------------------===// // RxOp decompositions @@ -1246,10 +1361,11 @@ struct R1ToPhasedRx : public OpRewritePattern { // quake.x [control] target // quake.ry(θ/2) target // quake.rz(-π/2) target -struct CRxToCX : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("CRxToCX"); } +struct CRxToCXType; // forward declare the pattern type, defined in the macro + // below +struct CRxToCX : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::RxOp op, PatternRewriter &rewriter) const override { @@ -1292,14 +1408,17 @@ struct CRxToCX : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(CRxToCX, "rx(1)", "s", "x(1)", "ry", "rz"); // quake.rx(θ) target // ─────────────────────────────── // quake.phased_rx(θ, 0) target -struct RxToPhasedRx : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("RxToPhasedRx"); } +struct RxToPhasedRxType; // forward declare the pattern type, defined in the + // macro below +struct RxToPhasedRx + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::RxOp op, PatternRewriter &rewriter) const override { @@ -1327,14 +1446,17 @@ struct RxToPhasedRx : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(RxToPhasedRx, "rx", "phased_rx"); // quake.rx (θ) target // ───────────────────────────────── // quake.rx(-θ) target -struct RxAdjToRx : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("RxAdjToRx"); } +struct RxAdjToRxType; // forward declare the pattern type, defined in the macro + // below +struct RxAdjToRx + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::RxOp op, PatternRewriter &rewriter) const override { @@ -1362,6 +1484,7 @@ struct RxAdjToRx : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(RxAdjToRx, "rx", "rx"); //===----------------------------------------------------------------------===// // RyOp decompositions @@ -1373,10 +1496,11 @@ struct RxAdjToRx : public OpRewritePattern { // quake.x [control] target // quake.ry(-θ/2) target // quake.x [control] target -struct CRyToCX : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("CRyToCX"); } +struct CRyToCXType; // forward declare the pattern type, defined in the macro + // below +struct CRyToCX : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::RyOp op, PatternRewriter &rewriter) const override { @@ -1413,14 +1537,17 @@ struct CRyToCX : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(CRyToCX, "ry(1)", "ry", "x(1)"); // quake.ry(θ) target // ───────────────────────────────── // quake.phased_rx(θ, π/2) target -struct RyToPhasedRx : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("RyToPhasedRx"); } +struct RyToPhasedRxType; // forward declare the pattern type, defined in the + // macro below +struct RyToPhasedRx + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::RyOp op, PatternRewriter &rewriter) const override { @@ -1448,14 +1575,17 @@ struct RyToPhasedRx : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(RyToPhasedRx, "ry", "phased_rx"); // quake.ry (θ) target // ───────────────────────────────── // quake.ry(-θ) target -struct RyAdjToRy : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("RyAdjToRy"); } +struct RyAdjToRyType; // forward declare the pattern type, defined in the macro + // below +struct RyAdjToRy + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::RyOp op, PatternRewriter &rewriter) const override { @@ -1483,6 +1613,7 @@ struct RyAdjToRy : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(RyAdjToRy, "ry", "ry"); //===----------------------------------------------------------------------===// // RzOp decompositions @@ -1494,10 +1625,11 @@ struct RyAdjToRy : public OpRewritePattern { // quake.x [control] target // quake.rz(-λ/2) target // quake.x [control] target -struct CRzToCX : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("CRzToCX"); } +struct CRzToCXType; // forward declare the pattern type, defined in the macro + // below +struct CRzToCX : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::RzOp op, PatternRewriter &rewriter) const override { @@ -1534,16 +1666,19 @@ struct CRzToCX : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(CRzToCX, "rz(1)", "rz", "x(1)"); // quake.rz(θ) target // ────────────────────────────────── // quake.phased_rx(π/2, 0) target // quake.phased_rx(-θ, π/2) target // quake.phased_rx(-π/2, 0) target -struct RzToPhasedRx : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("RzToPhasedRx"); } +struct RzToPhasedRxType; // forward declare the pattern type, defined in the + // macro below +struct RzToPhasedRx + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::RzOp op, PatternRewriter &rewriter) const override { @@ -1580,14 +1715,17 @@ struct RzToPhasedRx : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(RzToPhasedRx, "rz", "phased_rx"); // quake.rz (θ) target // ───────────────────────────────── // quake.rz(-θ) target -struct RzAdjToRz : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("RzAdjToRz"); } +struct RzAdjToRzType; // forward declare the pattern type, defined in the macro + // below +struct RzAdjToRz + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::RzOp op, PatternRewriter &rewriter) const override { @@ -1615,6 +1753,7 @@ struct RzAdjToRz : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(RzAdjToRz, "rz", "rz"); //===----------------------------------------------------------------------===// // U3Op decompositions @@ -1627,10 +1766,12 @@ struct RzAdjToRz : public OpRewritePattern { // quake.rz(θ) target // quake.rx(-π/2) target // quake.rz(ϕ) target -struct U3ToRotations : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("U3ToRotations"); } +struct U3ToRotationsType; // forward declare the pattern type, defined in the + // macro below +struct U3ToRotations + : public cudaq::DecompositionPattern { + using cudaq::DecompositionPattern::DecompositionPattern; LogicalResult matchAndRewrite(quake::U3Op op, PatternRewriter &rewriter) const override { @@ -1667,59 +1808,28 @@ struct U3ToRotations : public OpRewritePattern { return success(); } }; +REGISTER_DECOMPOSITION_PATTERN(U3ToRotations, "u3", "rz", "rx"); } // namespace -//===----------------------------------------------------------------------===// -// Populating pattern sets -//===----------------------------------------------------------------------===// +void cudaq::populateWithAllDecompositionPatterns( + mlir::RewritePatternSet &patterns) { + // For deterministic ordering, sort the registered pattern types by name + // Note that this assumes that no additional patterns are registered at + // runtime. + static std::map> + patternTypes = []() { + std::map> + map; + for (auto &patternType : + cudaq::DecompositionPatternType::RegistryType::entries()) { + map[patternType.getName().str()] = patternType.instantiate(); + } + return map; + }(); -void cudaq::populateWithAllDecompositionPatterns(RewritePatternSet &patterns) { - // clang-format off - patterns.insert< - // HOp patterns - HToPhasedRx, - CHToCX, - // SOp patterns - SToPhasedRx, - SToR1, - // TOp patterns - TToPhasedRx, - TToR1, - // XOp patterns - CXToCZ, - CCXToCCZ, - XToPhasedRx, - // YOp patterns - YToPhasedRx, - CYToCX, - // ZOp patterns - CZToCX, - CCZToCX, - ZToPhasedRx, - // R1Op patterns - CR1ToCX, - R1ToPhasedRx, - R1ToRz, - R1ToU3, - R1AdjToR1, - // RxOp patterns - CRxToCX, - RxToPhasedRx, - RxAdjToRx, - // RyOp patterns - CRyToCX, - RyToPhasedRx, - RyAdjToRy, - // RzOp patterns - CRzToCX, - RzToPhasedRx, - RzAdjToRz, - // Swap - SwapToCX, - // U3Op - U3ToRotations, - ExpPauliDecomposition - >(patterns.getContext()); - // clang-format on + for (auto it = patternTypes.begin(), ie = patternTypes.end(); it != ie; + ++it) { + patterns.add(it->second->create(patterns.getContext())); + } } diff --git a/lib/Optimizer/Transforms/DecompositionPatterns.h b/lib/Optimizer/Transforms/DecompositionPatterns.h index c55e0f78e86..9cb68e40522 100644 --- a/lib/Optimizer/Transforms/DecompositionPatterns.h +++ b/lib/Optimizer/Transforms/DecompositionPatterns.h @@ -8,12 +8,63 @@ #pragma once +#include "common/Registry.h" +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/PatternMatch.h" + namespace mlir { class RewritePatternSet; } namespace cudaq { +//===----------------------------------------------------------------------===// +// Base classes for decomposition patterns +//===----------------------------------------------------------------------===// + +/// Base class for pattern types to enable registration via the llvm::Registry +/// system. Stores the pattern metadata and provides a factory method to create +/// new instances of the pattern. +/// +/// Register decomposition patterns using +/// CUDAQ_REGISTER_TYPE(cudaq::DecompositionPatternType, MyPatternType, +/// pattern_name) +/// where pattern_name is the same as MyPatternType().getPatternName(). +class DecompositionPatternType + : public registry::RegisteredType { +public: + virtual ~DecompositionPatternType() = default; + + /// Get the source operation this pattern matches and decomposes. + virtual llvm::StringRef getSourceOp() const = 0; + + /// Get the target operations this pattern may produce + virtual llvm::ArrayRef getTargetOps() const = 0; + + /// Get the name of the pattern. + virtual llvm::StringRef getPatternName() const = 0; + + /// Create a new instance of the pattern. + virtual std::unique_ptr + create(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) const = 0; +}; + +/// Base class for all decomposition patterns. All decomposition patterns must +/// inherit from this class. Templated on +/// - the pattern type (which inherits from DecompositionPatternType), and +/// - the operation type that the pattern matches. +/// Used as follows class MyPattern : public DecompositionPattern +/// {...}; +template +class DecompositionPattern : public mlir::OpRewritePattern { +public: + using mlir::OpRewritePattern::OpRewritePattern; + + /// Set the debug name to the registered name + void initialize() { this->setDebugName(PatternType().getPatternName()); } +}; + void populateWithAllDecompositionPatterns(mlir::RewritePatternSet &patterns); -} +} // namespace cudaq diff --git a/test/Transforms/BasisConversion/all-qir-gates.qke b/test/Transforms/BasisConversion/all-qir-gates.qke index 19d44d89d29..c5b3de13b24 100644 --- a/test/Transforms/BasisConversion/all-qir-gates.qke +++ b/test/Transforms/BasisConversion/all-qir-gates.qke @@ -304,8 +304,7 @@ module { // CHECK: %[[VAL_87:.*]] = cc.cast unsigned %[[VAL_85]] : (i1) -> i8 // CHECK: cc.store %[[VAL_87]], %[[VAL_86]] : !cc.ptr // CHECK: %[[VAL_88:.*]] = quake.alloca !quake.ref -// CHECK: %[[VAL_89:.*]] = arith.negf %[[VAL_4]] : f64 -// CHECK: quake.rz (%[[VAL_89]]) %[[VAL_88]] : (f64, !quake.ref) -> () +// CHECK: quake.rz (%[[VAL_4]]) %[[VAL_88]] : (f64, !quake.ref) -> () // CHECK: quake.rx (%[[VAL_0]]) %[[VAL_88]] : (f64, !quake.ref) -> () // CHECK: quake.ry (%[[VAL_3]]) %[[VAL_88]] : (f64, !quake.ref) -> () // CHECK: quake.rz (%[[VAL_2]]) %[[VAL_88]] : (f64, !quake.ref) -> () diff --git a/unittests/Optimizer/CMakeLists.txt b/unittests/Optimizer/CMakeLists.txt index b8cbd46f946..3ff19b74813 100644 --- a/unittests/Optimizer/CMakeLists.txt +++ b/unittests/Optimizer/CMakeLists.txt @@ -8,16 +8,30 @@ include(HandleLLVMOptions) -add_executable(OptimizerUnitTests HermitianTrait.cpp FactoryMergeModuleTest.cpp) +add_executable(OptimizerUnitTests + HermitianTrait.cpp + FactoryMergeModuleTest.cpp + DecompositionPatternsTest.cpp +) target_link_libraries(OptimizerUnitTests PRIVATE + MLIRArithDialect + MLIRFuncDialect MLIRParser + MLIRPass + MLIRTransforms QuakeDialect + CCDialect + OptTransforms gtest_main OptimBuilder ) +target_include_directories(OptimizerUnitTests + PRIVATE ${CMAKE_SOURCE_DIR}/runtime +) + gtest_discover_tests(OptimizerUnitTests) add_executable(test_quake_synth QuakeSynthTester.cpp) diff --git a/unittests/Optimizer/DecompositionPatternsTest.cpp b/unittests/Optimizer/DecompositionPatternsTest.cpp new file mode 100644 index 00000000000..a4022add8c7 --- /dev/null +++ b/unittests/Optimizer/DecompositionPatternsTest.cpp @@ -0,0 +1,344 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +#include "../../lib/Optimizer/Transforms/DecompositionPatterns.h" +#include "cudaq/Optimizer/Builder/Factory.h" +#include "cudaq/Optimizer/Dialect/CC/CCDialect.h" +#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h" +#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" +#include "cudaq/Optimizer/Transforms/Passes.h" + +#include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/Builders.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; + +namespace { + +class DecompositionPatternsTest : public ::testing::Test { +protected: + void SetUp() override { + context = std::make_unique(); + context->loadDialect(); + } + + std::unique_ptr context; +}; + +// Helper to parse control count from gate string like "x(1)" or "z(2)" +std::pair parseGateSpec(StringRef gateSpec) { + auto pos = gateSpec.find('('); + if (pos == StringRef::npos) { + return {gateSpec.str(), 0}; + } + + std::string gateName = gateSpec.substr(0, pos).str(); + StringRef numStr = gateSpec.substr(pos + 1); + size_t numControls = 0; + + if (numStr.startswith("n")) { + // Arbitrary number of controls - use a reasonable test value + numControls = std::numeric_limits::max(); + } else { + numStr.consumeInteger(10, numControls); + } + + return {gateName, numControls}; +} + +// Helper function to create a test module with a single gate operation +ModuleOp createTestModule(MLIRContext *context, StringRef gateSpec) { + auto [gateName, numControls] = parseGateSpec(gateSpec); + + // Limit the number of controls to 2 + numControls = std::min(numControls, 2); + + size_t numQubits; + if (gateName == "swap") { + assert(numControls == 0); + numQubits = 2; + } else { + numQubits = numControls + 1; + } + + OpBuilder builder(context); + auto module = builder.create(builder.getUnknownLoc()); + builder.setInsertionPointToEnd(module.getBody()); + + // Create function type: (qubits...) -> () + SmallVector inputTypes; + auto refType = quake::RefType::get(context); + for (size_t i = 0; i < numQubits; ++i) { + inputTypes.push_back(refType); + } + auto funcType = builder.getFunctionType(inputTypes, {}); + + // Create function + auto func = builder.create(builder.getUnknownLoc(), "test_func", + funcType); + auto *entry = func.addEntryBlock(); + builder.setInsertionPointToStart(entry); + + // Get operands (controls and target) + SmallVector controls; + for (size_t i = 0; i < numControls; ++i) { + controls.push_back(entry->getArgument(i)); + } + Value target = entry->getArgument(numControls); + + // Create the gate operation based on gate name + Location loc = builder.getUnknownLoc(); + + Value pi_2 = cudaq::opt::factory::createFloatConstant(loc, builder, M_PI_2, + builder.getF64Type()); + + if (gateName == "h") { + builder.create(loc, controls, target); + } else if (gateName == "s") { + builder.create(loc, controls, target); + } else if (gateName == "t") { + builder.create(loc, controls, target); + } else if (gateName == "x") { + builder.create(loc, controls, target); + } else if (gateName == "y") { + builder.create(loc, controls, target); + } else if (gateName == "z") { + builder.create(loc, controls, target); + } else if (gateName == "rx") { + builder.create(loc, ValueRange{pi_2}, controls, target); + } else if (gateName == "ry") { + builder.create(loc, ValueRange{pi_2}, controls, target); + } else if (gateName == "rz") { + builder.create(loc, ValueRange{pi_2}, controls, target); + } else if (gateName == "r1") { + builder.create(loc, ValueRange{pi_2}, controls, target); + } else if (gateName == "u3") { + builder.create(loc, ValueRange{pi_2, pi_2, pi_2}, controls, + target); + } else if (gateName == "phased_rx") { + builder.create(loc, ValueRange{{pi_2, pi_2}}, controls, + target); + } else if (gateName == "swap") { + // Swap needs 2 targets + Value target = entry->getArgument(0); + Value target2 = entry->getArgument(1); + builder.create(loc, ValueRange{target, target2}); + } else { + // Unsupported gate for this test + ADD_FAILURE() << "unknown gate: " << gateName; + } + + builder.create(loc); + return module; +} + +// Helper to collect all gate types in a module +llvm::StringSet<> collectGateTypesInModule(ModuleOp module) { + llvm::StringSet<> gates; + + module.walk([&](Operation *op) { + if (auto optor = dyn_cast(op)) { + std::string gateName = optor->getName().stripDialect().str(); + auto numControls = optor.getControls().size(); + + if (numControls > 0) { + gateName += "(" + std::to_string(numControls) + ")"; + } + + gates.insert(gateName); + } + }); + + return gates; +} + +inline std::pair +splitGateAndControls(llvm::StringRef gate) { + auto parenOpen = gate.find('('); + std::string gatePrefix; + size_t gateNum = 0; + if (parenOpen != llvm::StringRef::npos) { + gatePrefix = gate.substr(0, parenOpen).str(); + auto parenClose = gate.find(')', parenOpen); + assert(parenClose != llvm::StringRef::npos); + std::string numStr = + gate.substr(parenOpen + 1, parenClose - parenOpen - 1).str(); + if (numStr == "n") + gateNum = std::numeric_limits::max(); + else + gateNum = static_cast(std::stoul(numStr)); + } else { + gatePrefix = gate.str(); + } + return {gatePrefix, gateNum}; +}; + +void stripNamespace(std::string &debugName) { + auto lastColon = debugName.find_last_of(':'); + if (lastColon != llvm::StringRef::npos) { + debugName = debugName.substr(lastColon + 1); + } +} + +} // namespace + +// Test 1: Verify the total number of registered decomposition patterns +TEST_F(DecompositionPatternsTest, TotalPatternCount) { + auto patternEntries = + cudaq::DecompositionPatternType::RegistryType::entries(); + unsigned int size = + std::distance(patternEntries.begin(), patternEntries.end()); + EXPECT_EQ(size, 31) << "Expected 31 decomposition patterns, but found " + << size; +} + +// Test 2: Verify pattern names match getDebugName() +TEST_F(DecompositionPatternsTest, PatternNamesMatchDebugNames) { + auto patternEntries = + cudaq::DecompositionPatternType::RegistryType::entries(); + + for (auto &entry : patternEntries) { + auto patternName = entry.getName(); + // Create the pattern + auto patternType = cudaq::registry::get( + patternName.str()); + ASSERT_NE(patternType, nullptr) + << "Failed to recover registered pattern type: " << patternName.str(); + + auto pattern = patternType->create(context.get()); + ASSERT_NE(pattern, nullptr) + << "Failed to create pattern: " << patternName.str(); + + // Get the debug name + auto debugName = pattern->getDebugName().str(); + stripNamespace(debugName); + + // Verify they match + EXPECT_EQ(patternName.str(), debugName) + << "Pattern name '" << patternName.str() + << "' does not match debug name '" << debugName << "'"; + } +} + +// Test 3: Verify metadata is consistent (source and target gates are valid) +TEST_F(DecompositionPatternsTest, MetadataConsistency) { + auto patternEntries = + cudaq::DecompositionPatternType::RegistryType::entries(); + + for (auto &entry : patternEntries) { + std::string patternName = entry.getName().str(); + auto patternType = entry.instantiate(); + std::string sourceGate = patternType->getSourceOp().str(); + auto targetGates = patternType->getTargetOps(); + + // Source gate should not be empty + EXPECT_FALSE(sourceGate.empty()) + << "Pattern '" << patternName << "' has empty source gate"; + + // Target gates should not be empty + EXPECT_FALSE(targetGates.empty()) + << "Pattern '" << patternName << "' has empty target gates"; + + // All target gates should be non-empty + for (auto targetGate : targetGates) { + EXPECT_FALSE(targetGate.empty()) + << "Pattern '" << patternName << "' has empty target gate in list"; + } + } +} + +// Test 4: Verify pattern decompositions produce only target gates +TEST_F(DecompositionPatternsTest, DecompositionProducesOnlyTargetGates) { + auto patternEntries = + cudaq::DecompositionPatternType::RegistryType::entries(); + + for (auto &entry : patternEntries) { + std::string patternName = entry.getName().str(); + auto patternType = entry.instantiate(); + std::string sourceGate = patternType->getSourceOp().str(); + auto targetGates = patternType->getTargetOps(); + + // TODO: add support for testing exp_pauli + if (sourceGate.starts_with("exp_pauli")) + continue; + + // Create a test module with the source gate + auto module = createTestModule(context.get(), sourceGate); + + // Apply the decomposition pass with only this pattern enabled + PassManager pm(context.get()); + cudaq::opt::DecompositionPassOptions options; + std::string ownedEnabledPatterns[]{patternName}; + options.enabledPatterns = ownedEnabledPatterns; + pm.addPass(cudaq::opt::createDecompositionPass(options)); + + // Run the pass + auto result = pm.run(module); + ASSERT_TRUE(succeeded(result)) + << "Decomposition pass failed for pattern: " << patternName; + + // Collect all gates in the output + auto outputGates = collectGateTypesInModule(module); + + // Map from gate prefix to allowed number of controls + llvm::StringMap> allowedGates; + for (auto targetGate : targetGates) { + auto [tPrefix, tNum] = splitGateAndControls(targetGate); + allowedGates[tPrefix].push_back(tNum); + } + auto isAllowedGate = [&](StringRef gate) { + // Split gate into prefix and number (e.g., "h(1)" -> "h", 1) using + // utility function + auto [gatePrefix, gateNum] = splitGateAndControls(gate); + + auto it = allowedGates.find(gatePrefix); + if (it == allowedGates.end()) { + return false; + } + auto allowedNumControls = it->second; + // Check if the number of controls is in the allowed list (or if any + // number is allowed) + auto isEqOrMax = [gateNum](size_t num) { + return num == gateNum || num == std::numeric_limits::max(); + }; + return std::find_if(allowedNumControls.begin(), allowedNumControls.end(), + isEqOrMax) != allowedNumControls.end(); + }; + + std::vector unexpectedGates; + for (auto &outputGate : outputGates) { + if (!isAllowedGate(outputGate.getKey())) { + unexpectedGates.push_back(outputGate.getKey().str()); + } + } + + if (!unexpectedGates.empty()) { + auto expectedGatesStr = llvm::join(targetGates, ", "); + auto unexpectedGatesStr = llvm::join(unexpectedGates, ", "); + + ADD_FAILURE() << "Pattern '" << patternName + << "' produced unexpected gates.\n" + << " Allowed gates: {" << expectedGatesStr << "}\n" + << " Found: {" << unexpectedGatesStr << "}"; + } + } +}