Skip to content

Commit a97bef8

Browse files
committed
added tests for Faster-Whisper and HF models
-added tests for Faster-Whisper and Hugging Face models; cap `tokenizers` to 0.20.3 for tests with python versions <3.9 due to: huggingface/tokenizers#1691 -added test for PyTorch 2.6.0
1 parent 4232dd4 commit a97bef8

File tree

4 files changed

+86
-39
lines changed

4 files changed

+86
-39
lines changed

.github/workflows/test.yml

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,34 +16,54 @@ jobs:
1616
- python-version: '3.8'
1717
pytorch-version: 1.10.1
1818
numpy-requirement: "'numpy<2'"
19+
tokenizers-requirement: "'tokenizers<=0.20.3'"
1920
- python-version: '3.8'
2021
pytorch-version: 1.13.1
2122
numpy-requirement: "'numpy<2'"
23+
tokenizers-requirement: "'tokenizers<=0.20.3'"
2224
- python-version: '3.8'
2325
pytorch-version: 2.0.1
2426
numpy-requirement: "'numpy<2'"
27+
tokenizers-requirement: "'tokenizers<=0.20.3'"
2528
- python-version: '3.9'
2629
pytorch-version: 2.1.2
2730
numpy-requirement: "'numpy<2'"
31+
tokenizers-requirement: "'tokenizers'"
2832
- python-version: '3.10'
2933
pytorch-version: 2.2.2
3034
numpy-requirement: "'numpy<2'"
35+
tokenizers-requirement: "'tokenizers'"
3136
- python-version: '3.11'
3237
pytorch-version: 2.3.1
3338
numpy-requirement: "'numpy'"
39+
tokenizers-requirement: "'tokenizers'"
3440
- python-version: '3.12'
3541
pytorch-version: 2.4.1
3642
numpy-requirement: "'numpy'"
43+
tokenizers-requirement: "'tokenizers'"
3744
- python-version: '3.12'
3845
pytorch-version: 2.5.0
3946
numpy-requirement: "'numpy'"
47+
tokenizers-requirement: "'tokenizers'"
48+
- python-version: '3.12'
49+
pytorch-version: 2.6.0
50+
numpy-requirement: "'numpy'"
51+
tokenizers-requirement: "'tokenizers'"
4052
steps:
4153
- uses: conda-incubator/setup-miniconda@v3
4254
- run: conda install -n test ffmpeg python=${{ matrix.python-version }}
4355
- uses: actions/checkout@v4
4456
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
45-
- run: pip3 install .["dev"] ${{ matrix.numpy-requirement }} torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple
57+
- run: pip3 install . ${{ matrix.numpy-requirement }} torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple
4658
- run: python test/test_transcribe.py
4759
- run: python test/test_align.py
4860
- run: python test/test_refine.py
4961
- run: python test/test_locate.py
62+
- run: pip3 install .["fw"] ${{ matrix.tokenizers-requirement }}
63+
- run: python test/test_transcribe.py load_faster_whisper
64+
- run: python test/test_align.py load_faster_whisper
65+
- run: python test/test_refine.py load_faster_whisper
66+
- run: pip3 install .["hf"] 'transformers<=4.46.3'
67+
- run: python test/test_transcribe.py load_hf_whisper
68+
- run: python test/test_align.py load_hf_whisper
69+
- run: python test/test_refine.py load_hf_whisper

test/test_align.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,60 @@
11
import os
2+
import sys
23
import torch
34
import stable_whisper
45

56

7+
def get_load_method():
8+
if len(sys.argv) >= 2:
9+
return getattr(stable_whisper, sys.argv[1]), sys.argv[1]
10+
return stable_whisper.load_model, 'load_model'
11+
12+
613
def check_result(result, expected_text: str, test_name: str):
714
assert result.text == expected_text
815

916
timing_checked = False
10-
for segment in result:
11-
for word in segment:
12-
assert word.start < word.end, (word.start, word.end, test_name)
13-
if word.word.strip(" ,") == "americans":
14-
assert word.start <= 1.8, (word.start, test_name)
15-
assert word.end >= 1.8, (word.end, test_name)
16-
timing_checked = True
17+
all_words = result.all_words()
18+
fail_count = 0
19+
for word in all_words:
20+
if word.start >= word.end:
21+
fail_count += 1
22+
if word.word.strip(" ,") == "americans":
23+
assert word.start <= 1.8, (word.start, test_name)
24+
assert word.end >= 1.8, (word.end, test_name)
25+
timing_checked = True
26+
fail_rate = fail_count / len(all_words)
27+
print(f'Fail Count: {fail_count} / {len(all_words)} ({test_name})\n')
28+
assert fail_rate < 0.1, (fail_rate, fail_count, test_name)
1729

1830
assert timing_checked, test_name
1931

2032

