Skip to content

Commit d131cb8

Browse files
committed
test: use CUDA if cuda feature is enabled
Updated test configurations for Translator and Generator to use CUDA device when the `cuda` feature is enabled; defaulting to CPU otherwise. This ensures that tests are run on the appropriate device based on the feature flag.
1 parent eb7af30 commit d131cb8

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

src/generator.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,15 +242,26 @@ impl<T: Tokenizer> Debug for Generator<T> {
242242
#[cfg(feature = "hub")]
243243
mod tests {
244244
use super::Generator;
245-
use crate::{download_model, GenerationOptions};
245+
use crate::{download_model, Config, Device, GenerationOptions};
246246

247247
const MODEL_ID: &str = "jkawamoto/gpt2-ct2";
248248

249249
#[test]
250250
#[ignore]
251251
fn test_generate() {
252252
let model_path = download_model(MODEL_ID).unwrap();
253-
let g = Generator::new(&model_path, &Default::default()).unwrap();
253+
let g = Generator::new(
254+
&model_path,
255+
&Config {
256+
device: if cfg!(feature = "cuda") {
257+
Device::CUDA
258+
} else {
259+
Device::CPU
260+
},
261+
..Default::default()
262+
},
263+
)
264+
.unwrap();
254265

255266
let prompt = "CTranslate2 is a library";
256267
let res = g

src/translator.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,15 +327,27 @@ impl<T: Tokenizer> Debug for Translator<T> {
327327
#[cfg(test)]
328328
#[cfg(feature = "hub")]
329329
mod tests {
330-
use crate::{download_model, TranslationOptions, Translator};
330+
use crate::{download_model, Config, Device, TranslationOptions, Translator};
331331

332332
const MODEL_ID: &str = "jkawamoto/fugumt-en-ja-ct2";
333333

334334
#[test]
335335
#[ignore]
336336
fn test_translate() {
337337
let model_path = download_model(MODEL_ID).unwrap();
338-
let t = Translator::new(&model_path, &Default::default()).unwrap();
338+
let t = Translator::new(
339+
&model_path,
340+
&Config {
341+
device: if cfg!(feature = "cuda") {
342+
Device::CUDA
343+
} else {
344+
Device::CPU
345+
},
346+
347+
..Default::default()
348+
},
349+
)
350+
.unwrap();
339351

340352
let res = t
341353
.translate_batch(

0 commit comments

Comments
 (0)