|
3 | 3 | import stable_whisper |
4 | 4 |
|
5 | 5 |
|
6 | | -def check_result(result, expected_text: str): |
| 6 | +def check_result(result, expected_text: str, test_name: str): |
7 | 7 | assert result.text == expected_text |
8 | 8 |
|
9 | 9 | timing_checked = False |
10 | 10 | for segment in result: |
11 | 11 | for word in segment: |
12 | | - assert word.start < word.end |
| 12 | + assert word.start < word.end, (word.start, word.end, test_name) |
13 | 13 | if word.word.strip(" ,") == "americans": |
14 | | - assert word.start <= 1.8, word.start |
15 | | - assert word.end >= 1.8, word.end |
| 14 | + assert word.start <= 1.8, (word.start, test_name) |
| 15 | + assert word.end >= 1.8, (word.end, test_name) |
16 | 16 | timing_checked = True |
17 | 17 |
|
18 | | - assert timing_checked |
| 18 | + assert timing_checked, test_name |
19 | 19 |
|
20 | 20 |
|
21 | | -def test_transcribe(model0_name: str, model1_name: str): |
| 21 | +def test_align(model_names): |
22 | 22 | device = "cuda" if torch.cuda.is_available() else "cpu" |
23 | | - model0 = stable_whisper.load_model(model0_name, device=device) |
24 | 23 | audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") |
25 | | - |
26 | | - language = "en" if model0_name.endswith(".en") else None |
27 | | - orig_result = model0.transcribe( |
28 | | - audio_path, language=language, temperature=0.0, word_timestamps=True |
| 24 | + models = [stable_whisper.load_model(name, device=device) for name in model_names] |
| 25 | + orig_result = models[0].transcribe( |
| 26 | + audio_path, language='en', temperature=0.0, word_timestamps=True |
29 | 27 | ) |
30 | 28 | for word in orig_result.all_words(): |
31 | 29 | word.word = word.word.replace('Americans', 'americans') |
32 | 30 |
|
33 | | - model1 = stable_whisper.load_model(model1_name, device=device) |
| 31 | + def single_test(m, meth: str, prep, extra_check, **kwargs): |
| 32 | + model_type = 'multilingual-model' if m.is_multilingual else 'en-model' |
| 33 | + meth = getattr(m, meth) |
| 34 | + test_name = f'{model_type} {meth.__name__}(WhisperResult)' |
| 35 | + try: |
| 36 | + result = meth(audio_path, orig_result, **kwargs) |
| 37 | + check_same_segment_text(orig_result, result) |
| 38 | + except Exception as e: |
| 39 | + raise Exception(f'failed test {test_name} -> {e.__class__.__name__}: {e}') |
| 40 | + check_result(result, orig_result.text, test_name) |
| 41 | + |
| 42 | + test_name = f'{model_type} {meth.__name__}(plain-text)' |
| 43 | + try: |
| 44 | + result = meth(audio_path, prep(orig_result), language=orig_result.language) |
| 45 | + if extra_check: |
| 46 | + extra_check(orig_result, result) |
| 47 | + except Exception as e: |
| 48 | + raise Exception(f'failed test {test_name} -> {e.__class__.__name__}: {e}') |
| 49 | + check_result(result, orig_result.text, test_name) |
| 50 | + |
| 51 | + def get_text(res): |
| 52 | + return res.text |
| 53 | + |
| 54 | + def get_segment_dicts(res): |
| 55 | + return [dict(start=s.start, end=s.end, text=s.text) for s in res] |
34 | 56 |
|
35 | | - result = model1.align(audio_path, orig_result, original_split=True) |
36 | | - assert [s.text for s in result] == [s.text for s in orig_result] |
37 | | - check_result(result, orig_result.text) |
| 57 | + def check_same_segment_text(res0, res1): |
| 58 | + assert [s.text for s in res0] == [s.text for s in res1], 'mismatch segment text' |
38 | 59 |
|
39 | | - result = model1.align(audio_path, orig_result.text, language=orig_result.language) |
40 | | - check_result(result, orig_result.text) |
| 60 | + for model in models: |
| 61 | + for method in ('align', 'align_words'): |
| 62 | + options = dict(original_split=True) if method == 'align' else {} |
| 63 | + preprocess = get_text if method == 'align' else get_segment_dicts |
| 64 | + check_seg = None if method == 'align' else check_same_segment_text |
| 65 | + single_test(model, method, preprocess, check_seg, **options) |
41 | 66 |
|
42 | 67 |
|
43 | 68 | def test(): |
44 | | - test_transcribe('tiny', 'tiny.en') |
| 69 | + test_align(['tiny', 'tiny.en']) |
45 | 70 |
|
46 | 71 |
|
47 | 72 | if __name__ == '__main__': |
|
0 commit comments