Skip to content

Commit 6938b95

Browse files
committed
test: use CUDA if cuda feature is enabled
Updated test configurations for Translator and Generator to use CUDA device when the `cuda` feature is enabled; defaulting to CPU otherwise. This ensures that tests are run on the appropriate device based on the feature flag.
1 parent 8c1bb5c commit 6938b95

File tree

5 files changed

+89
-12
lines changed

5 files changed

+89
-12
lines changed

src/generator.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,15 +242,26 @@ impl<T: Tokenizer> Debug for Generator<T> {
242242
#[cfg(feature = "hub")]
243243
mod tests {
244244
use super::Generator;
245-
use crate::{download_model, GenerationOptions};
245+
use crate::{download_model, Config, Device, GenerationOptions};
246246

247247
const MODEL_ID: &str = "jkawamoto/gpt2-ct2";
248248

249249
#[test]
250250
#[ignore]
251251
fn test_generate() {
252252
let model_path = download_model(MODEL_ID).unwrap();
253-
let g = Generator::new(&model_path, &Default::default()).unwrap();
253+
let g = Generator::new(
254+
&model_path,
255+
&Config {
256+
device: if cfg!(feature = "cuda") {
257+
Device::CUDA
258+
} else {
259+
Device::CPU
260+
},
261+
..Default::default()
262+
},
263+
)
264+
.unwrap();
254265

255266
let prompt = "CTranslate2 is a library";
256267
let res = g
@@ -271,7 +282,18 @@ mod tests {
271282
#[ignore]
272283
fn test_generator_debug() {
273284
let model_path = download_model(MODEL_ID).unwrap();
274-
let g = Generator::new(&model_path, &Default::default()).unwrap();
285+
let g = Generator::new(
286+
&model_path,
287+
&Config {
288+
device: if cfg!(feature = "cuda") {
289+
Device::CUDA
290+
} else {
291+
Device::CPU
292+
},
293+
..Default::default()
294+
},
295+
)
296+
.unwrap();
275297

276298
assert!(format!("{:?}", g).contains(model_path.file_name().unwrap().to_str().unwrap()));
277299
}

src/sys/generator.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,16 +619,27 @@ mod tests {
619619

620620
#[cfg(feature = "hub")]
621621
mod hub {
622-
use crate::download_model;
623622
use crate::sys::Generator;
623+
use crate::{download_model, Config, Device};
624624

625625
const MODEL_ID: &str = "jkawamoto/gpt2-ct2";
626626
#[test]
627627
#[ignore]
628628
fn test_generator_debug() {
629629
let model_path = download_model(MODEL_ID).unwrap();
630630

631-
let generator = Generator::new(&model_path, &Default::default()).unwrap();
631+
let generator = Generator::new(
632+
&model_path,
633+
&Config {
634+
device: if cfg!(feature = "cuda") {
635+
Device::CUDA
636+
} else {
637+
Device::CPU
638+
},
639+
..Default::default()
640+
},
641+
)
642+
.unwrap();
632643
assert!(format!("{:?}", generator)
633644
.contains(model_path.file_name().unwrap().to_str().unwrap()));
634645
}

src/sys/translator.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -728,16 +728,27 @@ mod tests {
728728

729729
#[cfg(feature = "hub")]
730730
mod hub {
731-
use crate::download_model;
732731
use crate::sys::Translator;
732+
use crate::{download_model, Config, Device};
733733

734734
const MODEL_ID: &str = "jkawamoto/fugumt-en-ja-ct2";
735735
#[test]
736736
#[ignore]
737737
fn test_translator_debug() {
738738
let model_path = download_model(MODEL_ID).unwrap();
739739

740-
let translator = Translator::new(&model_path, &Default::default()).unwrap();
740+
let translator = Translator::new(
741+
&model_path,
742+
&Config {
743+
device: if cfg!(feature = "cuda") {
744+
Device::CUDA
745+
} else {
746+
Device::CPU
747+
},
748+
..Default::default()
749+
},
750+
)
751+
.unwrap();
741752
assert!(format!("{:?}", translator)
742753
.contains(model_path.file_name().unwrap().to_str().unwrap()));
743754
}

src/translator.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,15 +327,26 @@ impl<T: Tokenizer> Debug for Translator<T> {
327327
#[cfg(test)]
328328
#[cfg(feature = "hub")]
329329
mod tests {
330-
use crate::{download_model, TranslationOptions, Translator};
330+
use crate::{download_model, Config, Device, TranslationOptions, Translator};
331331

332332
const MODEL_ID: &str = "jkawamoto/fugumt-en-ja-ct2";
333333

334334
#[test]
335335
#[ignore]
336336
fn test_translate() {
337337
let model_path = download_model(MODEL_ID).unwrap();
338-
let t = Translator::new(&model_path, &Default::default()).unwrap();
338+
let t = Translator::new(
339+
&model_path,
340+
&Config {
341+
device: if cfg!(feature = "cuda") {
342+
Device::CUDA
343+
} else {
344+
Device::CPU
345+
},
346+
..Default::default()
347+
},
348+
)
349+
.unwrap();
339350

340351
let res = t
341352
.translate_batch(
@@ -355,7 +366,18 @@ mod tests {
355366
#[ignore]
356367
fn test_translator_debug() {
357368
let model_path = download_model(MODEL_ID).unwrap();
358-
let t = Translator::new(&model_path, &Default::default()).unwrap();
369+
let t = Translator::new(
370+
&model_path,
371+
&Config {
372+
device: if cfg!(feature = "cuda") {
373+
Device::CUDA
374+
} else {
375+
Device::CPU
376+
},
377+
..Default::default()
378+
},
379+
)
380+
.unwrap();
359381

360382
assert!(format!("{:?}", t).contains(model_path.file_name().unwrap().to_str().unwrap()));
361383
}

src/whisper.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,15 +295,26 @@ impl PreprocessorConfig {
295295
#[cfg(test)]
296296
#[cfg(feature = "hub")]
297297
mod tests {
298-
use crate::{download_model, Whisper};
298+
use crate::{download_model, Config, Device, Whisper};
299299

300300
const MODEL_ID: &str = "jkawamoto/whisper-tiny-ct2";
301301

302302
#[test]
303303
#[ignore]
304304
fn test_whisper_debug() {
305305
let model_path = download_model(MODEL_ID).unwrap();
306-
let w = Whisper::new(&model_path, Default::default()).unwrap();
306+
let w = Whisper::new(
307+
&model_path,
308+
Config {
309+
device: if cfg!(feature = "cuda") {
310+
Device::CUDA
311+
} else {
312+
Device::CPU
313+
},
314+
..Default::default()
315+
},
316+
)
317+
.unwrap();
307318

308319
assert!(format!("{:?}", w).contains(model_path.file_name().unwrap().to_str().unwrap()));
309320
}

0 commit comments

Comments
 (0)