diff --git a/ct2rs/build.rs b/ct2rs/build.rs index 399b6ab..de3a33f 100644 --- a/ct2rs/build.rs +++ b/ct2rs/build.rs @@ -98,9 +98,15 @@ fn build_ctranslate2() { let cuda = cuda_root().expect("CUDA_TOOLKIT_ROOT_DIR is not specified"); cmake.define("WITH_CUDA", "ON"); cmake.define("CUDA_TOOLKIT_ROOT_DIR", &cuda); - cmake.define("CUDA_ARCH_LIST", "Common"); + let arch_config = resolve_cuda_arch_list(); + cmake.define("CUDA_ARCH_LIST", &arch_config.cmake_value); + let mut nvcc_flags = Vec::new(); if cfg!(feature = "cuda-small-binary") { - cmake.define("CUDA_NVCC_FLAGS", "-Xfatbin=-compress-all"); + nvcc_flags.push("-Xfatbin=-compress-all".to_string()); + } + nvcc_flags.extend(arch_config.extra_nvcc_flags); + if !nvcc_flags.is_empty() { + cmake.define("CUDA_NVCC_FLAGS", nvcc_flags.join(";")); } println!("cargo:rustc-link-search={}", cuda.join("lib").display()); println!("cargo:rustc-link-search={}", cuda.join("lib64").display()); @@ -192,6 +198,96 @@ fn build_ctranslate2() { link_libraries(ctranslate2.join("build")); } +struct CudaArchConfig { + cmake_value: String, + extra_nvcc_flags: Vec, +} + +fn resolve_cuda_arch_list() -> CudaArchConfig { + let raw = env::var("CUDA_ARCH_LIST") + .or_else(|_| env::var("CT2_CUDA_ARCH_LIST")) + .unwrap_or_else(|_| "Common".to_string()); + let trimmed = raw.trim(); + if trimmed.is_empty() { + return CudaArchConfig { + cmake_value: "Common".to_string(), + extra_nvcc_flags: Vec::new(), + }; + } + + let mut cmake_tokens = Vec::new(); + let mut extra_flags = Vec::new(); + + for token in trimmed.split(|c: char| c == ';' || c == ',' || c.is_whitespace()) { + if token.is_empty() { + continue; + } + if let Some(arch) = parse_cuda_arch(token) { + if arch.major < 10 { + cmake_tokens.push(format!("{}.{}", arch.major, arch.minor)); + } else { + extra_flags.push(format!( + "-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}", + major = arch.major, + minor = arch.minor + )); + } + } else { + cmake_tokens.push(token.trim().to_string()); + } + } + + if cmake_tokens.is_empty() { + cmake_tokens.push("Common".to_string()); + } + + CudaArchConfig { + cmake_value: cmake_tokens.join(";"), + extra_nvcc_flags: extra_flags, + } +} + +#[derive(Copy, Clone)] +struct ParsedArch { + major: u32, + minor: u32, +} + +fn parse_cuda_arch(token: &str) -> Option { + let trimmed = token.trim(); + if trimmed.is_empty() { + return None; + } + if let Some(rest) = trimmed.strip_prefix("compute_") { + return parse_arch_pair(rest.parse::().ok()?); + } + if trimmed.contains('.') { + let mut parts = trimmed.split('.'); + let major: u32 = parts.next()?.parse().ok()?; + let minor_part = parts.next().unwrap_or("0"); + let minor_char = minor_part + .chars() + .find(|c| c.is_ascii_digit()) + .unwrap_or('0'); + let minor = minor_char.to_digit(10)? as u32; + return Some(ParsedArch { major, minor }); + } + if let Ok(value) = trimmed.parse::() { + return parse_arch_pair(value); + } + None +} + +fn parse_arch_pair(value: u32) -> Option { + if value < 10 { + return None; + } + Some(ParsedArch { + major: value / 10, + minor: value % 10, + }) +} + fn link_system_libraries() { println!("cargo:rustc-link-lib=ctranslate2"); if cfg!(target_arch = "x86_64") {