Skip to content

Commit 1f44b3d

Browse files
committed
added locate() test
1 parent 7338f85 commit 1f44b3d

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,4 @@ jobs:
4646
- run: python test/test_transcribe.py
4747
- run: python test/test_align.py
4848
- run: python test/test_refine.py
49+
- run: python test/test_locate.py

test/test_locate.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import os
2+
import torch
3+
import stable_whisper
4+
5+
6+
def test_locate(model_names):
7+
device = "cuda" if torch.cuda.is_available() else "cpu"
8+
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
9+
models = [stable_whisper.load_model(name, device=device) for name in model_names]
10+
for model in models:
11+
matches = model.locate(audio_path, 'americans', 'en', mode=0)
12+
assert len(matches), len(matches)
13+
words = [word.word.lower().strip(',').strip() for match in matches for word in match]
14+
assert 'americans' in words, words
15+
matches = model.locate(audio_path, 'americans', 'en', mode=1)
16+
assert len(matches), len(matches)
17+
any(['americans' in match['duration_window_text'].lower() for match in matches])
18+
19+
20+
def test():
21+
test_locate(['tiny', 'tiny.en'])
22+
23+
24+
if __name__ == '__main__':
25+
test()

0 commit comments

Comments
 (0)