Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 98 additions & 2 deletions ct2rs/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,15 @@ fn build_ctranslate2() {
let cuda = cuda_root().expect("CUDA_TOOLKIT_ROOT_DIR is not specified");
cmake.define("WITH_CUDA", "ON");
cmake.define("CUDA_TOOLKIT_ROOT_DIR", &cuda);
cmake.define("CUDA_ARCH_LIST", "Common");
let arch_config = resolve_cuda_arch_list();
cmake.define("CUDA_ARCH_LIST", &arch_config.cmake_value);
let mut nvcc_flags = Vec::new();
if cfg!(feature = "cuda-small-binary") {
cmake.define("CUDA_NVCC_FLAGS", "-Xfatbin=-compress-all");
nvcc_flags.push("-Xfatbin=-compress-all".to_string());
}
nvcc_flags.extend(arch_config.extra_nvcc_flags);
if !nvcc_flags.is_empty() {
cmake.define("CUDA_NVCC_FLAGS", nvcc_flags.join(";"));
}
println!("cargo:rustc-link-search={}", cuda.join("lib").display());
println!("cargo:rustc-link-search={}", cuda.join("lib64").display());
Expand Down Expand Up @@ -192,6 +198,96 @@ fn build_ctranslate2() {
link_libraries(ctranslate2.join("build"));
}

struct CudaArchConfig {
cmake_value: String,
extra_nvcc_flags: Vec<String>,
}

fn resolve_cuda_arch_list() -> CudaArchConfig {
let raw = env::var("CUDA_ARCH_LIST")
.or_else(|_| env::var("CT2_CUDA_ARCH_LIST"))
.unwrap_or_else(|_| "Common".to_string());
let trimmed = raw.trim();
if trimmed.is_empty() {
return CudaArchConfig {
cmake_value: "Common".to_string(),
extra_nvcc_flags: Vec::new(),
};
}

let mut cmake_tokens = Vec::new();
let mut extra_flags = Vec::new();

for token in trimmed.split(|c: char| c == ';' || c == ',' || c.is_whitespace()) {
if token.is_empty() {
continue;
}
if let Some(arch) = parse_cuda_arch(token) {
if arch.major < 10 {
cmake_tokens.push(format!("{}.{}", arch.major, arch.minor));
} else {
extra_flags.push(format!(
"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}",
major = arch.major,
minor = arch.minor
));
}
} else {
cmake_tokens.push(token.trim().to_string());
}
}

if cmake_tokens.is_empty() {
cmake_tokens.push("Common".to_string());
}

CudaArchConfig {
cmake_value: cmake_tokens.join(";"),
extra_nvcc_flags: extra_flags,
}
}

#[derive(Copy, Clone)]
struct ParsedArch {
major: u32,
minor: u32,
}

fn parse_cuda_arch(token: &str) -> Option<ParsedArch> {
let trimmed = token.trim();
if trimmed.is_empty() {
return None;
}
if let Some(rest) = trimmed.strip_prefix("compute_") {
return parse_arch_pair(rest.parse::<u32>().ok()?);
}
if trimmed.contains('.') {
let mut parts = trimmed.split('.');
let major: u32 = parts.next()?.parse().ok()?;
let minor_part = parts.next().unwrap_or("0");
let minor_char = minor_part
.chars()
.find(|c| c.is_ascii_digit())
.unwrap_or('0');
let minor = minor_char.to_digit(10)? as u32;
return Some(ParsedArch { major, minor });
}
if let Ok(value) = trimmed.parse::<u32>() {
return parse_arch_pair(value);
}
None
}

fn parse_arch_pair(value: u32) -> Option<ParsedArch> {
if value < 10 {
return None;
}
Some(ParsedArch {
major: value / 10,
minor: value % 10,
})
}

fn link_system_libraries() {
println!("cargo:rustc-link-lib=ctranslate2");
if cfg!(target_arch = "x86_64") {
Expand Down
Loading