@@ -18,14 +18,35 @@ def transcribe(model, entry: Dict[str, Any]) -> str:
1818 return " " .join (texts )
1919
2020
21+ def get_device_and_index (device : str ) -> tuple [str , int | None ]:
22+ if len (device .split (":" )) == 2 :
23+ device , device_index = device .split (":" )
24+ device_index = int (device_index )
25+ return device , device_index
26+
27+ return device , None
28+
29+
2130def create_app (** kwargs ) -> Callable :
2231 model_path = kwargs .get ("model_path" )
2332 device : str = kwargs .get ("device" , "auto" )
2433 device_index = None
25- if len (device .split (":" )) == 2 :
26- device , device_index = device .split (":" )
27- device_index = int (device_index )
2834
35+ if len (device .split ("," )) > 1 :
36+ device_indexes = []
37+ base_device = None
38+ for device_instance in device .split ("," ):
39+ device , device_index = get_device_and_index (device_instance )
40+ base_device = base_device or device
41+ if base_device != device :
42+ raise ValueError ("Multiple devices must be instances of the same base device (e.g cuda:0, cuda:1 etc.)" )
43+ device_indexes .append (device_index )
44+ device = base_device
45+ device_index = device_indexes
46+ else :
47+ device , device_index = get_device_and_index (device )
48+
49+ print (f'Loading faster-whisper model: { model_path } on { device } with index: { device_index or "auto" } ' )
2950 model = faster_whisper .WhisperModel (model_path , device = device , device_index = device_index )
3051
3152 def transcribe_fn (entry ):
0 commit comments