|
25 | 25 | #include <iterator> |
26 | 26 | #include <llvm/ADT/APFloat.h> |
27 | 27 | #include <llvm/ADT/STLExtras.h> |
| 28 | +#include <llvm/ADT/StringMap.h> |
28 | 29 | #include <memory> |
29 | 30 | #include <mlir/IR/BuiltinOps.h> |
30 | | -#include <ranges> |
31 | 31 |
|
32 | 32 | using namespace mlir; |
33 | 33 |
|
@@ -203,8 +203,10 @@ void stripNamespace(std::string &debugName) { |
203 | 203 |
|
204 | 204 | // Test 1: Verify the total number of registered decomposition patterns |
205 | 205 | TEST_F(DecompositionPatternsTest, TotalPatternCount) { |
206 | | - auto patternNames = cudaq::DecompositionPatternType::RegistryType::entries(); |
207 | | - unsigned int size = std::distance(patternNames.begin(), patternNames.end()); |
| 206 | + auto patternEntries = |
| 207 | + cudaq::DecompositionPatternType::RegistryType::entries(); |
| 208 | + unsigned int size = |
| 209 | + std::distance(patternEntries.begin(), patternEntries.end()); |
208 | 210 | EXPECT_EQ(size, 31) << "Expected 31 decomposition patterns, but found " |
209 | 211 | << size; |
210 | 212 | } |
@@ -297,18 +299,29 @@ TEST_F(DecompositionPatternsTest, DecompositionProducesOnlyTargetGates) { |
297 | 299 | // Collect all gates in the output |
298 | 300 | auto outputGates = collectGateTypesInModule(module); |
299 | 301 |
|
| 302 | + // Map from gate prefix to allowed number of controls |
| 303 | + llvm::StringMap<llvm::SmallVector<size_t>> allowedGates; |
| 304 | + for (auto targetGate : targetGates) { |
| 305 | + auto [tPrefix, tNum] = splitGateAndControls(targetGate); |
| 306 | + allowedGates[tPrefix].push_back(tNum); |
| 307 | + } |
300 | 308 | auto isAllowedGate = [&](StringRef gate) { |
301 | 309 | // Split gate into prefix and number (e.g., "h(1)" -> "h", 1) using |
302 | 310 | // utility function |
303 | 311 | auto [gatePrefix, gateNum] = splitGateAndControls(gate); |
304 | 312 |
|
305 | | - for (auto targetGate : targetGates) { |
306 | | - auto [tPrefix, tNum] = splitGateAndControls(targetGate); |
307 | | - if (gatePrefix == tPrefix && gateNum <= tNum) { |
308 | | - return true; |
309 | | - } |
| 313 | + auto it = allowedGates.find(gatePrefix); |
| 314 | + if (it == allowedGates.end()) { |
| 315 | + return false; |
310 | 316 | } |
311 | | - return false; |
| 317 | + auto allowedNumControls = it->second; |
| 318 | + // Check if the number of controls is in the allowed list (or if any |
| 319 | + // number is allowed) |
| 320 | + auto isEqOrMax = [gateNum](size_t num) { |
| 321 | + return num == gateNum || num == std::numeric_limits<size_t>::max(); |
| 322 | + }; |
| 323 | + return std::find_if(allowedNumControls.begin(), allowedNumControls.end(), |
| 324 | + isEqOrMax) != allowedNumControls.end(); |
312 | 325 | }; |
313 | 326 |
|
314 | 327 | std::vector<std::string> unexpectedGates; |
|
0 commit comments