Skip to content

Commit 090d682

Browse files
committed
updated tests
1 parent 65018aa commit 090d682

File tree

2 files changed

+45
-20
lines changed

2 files changed

+45
-20
lines changed

test/test_align.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,70 @@
33
import stable_whisper
44

55

6-
def check_result(result, expected_text: str):
6+
def check_result(result, expected_text: str, test_name: str):
77
assert result.text == expected_text
88

99
timing_checked = False
1010
for segment in result:
1111
for word in segment:
12-
assert word.start < word.end
12+
assert word.start < word.end, (word.start, word.end, test_name)
1313
if word.word.strip(" ,") == "americans":
14-
assert word.start <= 1.8, word.start
15-
assert word.end >= 1.8, word.end
14+
assert word.start <= 1.8, (word.start, test_name)
15+
assert word.end >= 1.8, (word.end, test_name)
1616
timing_checked = True
1717

18-
assert timing_checked
18+
assert timing_checked, test_name
1919

2020

21-
def test_transcribe(model0_name: str, model1_name: str):
21+
def test_align(model_names):
2222
device = "cuda" if torch.cuda.is_available() else "cpu"
23-
model0 = stable_whisper.load_model(model0_name, device=device)
2423
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
25-
26-
language = "en" if model0_name.endswith(".en") else None
27-
orig_result = model0.transcribe(
28-
audio_path, language=language, temperature=0.0, word_timestamps=True
24+
models = [stable_whisper.load_model(name, device=device) for name in model_names]
25+
orig_result = models[0].transcribe(
26+
audio_path, language='en', temperature=0.0, word_timestamps=True
2927
)
3028
for word in orig_result.all_words():
3129
word.word = word.word.replace('Americans', 'americans')
3230

33-
model1 = stable_whisper.load_model(model1_name, device=device)
31+
def single_test(m, meth: str, prep, extra_check, **kwargs):
32+
model_type = 'multilingual-model' if m.is_multilingual else 'en-model'
33+
meth = getattr(m, meth)
34+
test_name = f'{model_type} {meth.__name__}(WhisperResult)'
35+
try:
36+
result = meth(audio_path, orig_result, **kwargs)
37+
check_same_segment_text(orig_result, result)
38+
except Exception as e:
39+
raise Exception(f'failed test {test_name} -> {e.__class__.__name__}: {e}')
40+
check_result(result, orig_result.text, test_name)
41+
42+
test_name = f'{model_type} {meth.__name__}(plain-text)'
43+
try:
44+
result = meth(audio_path, prep(orig_result), language=orig_result.language)
45+
if extra_check:
46+
extra_check(orig_result, result)
47+
except Exception as e:
48+
raise Exception(f'failed test {test_name} -> {e.__class__.__name__}: {e}')
49+
check_result(result, orig_result.text, test_name)
50+
51+
def get_text(res):
52+
return res.text
53+
54+
def get_segment_dicts(res):
55+
return [dict(start=s.start, end=s.end, text=s.text) for s in res]
3456

35-
result = model1.align(audio_path, orig_result, original_split=True)
36-
assert [s.text for s in result] == [s.text for s in orig_result]
37-
check_result(result, orig_result.text)
57+
def check_same_segment_text(res0, res1):
58+
assert [s.text for s in res0] == [s.text for s in res1], 'mismatch segment text'
3859

39-
result = model1.align(audio_path, orig_result.text, language=orig_result.language)
40-
check_result(result, orig_result.text)
60+
for model in models:
61+
for method in ('align', 'align_words'):
62+
options = dict(original_split=True) if method == 'align' else {}
63+
preprocess = get_text if method == 'align' else get_segment_dicts
64+
check_seg = None if method == 'align' else check_same_segment_text
65+
single_test(model, method, preprocess, check_seg, **options)
4166

4267

4368
def test():
44-
test_transcribe('tiny', 'tiny.en')
69+
test_align(['tiny', 'tiny.en'])
4570

4671

4772
if __name__ == '__main__':

test/test_refine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def check_result(result, orig_result, expect_change: bool = True):
2323
assert timing_checked
2424

2525

26-
def test_transcribe(model0_name: str, model1_name: str):
26+
def test_refine(model0_name: str, model1_name: str):
2727
device = "cuda" if torch.cuda.is_available() else "cpu"
2828
model0 = stable_whisper.load_model(model0_name, device=device)
2929
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
@@ -40,7 +40,7 @@ def test_transcribe(model0_name: str, model1_name: str):
4040

4141

4242
def test():
43-
test_transcribe('tiny.en', 'tiny')
43+
test_refine('tiny.en', 'tiny')
4444

4545

4646
if __name__ == '__main__':

0 commit comments

Comments
 (0)