@@ -11,6 +11,8 @@ use std::path::PathBuf;
1111
1212use anyhow:: { anyhow, Result } ;
1313use clap:: { crate_name, Parser } ;
14+ use ct2rs:: config:: Config ;
15+ use ct2rs:: config:: Device ;
1416use tokio:: signal;
1517use tokio:: sync:: oneshot;
1618
@@ -45,6 +47,12 @@ struct Args {
4547 /// Specifies the path to the socket file.
4648 #[ arg( long) ]
4749 socket_file : Option < PathBuf > ,
50+ /// Enable GPU acceleration,
51+ #[ arg( long) ]
52+ cuda : bool ,
53+ /// Specifies the CUDA device ID to be used. Effective only when the `--cuda` flag is enabled.
54+ #[ arg( long, default_value = "0" ) ]
55+ cuda_device_id : i32 ,
4856}
4957
5058#[ tokio:: main]
@@ -67,7 +75,14 @@ async fn main() -> Result<()> {
6775 Some ( path) => SocketFile :: with_path ( path) ?,
6876 None => SocketFile :: new ( crate_name ! ( ) ) ?,
6977 } ;
70- let server = Server :: new ( model_dir) ?;
78+ let server = Server :: new (
79+ model_dir,
80+ Config {
81+ device : if args. cuda { Device :: CUDA } else { Device :: CPU } ,
82+ device_indices : vec ! [ args. cuda_device_id] ,
83+ ..Default :: default ( )
84+ } ,
85+ ) ?;
7186 server
7287 . serve ( socket_file, async move {
7388 if let Err ( e) = rx. await {
0 commit comments