@@ -13,6 +13,12 @@ use std::path::PathBuf;
13
13
use glob:: glob;
14
14
use which:: which;
15
15
16
+ const PYTHON_PRINT_DIRS : & str = r"
17
+ import sysconfig
18
+ print('PYTHON_INCLUDE_DIR:', sysconfig.get_config_var('INCLUDEDIR'))
19
+ print('PYTHON_LIB_DIR:', sysconfig.get_config_var('LIBDIR'))
20
+ " ;
21
+
16
22
// Translated from torch/utils/cpp_extension.py
17
23
fn find_cuda_home ( ) -> Option < String > {
18
24
// Guess #1
@@ -67,6 +73,26 @@ fn emit_cuda_link_directives(cuda_home: &str) {
67
73
println ! ( "cargo:rustc-link-lib=cudart" ) ;
68
74
}
69
75
76
+ fn python_env_dirs ( ) -> ( Option < String > , Option < String > ) {
77
+ let output = std:: process:: Command :: new ( PathBuf :: from ( "python3" ) )
78
+ . arg ( "-c" )
79
+ . arg ( PYTHON_PRINT_DIRS )
80
+ . output ( )
81
+ . unwrap_or_else ( |_| panic ! ( "error running python" ) ) ;
82
+
83
+ let mut include_dir = None ;
84
+ let mut lib_dir = None ;
85
+ for line in String :: from_utf8_lossy ( & output. stdout ) . lines ( ) {
86
+ if let Some ( path) = line. strip_prefix ( "PYTHON_INCLUDE_DIR: " ) {
87
+ include_dir = Some ( path. to_string ( ) ) ;
88
+ }
89
+ if let Some ( path) = line. strip_prefix ( "PYTHON_LIB_DIR: " ) {
90
+ lib_dir = Some ( path. to_string ( ) ) ;
91
+ }
92
+ }
93
+ ( include_dir, lib_dir)
94
+ }
95
+
70
96
fn main ( ) {
71
97
// Tell cargo to look for shared libraries in the specified directory
72
98
println ! ( "cargo:rustc-link-search=/usr/lib" ) ;
@@ -78,30 +104,65 @@ fn main() {
78
104
// Link against the mlx5 library
79
105
println ! ( "cargo:rustc-link-lib=mlx5" ) ;
80
106
81
- // Link against cuda library library
82
- if let Some ( cuda_home) = find_cuda_home ( ) {
83
- emit_cuda_link_directives ( & cuda_home) ;
84
- }
85
-
86
107
// Tell cargo to invalidate the built crate whenever the wrapper changes
87
108
println ! ( "cargo:rerun-if-changed=src/rdmaxcel.h" ) ;
88
109
89
- // Add cargo metadata
90
- println ! ( "cargo:rustc-cfg=cargo" ) ;
91
- println ! ( "cargo:rustc-check-cfg=cfg(cargo)" ) ;
110
+ // Get the directory of the current crate
111
+ let manifest_dir = env:: var ( "CARGO_MANIFEST_DIR" ) . unwrap_or_else ( |_| {
112
+ // For buck2 run, we know the package is in fbcode/monarch/rdmaxcel-sys
113
+ // Get the fbsource directory from the current directory path
114
+ let current_dir = std:: env:: current_dir ( ) . expect ( "Failed to get current directory" ) ;
115
+ let current_path = current_dir. to_string_lossy ( ) ;
116
+
117
+ // Find the fbsource part of the path
118
+ if let Some ( fbsource_pos) = current_path. find ( "fbsource" ) {
119
+ let fbsource_path = & current_path[ ..fbsource_pos + "fbsource" . len ( ) ] ;
120
+ format ! ( "{}/fbcode/monarch/rdmaxcel-sys" , fbsource_path)
121
+ } else {
122
+ // If we can't find fbsource in the path, just use the current directory
123
+ format ! ( "{}/src" , current_dir. to_string_lossy( ) )
124
+ }
125
+ } ) ;
92
126
93
- // The bindgen::Builder is the main entry point to bindgen
94
- let bindings = bindgen:: Builder :: default ( )
127
+ // Create the absolute path to the header file
128
+ let header_path = format ! ( "{}/src/rdmaxcel.h" , manifest_dir) ;
129
+
130
+ // Check if the header file exists
131
+ if !Path :: new ( & header_path) . exists ( ) {
132
+ panic ! ( "Header file not found at {}" , header_path) ;
133
+ }
134
+
135
+ // Start building the bindgen configuration
136
+ let mut builder = bindgen:: Builder :: default ( )
95
137
// The input header we would like to generate bindings for
96
- . header ( "src/rdmaxcel.h" )
138
+ . header ( & header_path)
139
+ . clang_arg ( "-x" )
140
+ . clang_arg ( "c++" )
141
+ . clang_arg ( "-std=gnu++20" )
142
+ . parse_callbacks ( Box :: new ( bindgen:: CargoCallbacks :: new ( ) ) )
97
143
// Allow the specified functions, types, and variables
98
144
. allowlist_function ( "ibv_.*" )
99
145
. allowlist_function ( "mlx5dv_.*" )
100
146
. allowlist_function ( "mlx5_wqe_.*" )
147
+ . allowlist_function ( "create_qp" )
148
+ . allowlist_function ( "create_mlx5dv_.*" )
149
+ . allowlist_function ( "register_cuda_memory" )
150
+ . allowlist_function ( "db_ring" )
151
+ . allowlist_function ( "cqe_poll" )
152
+ . allowlist_function ( "send_wqe" )
153
+ . allowlist_function ( "recv_wqe" )
154
+ . allowlist_function ( "launch_db_ring" )
155
+ . allowlist_function ( "launch_cqe_poll" )
156
+ . allowlist_function ( "launch_send_wqe" )
157
+ . allowlist_function ( "launch_recv_wqe" )
101
158
. allowlist_type ( "ibv_.*" )
102
159
. allowlist_type ( "mlx5dv_.*" )
103
160
. allowlist_type ( "mlx5_wqe_.*" )
161
+ . allowlist_type ( "cqe_poll_result_t" )
162
+ . allowlist_type ( "wqe_params_t" )
163
+ . allowlist_type ( "cqe_poll_params_t" )
104
164
. allowlist_var ( "MLX5_.*" )
165
+ . allowlist_var ( "IBV_.*" )
105
166
// Block specific types that are manually defined in lib.rs
106
167
. blocklist_type ( "ibv_wc" )
107
168
. blocklist_type ( "mlx5_wqe_ctrl_seg" )
@@ -118,15 +179,58 @@ fn main() {
118
179
. constified_enum_module ( "ibv_wr_opcode" )
119
180
. constified_enum_module ( "ibv_wc_status" )
120
181
. derive_default ( true )
121
- . prepend_enum_name ( false )
122
- // Finish the builder and generate the bindings
123
- . generate ( )
124
- // Unwrap the Result and panic on failure
125
- . expect ( "Unable to generate bindings" ) ;
182
+ . prepend_enum_name ( false ) ;
183
+
184
+ // Add CUDA include path if available
185
+ if let Some ( cuda_home) = find_cuda_home ( ) {
186
+ let cuda_include_path = format ! ( "{}/include" , cuda_home) ;
187
+ if Path :: new ( & cuda_include_path) . exists ( ) {
188
+ builder = builder. clang_arg ( format ! ( "-I{}" , cuda_include_path) ) ;
189
+ } else {
190
+ eprintln ! (
191
+ "Warning: CUDA include directory not found at {}" ,
192
+ cuda_include_path
193
+ ) ;
194
+ }
195
+ } else {
196
+ eprintln ! ( "Warning: CUDA home directory not found. Continuing without CUDA include path." ) ;
197
+ }
198
+
199
+ // Include headers and libs from the active environment.
200
+ let ( include_dir, lib_dir) = python_env_dirs ( ) ;
201
+ if let Some ( include_dir) = include_dir {
202
+ builder = builder. clang_arg ( format ! ( "-I{}" , include_dir) ) ;
203
+ }
204
+ if let Some ( lib_dir) = lib_dir {
205
+ println ! ( "cargo::rustc-link-search=native={}" , lib_dir) ;
206
+ // Set cargo metadata to inform dependent binaries about how to set their
207
+ // RPATH (see controller/build.rs for an example).
208
+ println ! ( "cargo::metadata=LIB_PATH={}" , lib_dir) ;
209
+ }
210
+ if let Some ( cuda_home) = find_cuda_home ( ) {
211
+ emit_cuda_link_directives ( & cuda_home) ;
212
+ }
213
+
214
+ // Generate bindings
215
+ let bindings = builder. generate ( ) . expect ( "Unable to generate bindings" ) ;
126
216
127
217
// Write the bindings to the $OUT_DIR/bindings.rs file
128
- let out_path = PathBuf :: from ( env:: var ( "OUT_DIR" ) . unwrap ( ) ) ;
129
- bindings
130
- . write_to_file ( out_path. join ( "bindings.rs" ) )
131
- . expect ( "Couldn't write bindings!" ) ;
218
+ match env:: var ( "OUT_DIR" ) {
219
+ Ok ( out_dir) => {
220
+ let out_path = PathBuf :: from ( out_dir) ;
221
+ match bindings. write_to_file ( out_path. join ( "bindings.rs" ) ) {
222
+ Ok ( _) => {
223
+ println ! ( "cargo:rustc-cfg=cargo" ) ;
224
+ println ! ( "cargo:rustc-check-cfg=cfg(cargo)" ) ;
225
+ }
226
+ Err ( e) => eprintln ! ( "Warning: Couldn't write bindings: {}" , e) ,
227
+ }
228
+ }
229
+ Err ( _) => {
230
+ // When running via buck2 run, OUT_DIR might not be set
231
+ // This is expected and not an error - we're just running the script directly
232
+ // The actual build will happen later with cargo
233
+ println ! ( "Note: OUT_DIR not set, skipping bindings file generation" ) ;
234
+ }
235
+ }
132
236
}
0 commit comments