2133
def test_align(model_names):
2234
device = "cuda" if torch.cuda.is_available() else "cpu"
2335
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
24-
models = [stable_whisper.load_model(name, device=device) for name in model_names]
25-
orig_result = models[0].transcribe(
36+
load, load_name = get_load_method()
37+
models = [(load(name, device=device), f'{load_name}->{name}') for name in model_names]
38+
orig_result = models[0][0].transcribe(
2639
audio_path, language='en', temperature=0.0, word_timestamps=True
2740
)
2841
for word in orig_result.all_words():
2942
word.word = word.word.replace('Americans', 'americans')
3043

3144
def single_test(m, meth: str, prep, extra_check, **kwargs):
32-
model_type = 'multilingual-model' if m.is_multilingual else 'en-model'
45+
m, model_type = m
3346
meth = getattr(m, meth)
34-
test_name = f'{model_type} {meth.__name__}(WhisperResult)'
35-
try:
36-
result = meth(audio_path, orig_result, **kwargs)
37-
check_same_segment_text(orig_result, result)
38-
except Exception as e:
39-
raise Exception(f'failed test {test_name} -> {e.__class__.__name__}: {e}')
47+
test_name = f'{model_type} + {meth.__name__}(WhisperResult)'
48+
print(f'Start Test: {test_name}')
49+
result = meth(audio_path, orig_result, **kwargs)
50+
check_same_segment_text(orig_result, result)
4051
check_result(result, orig_result.text, test_name)
4152

42-
test_name = f'{model_type} {meth.__name__}(plain-text)'
43-
try:
44-
result = meth(audio_path, prep(orig_result), language=orig_result.language)
45-
if extra_check:
46-
extra_check(orig_result, result)
47-
except Exception as e:
48-
raise Exception(f'failed test {test_name} -> {e.__class__.__name__}: {e}')
53+
test_name = f'{model_type} + {meth.__name__}(plain-text)'
54+
print(f'Start Test: {test_name}')
55+
result = meth(audio_path, prep(orig_result), language=orig_result.language)
56+
if extra_check:
57+
extra_check(orig_result, result)
4958
check_result(result, orig_result.text, test_name)
5059

5160
def get_text(res):

test/test_refine.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
import os
2+
import sys
23
import torch
34
import stable_whisper
45

56

6-
def check_result(result, orig_result, expect_change: bool = True):
7+
def get_load_method():
8+
if len(sys.argv) >= 2:
9+
return getattr(stable_whisper, sys.argv[1]), sys.argv[1]
10+
return stable_whisper.load_model, 'load_model'
11+
12+
13+
def check_result(result, orig_result, test_name: str, expect_change: bool = True):
714

815
timing_checked = False
916
for segment in result:
1017
for word in segment:
1118
assert word.start < word.end
1219
if word.word.strip(" ,").lower() == "americans":
13-
assert word.start <= 1.8, word.start
14-
assert word.end >= 1.8, word.end
20+
assert word.start <= 1.8, (word.start, test_name)
21+
assert word.end >= 1.8, (word.end, test_name)
1522
timing_checked = True
1623

1724
if expect_change:
@@ -25,7 +32,8 @@ def check_result(result, orig_result, expect_change: bool = True):
2532

2633
def test_refine(model0_name: str, model1_name: str):
2734
device = "cuda" if torch.cuda.is_available() else "cpu"
28-
model0 = stable_whisper.load_model(model0_name, device=device)
35+
load, load_name = get_load_method()
36+
model0 = load(model0_name, device=device)
2937
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
3038

3139
language = "en" if model0_name.endswith(".en") else None
@@ -36,7 +44,7 @@ def test_refine(model0_name: str, model1_name: str):
3644
model1 = stable_whisper.load_model(model1_name, device=device)
3745

3846
result = model1.refine(audio_path, orig_result, inplace=False)
39-
check_result(result, orig_result, True)
47+
check_result(result, orig_result, load_name, True)
4048

4149

4250
def test():

test/test_transcribe.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,52 @@
11
import os
2+
import sys
23
import torch
34
import stable_whisper
45

56

6-
def check_result(result):
7-
assert result.language == "en"
7+
def get_load_method():
8+
if len(sys.argv) >= 2:
9+
return getattr(stable_whisper, sys.argv[1]), sys.argv[1]
10+
return stable_whisper.load_model, 'load_model'
11+
12+
13+
def check_result(result, test_name: str):
14+
assert result.language in ('en', 'english'), result.language
815

916
transcription = result.text.lower()
10-
assert "my fellow americans" in transcription
11-
assert "your country" in transcription
12-
assert "do for you" in transcription
17+
assert "my fellow americans" in transcription, test_name
18+
assert "your country" in transcription, test_name
19+
assert "do for you" in transcription, test_name
1320

1421
timing_checked = False
1522
for segment in result:
1623
for word in segment:
1724
assert word.start < word.end
1825
if word.word.strip(" ,").lower() == "americans":
19-
assert word.start <= 1.8, word.start
20-
assert word.end >= 1.8, word.end
26+
assert word.start <= 1.8, (word.start, test_name)
27+
assert word.end >= 1.8, (word.end, test_name)
2128
timing_checked = True
2229

2330
assert timing_checked
2431

2532

2633
def test_transcribe(model_name: str):
2734
device = "cuda" if torch.cuda.is_available() else "cpu"
28-
model = stable_whisper.load_model(model_name, device=device)
35+
load, load_name = get_load_method()
36+
model = load(model_name, device=device)
2937
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
3038

3139
language = "en" if model_name.endswith(".en") else None
3240
result = model.transcribe(
3341
audio_path, language=language, temperature=0.0, word_timestamps=True
3442
)
35-
check_result(result)
43+
check_result(result, f'{load_name}->transcribe')
44+
if not hasattr(model, 'transcribe_minimal'):
45+
return
3646
result = model.transcribe_minimal(
3747
audio_path, language=language, temperature=0.0, word_timestamps=True
3848
)
39-
check_result(result)
49+
check_result(result, f'{load_name}->transcribe_minimal')
4050

4151

4252
def test():

0 commit comments

Comments
 (0)