Skip to content

Commit 3358505

Browse files
Fix walk_mlir_operations with manual iterator (#1335)
* Implement manual iterator * Remove unsafe
1 parent 1c60fa8 commit 3358505

File tree

3 files changed

+50
-87
lines changed

3 files changed

+50
-87
lines changed

src/compiler.rs

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ use mlir_sys::{
100100
use std::{
101101
cell::Cell,
102102
collections::{hash_map::Entry, BTreeMap, HashMap, HashSet},
103-
ffi::c_void,
104103
ops::Deref,
105104
};
106105

@@ -652,17 +651,11 @@ fn compile_func(
652651
// When statistics are enabled, we iterate from the start
653652
// to the end block of the compiled libfunc, and count all the operations.
654653
if let Some(&mut ref mut stats) = stats {
655-
unsafe extern "C" fn callback(
656-
_: mlir_sys::MlirOperation,
657-
data: *mut c_void,
658-
) -> mlir_sys::MlirWalkResult {
659-
let data = data.cast::<u128>().as_mut().unwrap();
660-
*data += 1;
661-
0
662-
}
663-
let data = walk_mlir_block(*block, *helper.last_block.get(), callback, 0);
654+
let mut operations = 0;
655+
walk_mlir_block(*block, *helper.last_block.get(), &mut |_| operations += 1);
664656
let name = libfunc_to_name(libfunc).to_string();
665-
*stats.mlir_operations_by_libfunc.entry(name).or_insert(0) += data;
657+
658+
*stats.mlir_operations_by_libfunc.entry(name).or_insert(0) += operations;
666659
}
667660

668661
native_assert!(

src/context.rs

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use mlir_sys::{
3333
mlirLLVMDIModuleAttrGet, MlirLLVMDIEmissionKind_MlirLLVMDIEmissionKindFull,
3434
MlirLLVMDINameTableKind_MlirLLVMDINameTableKindDefault,
3535
};
36-
use std::{ffi::c_void, sync::OnceLock, time::Instant};
36+
use std::{sync::OnceLock, time::Instant};
3737

3838
/// Context of IRs, dialects and passes for Cairo programs compilation.
3939
#[derive(Debug, Eq, PartialEq)]
@@ -204,16 +204,9 @@ impl NativeContext {
204204
}
205205

206206
if let Some(&mut ref mut stats) = stats {
207-
unsafe extern "C" fn callback(
208-
_: mlir_sys::MlirOperation,
209-
data: *mut c_void,
210-
) -> mlir_sys::MlirWalkResult {
211-
let data = data.cast::<u128>().as_mut().unwrap();
212-
*data += 1;
213-
0
214-
}
215-
let data = walk_mlir_operations(module.as_operation(), callback, 0);
216-
stats.mlir_operation_count = Some(data)
207+
let mut operations = 0;
208+
walk_mlir_operations(module.as_operation(), &mut |_| operations += 1);
209+
stats.mlir_operation_count = Some(operations)
217210
}
218211

219212
let pre_mlir_passes_instant = Instant::now();

src/utils/walk_ir.rs

Lines changed: 42 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use std::ffi::c_void;
2-
31
use llvm_sys::{
42
core::{
53
LLVMGetFirstBasicBlock, LLVMGetFirstFunction, LLVMGetFirstInstruction,
@@ -9,85 +7,64 @@ use llvm_sys::{
97
LLVMBasicBlock, LLVMValue,
108
};
119
use melior::ir::{BlockLike, BlockRef, OperationRef};
12-
use mlir_sys::{MlirOperation, MlirWalkResult};
13-
14-
type OperationWalkCallback =
15-
unsafe extern "C" fn(MlirOperation, *mut ::std::os::raw::c_void) -> MlirWalkResult;
1610

1711
/// Traverses the given operation tree in preorder.
1812
///
19-
/// Calls `f` on each operation encountered. The second argument to `f` should
20-
/// be interpreted as a pointer to a value of type `T`.
13+
/// Calls `f` on each operation encountered.
14+
pub fn walk_mlir_operations(top_op: OperationRef, f: &mut impl FnMut(OperationRef)) {
15+
f(top_op);
16+
17+
for region in top_op.regions() {
18+
let mut next_block = region.first_block();
19+
20+
while let Some(block) = next_block {
21+
if let Some(operation) = block.first_operation() {
22+
walk_mlir_block_operations(operation, f);
23+
}
24+
25+
next_block = block.next_in_region();
26+
}
27+
}
28+
}
29+
30+
/// Traverses all following operations in the current block
2131
///
22-
/// TODO: Can we receive a closure instead?
23-
/// We may need to save a pointer to the closure
24-
/// inside of the callback data.
25-
pub fn walk_mlir_operations<T: Sized>(
26-
top_op: OperationRef,
27-
f: OperationWalkCallback,
28-
initial: T,
29-
) -> T {
30-
let mut data = Box::new(initial);
31-
unsafe {
32-
mlir_sys::mlirOperationWalk(
33-
top_op.to_raw(),
34-
Some(f),
35-
data.as_mut() as *mut _ as *mut c_void,
36-
mlir_sys::MlirWalkOrder_MlirWalkPreOrder,
37-
);
38-
};
39-
*data
32+
/// Calls `f` on each operation encountered.
33+
///
34+
/// NOTE: The lifetime of each operation is bound to the previous operation,
35+
/// so the only way I found to comply with the borrow checker was to make the
36+
/// function recursive. This convinces the compiler that the full operation
37+
/// chain is in scope. This has been fixed in the latest melior release, but
38+
/// updating the dependency requires us to update to LLVM 20.
39+
pub fn walk_mlir_block_operations(operation: OperationRef, f: &mut impl FnMut(OperationRef)) {
40+
walk_mlir_operations(operation, f);
41+
42+
if let Some(next_operation) = operation.next_in_block() {
43+
walk_mlir_block_operations(next_operation, f);
44+
}
4045
}
4146

4247
/// Traverses from start block to end block (including) in preorder.
4348
///
44-
/// Calls `f` on each operation encountered. The second argument to `f` should
45-
/// be interpreted as a pointer to a value of type `T`.
46-
///
47-
/// TODO: Can we receive a closure instead?
48-
/// We may need to save a pointer to the closure
49-
/// inside of the callback data.
50-
pub fn walk_mlir_block<T: Sized>(
49+
/// Calls `f` on each operation encountered.
50+
pub fn walk_mlir_block(
5151
start_block: BlockRef,
5252
end_block: BlockRef,
53-
f: OperationWalkCallback,
54-
initial: T,
55-
) -> T {
56-
let mut data = Box::new(initial);
53+
f: &mut impl FnMut(OperationRef),
54+
) {
55+
let mut next_block = Some(start_block);
5756

58-
let mut current_block = start_block;
59-
loop {
60-
let mut next_operation = current_block.first_operation();
61-
62-
while let Some(operation) = next_operation {
63-
unsafe {
64-
mlir_sys::mlirOperationWalk(
65-
operation.to_raw(),
66-
Some(f),
67-
data.as_mut() as *mut _ as *mut c_void,
68-
mlir_sys::MlirWalkOrder_MlirWalkPreOrder,
69-
);
70-
};
71-
72-
// we have to convert it to raw, and back to ref to bypass borrow checker.
73-
next_operation = unsafe {
74-
operation
75-
.next_in_block()
76-
.map(OperationRef::to_raw)
77-
.map(|op| OperationRef::from_raw(op))
78-
}
57+
while let Some(block) = next_block {
58+
if let Some(operation) = block.first_operation() {
59+
walk_mlir_block_operations(operation, f);
7960
}
8061

81-
if current_block == end_block {
82-
break;
62+
if block == end_block {
63+
return;
8364
}
8465

85-
current_block = current_block
86-
.next_in_region()
87-
.expect("should always reach `end_block`");
66+
next_block = block.next_in_region();
8867
}
89-
90-
*data
9168
}
9269

9370
/// Traverses the whole LLVM Module, calling `f` on each instruction.

0 commit comments

Comments
 (0)