1
- use std:: ffi:: c_void;
2
-
3
1
use llvm_sys:: {
4
2
core:: {
5
3
LLVMGetFirstBasicBlock , LLVMGetFirstFunction , LLVMGetFirstInstruction ,
@@ -9,85 +7,64 @@ use llvm_sys::{
9
7
LLVMBasicBlock , LLVMValue ,
10
8
} ;
11
9
use melior:: ir:: { BlockLike , BlockRef , OperationRef } ;
12
- use mlir_sys:: { MlirOperation , MlirWalkResult } ;
13
-
14
- type OperationWalkCallback =
15
- unsafe extern "C" fn ( MlirOperation , * mut :: std:: os:: raw:: c_void ) -> MlirWalkResult ;
16
10
17
11
/// Traverses the given operation tree in preorder.
18
12
///
19
- /// Calls `f` on each operation encountered. The second argument to `f` should
20
- /// be interpreted as a pointer to a value of type `T`.
13
+ /// Calls `f` on each operation encountered.
14
+ pub fn walk_mlir_operations ( top_op : OperationRef , f : & mut impl FnMut ( OperationRef ) ) {
15
+ f ( top_op) ;
16
+
17
+ for region in top_op. regions ( ) {
18
+ let mut next_block = region. first_block ( ) ;
19
+
20
+ while let Some ( block) = next_block {
21
+ if let Some ( operation) = block. first_operation ( ) {
22
+ walk_mlir_block_operations ( operation, f) ;
23
+ }
24
+
25
+ next_block = block. next_in_region ( ) ;
26
+ }
27
+ }
28
+ }
29
+
30
+ /// Traverses all following operations in the current block
21
31
///
22
- /// TODO: Can we receive a closure instead?
23
- /// We may need to save a pointer to the closure
24
- /// inside of the callback data.
25
- pub fn walk_mlir_operations < T : Sized > (
26
- top_op : OperationRef ,
27
- f : OperationWalkCallback ,
28
- initial : T ,
29
- ) -> T {
30
- let mut data = Box :: new ( initial) ;
31
- unsafe {
32
- mlir_sys:: mlirOperationWalk (
33
- top_op. to_raw ( ) ,
34
- Some ( f) ,
35
- data. as_mut ( ) as * mut _ as * mut c_void ,
36
- mlir_sys:: MlirWalkOrder_MlirWalkPreOrder ,
37
- ) ;
38
- } ;
39
- * data
32
+ /// Calls `f` on each operation encountered.
33
+ ///
34
+ /// NOTE: The lifetime of each operation is bound to the previous operation,
35
+ /// so the only way I found to comply with the borrow checker was to make the
36
+ /// function recursive. This convinces the compiler that the full operation
37
+ /// chain is in scope. This has been fixed in the latest melior release, but
38
+ /// updating the dependency requires us to update to LLVM 20.
39
+ pub fn walk_mlir_block_operations ( operation : OperationRef , f : & mut impl FnMut ( OperationRef ) ) {
40
+ walk_mlir_operations ( operation, f) ;
41
+
42
+ if let Some ( next_operation) = operation. next_in_block ( ) {
43
+ walk_mlir_block_operations ( next_operation, f) ;
44
+ }
40
45
}
41
46
42
47
/// Traverses from start block to end block (including) in preorder.
43
48
///
44
- /// Calls `f` on each operation encountered. The second argument to `f` should
45
- /// be interpreted as a pointer to a value of type `T`.
46
- ///
47
- /// TODO: Can we receive a closure instead?
48
- /// We may need to save a pointer to the closure
49
- /// inside of the callback data.
50
- pub fn walk_mlir_block < T : Sized > (
49
+ /// Calls `f` on each operation encountered.
50
+ pub fn walk_mlir_block (
51
51
start_block : BlockRef ,
52
52
end_block : BlockRef ,
53
- f : OperationWalkCallback ,
54
- initial : T ,
55
- ) -> T {
56
- let mut data = Box :: new ( initial) ;
53
+ f : & mut impl FnMut ( OperationRef ) ,
54
+ ) {
55
+ let mut next_block = Some ( start_block) ;
57
56
58
- let mut current_block = start_block;
59
- loop {
60
- let mut next_operation = current_block. first_operation ( ) ;
61
-
62
- while let Some ( operation) = next_operation {
63
- unsafe {
64
- mlir_sys:: mlirOperationWalk (
65
- operation. to_raw ( ) ,
66
- Some ( f) ,
67
- data. as_mut ( ) as * mut _ as * mut c_void ,
68
- mlir_sys:: MlirWalkOrder_MlirWalkPreOrder ,
69
- ) ;
70
- } ;
71
-
72
- // we have to convert it to raw, and back to ref to bypass borrow checker.
73
- next_operation = unsafe {
74
- operation
75
- . next_in_block ( )
76
- . map ( OperationRef :: to_raw)
77
- . map ( |op| OperationRef :: from_raw ( op) )
78
- }
57
+ while let Some ( block) = next_block {
58
+ if let Some ( operation) = block. first_operation ( ) {
59
+ walk_mlir_block_operations ( operation, f) ;
79
60
}
80
61
81
- if current_block == end_block {
82
- break ;
62
+ if block == end_block {
63
+ return ;
83
64
}
84
65
85
- current_block = current_block
86
- . next_in_region ( )
87
- . expect ( "should always reach `end_block`" ) ;
66
+ next_block = block. next_in_region ( ) ;
88
67
}
89
-
90
- * data
91
68
}
92
69
93
70
/// Traverses the whole LLVM Module, calling `f` on each instruction.
0 commit comments