Skip to content

Commit aa280a3

Browse files
committed
feat: add flags to enable GPU acceleration
1 parent 94ce525 commit aa280a3

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

src/bin/vsops.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ use std::path::PathBuf;
1111

1212
use anyhow::{anyhow, Result};
1313
use clap::{crate_name, Parser};
14+
use ct2rs::config::Config;
15+
use ct2rs::config::Device;
1416
use tokio::signal;
1517
use tokio::sync::oneshot;
1618

@@ -45,6 +47,12 @@ struct Args {
4547
/// Specifies the path to the socket file.
4648
#[arg(long)]
4749
socket_file: Option<PathBuf>,
50+
/// Enable GPU acceleration,
51+
#[arg(long)]
52+
cuda: bool,
53+
/// Specifies the CUDA device ID to be used. Effective only when the `--cuda` flag is enabled.
54+
#[arg(long, default_value = "0")]
55+
cuda_device_id: i32,
4856
}
4957

5058
#[tokio::main]
@@ -67,7 +75,14 @@ async fn main() -> Result<()> {
6775
Some(path) => SocketFile::with_path(path)?,
6876
None => SocketFile::new(crate_name!())?,
6977
};
70-
let server = Server::new(model_dir)?;
78+
let server = Server::new(
79+
model_dir,
80+
Config {
81+
device: if args.cuda { Device::CUDA } else { Device::CPU },
82+
device_indices: vec![args.cuda_device_id],
83+
..Default::default()
84+
},
85+
)?;
7186
server
7287
.serve(socket_file, async move {
7388
if let Err(e) = rx.await {

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::path::{Path, PathBuf};
1111

1212
use anyhow::{Error, Result};
1313
use ct2rs::auto::Tokenizer as AutoTokenizer;
14+
use ct2rs::config::Config;
1415
use ct2rs::{TranslationOptions, Translator};
1516
use tokio::net::{UnixListener, UnixStream};
1617
use tokio_stream::wrappers::UnixListenerStream;
@@ -60,9 +61,9 @@ pub struct Server {
6061
}
6162

6263
impl Server {
63-
pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
64+
pub fn new<P: AsRef<Path>>(model_path: P, config: Config) -> Result<Self> {
6465
Ok(Self {
65-
inner: Translator::new(model_path, &Default::default())?,
66+
inner: Translator::new(model_path, &config)?,
6667
options: TranslationOptions {
6768
beam_size: 12,
6869
use_vmap: true,

src/socket.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ impl AsRef<Path> for SocketFile {
4646

4747
impl Drop for SocketFile {
4848
fn drop(&mut self) {
49+
if !self.path.exists() {
50+
return;
51+
}
4952
if let Err(e) = remove_file(&self.path) {
5053
println!("failed to remove socket file {}: {}", &self, e);
5154
}

0 commit comments

Comments
 (0)