Skip to content

Commit b07ff17

Browse files
authored
MergeModules utility function (#3449)
* * Adding a mergeModule utility function to merge 2 modules * Avoid using walk as it is wrong. Filtering artifacts by their IR node type is incorrect. * Adding unit tests co-authored by: Eric Schweitz <[email protected]> Signed-off-by: Sachin Pisal <[email protected]> * formatting Signed-off-by: Sachin Pisal <[email protected]> * * Moving mergeModules definition to Factory.cpp * Updating the unittests name Signed-off-by: Sachin Pisal <[email protected]> --------- Signed-off-by: Sachin Pisal <[email protected]>
1 parent b72c2c5 commit b07ff17

File tree

5 files changed

+119
-6
lines changed

5 files changed

+119
-6
lines changed

include/cudaq/Optimizer/Builder/Factory.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ std::pair<mlir::func::FuncOp, /*alreadyDefined=*/bool>
294294
getOrAddFunc(mlir::Location loc, mlir::StringRef funcName,
295295
mlir::FunctionType funcTy, mlir::ModuleOp module);
296296

297+
void mergeModules(mlir::ModuleOp into, mlir::ModuleOp from);
297298
} // namespace factory
298299

299300
std::size_t getDataSize(llvm::DataLayout &dataLayout, mlir::Type ty);

lib/Optimizer/Builder/Factory.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,4 +774,17 @@ factory::getOrAddFunc(mlir::Location loc, mlir::StringRef funcName,
774774
return {func, /*defined=*/false};
775775
}
776776

777+
void factory::mergeModules(ModuleOp into, ModuleOp from) {
778+
for (Operation &op : *from.getBody()) {
779+
auto sym = dyn_cast<SymbolOpInterface>(op);
780+
if (!sym)
781+
continue; // Only merge named symbols, avoids duplicating anonymous ops.
782+
783+
// If `into` already has a symbol with this name, skip it.
784+
if (SymbolTable::lookupSymbolIn(into, sym.getName()))
785+
continue;
786+
787+
into.push_back(op.clone());
788+
}
789+
}
777790
} // namespace cudaq::opt

python/extension/CUDAQuantumExtension.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,9 @@ PYBIND11_MODULE(_quakeDialects, m) {
311311
auto ctx = unwrap(mod).getContext();
312312
auto moduleB = mlir::parseSourceString<mlir::ModuleOp>(code, ctx);
313313
auto moduleA = unwrap(mod);
314-
moduleB->walk([&moduleA](mlir::func::FuncOp op) {
315-
if (!moduleA.lookupSymbol<mlir::func::FuncOp>(op.getName()))
316-
moduleA.push_back(op.clone());
317-
return mlir::WalkResult::advance();
318-
});
314+
315+
// Merge symbols from moduleB into moduleA.
316+
cudaq::opt::factory::mergeModules(moduleA, *moduleB);
319317
return kName;
320318
},
321319
"Given a python module name like `mod1.mod2.func`, see if there is a "

unittests/Optimizer/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88

99
include(HandleLLVMOptions)
1010

11-
add_executable(OptimizerUnitTests HermitianTrait.cpp)
11+
add_executable(OptimizerUnitTests HermitianTrait.cpp FactoryMergeModuleTest.cpp)
1212

1313
target_link_libraries(OptimizerUnitTests
1414
PRIVATE
15+
MLIRParser
1516
QuakeDialect
1617
gtest_main
18+
OptimBuilder
1719
)
1820

1921
gtest_discover_tests(OptimizerUnitTests)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2025 NVIDIA Corporation & Affiliates. *
3+
* All rights reserved. *
4+
* *
5+
* This source code and the accompanying materials are made available under *
6+
* the terms of the Apache License 2.0 which accompanies this distribution. *
7+
******************************************************************************/
8+
9+
#include "cudaq/Optimizer/Builder/Factory.h"
10+
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12+
#include "mlir/IR/BuiltinOps.h"
13+
#include "mlir/IR/DialectRegistry.h"
14+
#include "mlir/IR/MLIRContext.h"
15+
#include "mlir/Parser/Parser.h"
16+
#include <gtest/gtest.h>
17+
18+
using namespace mlir;
19+
using cudaq::opt::factory::mergeModules;
20+
21+
static mlir::OwningOpRef<mlir::ModuleOp> parse(MLIRContext &ctx,
22+
llvm::StringRef ir) {
23+
return mlir::parseSourceString<mlir::ModuleOp>(ir, &ctx);
24+
}
25+
26+
TEST(FactoryMergeModuleTest, CopiesMissingFunction) {
27+
DialectRegistry registry;
28+
registry.insert<mlir::func::FuncDialect, mlir::LLVM::LLVMDialect>();
29+
MLIRContext ctx(registry);
30+
31+
auto dst = parse(ctx, R"mlir(
32+
module {
33+
func.func @alreadyThere() { return }
34+
}
35+
)mlir");
36+
ASSERT_TRUE(dst);
37+
38+
auto src = parse(ctx, R"mlir(
39+
module {
40+
func.func @newFunc() { return }
41+
func.func @alreadyThere() { return }
42+
}
43+
)mlir");
44+
ASSERT_TRUE(src);
45+
46+
mergeModules(*dst, *src);
47+
48+
auto newFunc = dst->lookupSymbol<mlir::func::FuncOp>("newFunc");
49+
EXPECT_TRUE(newFunc);
50+
51+
int countAlreadyThere = 0;
52+
dst->walk([&](mlir::func::FuncOp f) {
53+
if (f.getSymName() == "alreadyThere")
54+
countAlreadyThere++;
55+
});
56+
EXPECT_EQ(countAlreadyThere, 1);
57+
}
58+
59+
TEST(FactoryMergeModuleTest, RetainOriginalModuleSymbols) {
60+
DialectRegistry registry;
61+
registry.insert<mlir::func::FuncDialect, mlir::LLVM::LLVMDialect>();
62+
MLIRContext ctx(registry);
63+
64+
auto dst = parse(ctx, R"mlir(
65+
module attributes { test.attr = "keepme" } {
66+
func.func @a() { return }
67+
}
68+
)mlir");
69+
ASSERT_TRUE(dst);
70+
71+
auto src = parse(ctx, R"mlir(
72+
module {
73+
func.func @b() { return }
74+
}
75+
)mlir");
76+
ASSERT_TRUE(src);
77+
78+
mergeModules(*dst, *src);
79+
80+
// Verify the destination attribute remains
81+
auto sattr =
82+
dst->getOperation()->getAttrOfType<mlir::StringAttr>("test.attr");
83+
ASSERT_TRUE(sattr);
84+
ASSERT_EQ(sattr.getValue(), "keepme");
85+
86+
// Both symbols exists exactly once
87+
EXPECT_TRUE(dst->lookupSymbol<mlir::func::FuncOp>("a"));
88+
EXPECT_TRUE(dst->lookupSymbol<mlir::func::FuncOp>("b"));
89+
90+
int countA = 0, countB = 0;
91+
dst->walk([&](mlir::func::FuncOp f) {
92+
if (f.getSymName() == "a")
93+
countA++;
94+
if (f.getSymName() == "b")
95+
countB++;
96+
});
97+
EXPECT_EQ(countA, 1);
98+
EXPECT_EQ(countB, 1);
99+
}

0 commit comments

Comments
 (0)