Skip to content

Commit 12e1855

Browse files
authored
Allow Faster-whisper evals to run on multiple GPUs in parallel (#15)
Co-authored-by: yoad <[email protected]>
1 parent 75a6723 commit 12e1855

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

engines/faster_whisper_engine.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,35 @@ def transcribe(model, entry: Dict[str, Any]) -> str:
1818
return " ".join(texts)
1919

2020

21+
def get_device_and_index(device: str) -> tuple[str, int | None]:
22+
if len(device.split(":")) == 2:
23+
device, device_index = device.split(":")
24+
device_index = int(device_index)
25+
return device, device_index
26+
27+
return device, None
28+
29+
2130
def create_app(**kwargs) -> Callable:
2231
model_path = kwargs.get("model_path")
2332
device: str = kwargs.get("device", "auto")
2433
device_index = None
25-
if len(device.split(":")) == 2:
26-
device, device_index = device.split(":")
27-
device_index = int(device_index)
2834

35+
if len(device.split(",")) > 1:
36+
device_indexes = []
37+
base_device = None
38+
for device_instance in device.split(","):
39+
device, device_index = get_device_and_index(device_instance)
40+
base_device = base_device or device
41+
if base_device != device:
42+
raise ValueError("Multiple devices must be instances of the same base device (e.g cuda:0, cuda:1 etc.)")
43+
device_indexes.append(device_index)
44+
device = base_device
45+
device_index = device_indexes
46+
else:
47+
device, device_index = get_device_and_index(device)
48+
49+
print(f'Loading faster-whisper model: {model_path} on {device} with index: {device_index or "auto"}')
2950
model = faster_whisper.WhisperModel(model_path, device=device, device_index=device_index)
3051

3152
def transcribe_fn(entry):

merge-checkpoints-whisper.py

100644100755
File mode changed.

0 commit comments

Comments
 (0)