Skip to content

Commit c9b1b6e

Browse files
committed
address comments
Signed-off-by: Luca Mondada <[email protected]>
1 parent 12252ad commit c9b1b6e

File tree

2 files changed

+43
-14
lines changed

2 files changed

+43
-14
lines changed

lib/Optimizer/Transforms/DecompositionPatterns.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "mlir/IR/PatternMatch.h"
3131
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
3232
#include <llvm/ADT/SmallVector.h>
33+
#include <llvm/ADT/StringMap.h>
3334
#include <llvm/ADT/StringRef.h>
3435
#include <llvm/Support/Casting.h>
3536
#include <llvm/Support/Error.h>
@@ -333,8 +334,9 @@ LogicalResult checkAndExtractControls(quake::OperatorInterface op,
333334
}; \
334335
CUDAQ_REGISTER_TYPE(cudaq::DecompositionPatternType, PATTERN##Type, PATTERN)
335336

336-
// TODO: "SToR1", "TToR1", "R1ToU3", "U3ToRotations" can be generalised
337-
// arbitrary number of controls, but we would need to reason over n HOp patterns
337+
// TODO: The decomposition patterns "SToR1", "TToR1", "R1ToU3", "U3ToRotations"
338+
// can handle arbitrary number of controls, but currently metadata cannot
339+
// capture this. The pattern types therefore only advertise them for 0 controls.
338340

339341
//===----------------------------------------------------------------------===//
340342
// HOp decompositions
@@ -1812,8 +1814,22 @@ REGISTER_DECOMPOSITION_PATTERN(U3ToRotations, "u3", "rz", "rx");
18121814

18131815
void cudaq::populateWithAllDecompositionPatterns(
18141816
mlir::RewritePatternSet &patterns) {
1815-
for (auto &patternType :
1816-
cudaq::DecompositionPatternType::RegistryType::entries()) {
1817-
patterns.add(patternType.instantiate()->create(patterns.getContext()));
1817+
// For deterministic ordering, sort the registered pattern types by name
1818+
// Note that this assumes that no additional patterns are registered at
1819+
// runtime.
1820+
static std::map<std::string, std::unique_ptr<cudaq::DecompositionPatternType>>
1821+
patternTypes = []() {
1822+
std::map<std::string, std::unique_ptr<cudaq::DecompositionPatternType>>
1823+
map;
1824+
for (auto &patternType :
1825+
cudaq::DecompositionPatternType::RegistryType::entries()) {
1826+
map[patternType.getName().str()] = patternType.instantiate();
1827+
}
1828+
return map;
1829+
}();
1830+
1831+
for (auto it = patternTypes.begin(), ie = patternTypes.end(); it != ie;
1832+
++it) {
1833+
patterns.add(it->second->create(patterns.getContext()));
18181834
}
18191835
}

unittests/Optimizer/DecompositionPatternsTest.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
#include <iterator>
2626
#include <llvm/ADT/APFloat.h>
2727
#include <llvm/ADT/STLExtras.h>
28+
#include <llvm/ADT/StringMap.h>
2829
#include <memory>
2930
#include <mlir/IR/BuiltinOps.h>
30-
#include <ranges>
3131

3232
using namespace mlir;
3333

@@ -203,8 +203,10 @@ void stripNamespace(std::string &debugName) {
203203

204204
// Test 1: Verify the total number of registered decomposition patterns
205205
TEST_F(DecompositionPatternsTest, TotalPatternCount) {
206-
auto patternNames = cudaq::DecompositionPatternType::RegistryType::entries();
207-
unsigned int size = std::distance(patternNames.begin(), patternNames.end());
206+
auto patternEntries =
207+
cudaq::DecompositionPatternType::RegistryType::entries();
208+
unsigned int size =
209+
std::distance(patternEntries.begin(), patternEntries.end());
208210
EXPECT_EQ(size, 31) << "Expected 31 decomposition patterns, but found "
209211
<< size;
210212
}
@@ -297,18 +299,29 @@ TEST_F(DecompositionPatternsTest, DecompositionProducesOnlyTargetGates) {
297299
// Collect all gates in the output
298300
auto outputGates = collectGateTypesInModule(module);
299301

302+
// Map from gate prefix to allowed number of controls
303+
llvm::StringMap<llvm::SmallVector<size_t>> allowedGates;
304+
for (auto targetGate : targetGates) {
305+
auto [tPrefix, tNum] = splitGateAndControls(targetGate);
306+
allowedGates[tPrefix].push_back(tNum);
307+
}
300308
auto isAllowedGate = [&](StringRef gate) {
301309
// Split gate into prefix and number (e.g., "h(1)" -> "h", 1) using
302310
// utility function
303311
auto [gatePrefix, gateNum] = splitGateAndControls(gate);
304312

305-
for (auto targetGate : targetGates) {
306-
auto [tPrefix, tNum] = splitGateAndControls(targetGate);
307-
if (gatePrefix == tPrefix && gateNum <= tNum) {
308-
return true;
309-
}
313+
auto it = allowedGates.find(gatePrefix);
314+
if (it == allowedGates.end()) {
315+
return false;
310316
}
311-
return false;
317+
auto allowedNumControls = it->second;
318+
// Check if the number of controls is in the allowed list (or if any
319+
// number is allowed)
320+
auto isEqOrMax = [gateNum](size_t num) {
321+
return num == gateNum || num == std::numeric_limits<size_t>::max();
322+
};
323+
return std::find_if(allowedNumControls.begin(), allowedNumControls.end(),
324+
isEqOrMax) != allowedNumControls.end();
312325
};
313326

314327
std::vector<std::string> unexpectedGates;

0 commit comments

Comments
 (0)