Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions tests/test_faster_whisper_translate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Tests for wyoming-faster-whisper"""

import asyncio
import re
import sys
import wave
from asyncio.subprocess import PIPE
from pathlib import Path

import pytest
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioStart, AudioStop, wav_to_chunks
from wyoming.event import async_read_event, async_write_event
from wyoming.info import Describe, Info

_DIR = Path(__file__).parent
_PROGRAM_DIR = _DIR.parent
_LOCAL_DIR = _PROGRAM_DIR / "local"
_SAMPLES_PER_CHUNK = 1024

# Need to give time for the model to download
_START_TIMEOUT = 60
_TRANSCRIBE_TIMEOUT = 60


@pytest.mark.asyncio
async def test_faster_whisper() -> None:
proc = await asyncio.create_subprocess_exec(
sys.executable,
"-m",
"wyoming_faster_whisper",
"--uri",
"stdio://",
"--model",
"base-int8",
"--data-dir",
str(_LOCAL_DIR),
"--task",
"translate",
"--language",
"fr",
stdin=PIPE,
stdout=PIPE,
)
assert proc.stdin is not None
assert proc.stdout is not None

# Check info
await async_write_event(Describe().event(), proc.stdin)
while True:
event = await asyncio.wait_for(
async_read_event(proc.stdout), timeout=_START_TIMEOUT
)
assert event is not None

if not Info.is_type(event.type):
continue

info = Info.from_event(event)
assert len(info.asr) == 1, "Expected one asr service"
asr = info.asr[0]
assert len(asr.models) > 0, "Expected at least one model"
assert any(
m.name == "base-int8" for m in asr.models
), "Expected base-int8 model"
break

# We want to use the whisper model
await async_write_event(Transcribe(name="base-int8").event(), proc.stdin)

# Test known WAV
with wave.open(str(_DIR / "whats_your_name_french.wav"), "rb") as example_wav:
await async_write_event(
AudioStart(
rate=example_wav.getframerate(),
width=example_wav.getsampwidth(),
channels=example_wav.getnchannels(),
).event(),
proc.stdin,
)
for chunk in wav_to_chunks(example_wav, _SAMPLES_PER_CHUNK):
await async_write_event(chunk.event(), proc.stdin)

await async_write_event(AudioStop().event(), proc.stdin)

while True:
event = await asyncio.wait_for(
async_read_event(proc.stdout), timeout=_TRANSCRIBE_TIMEOUT
)
assert event is not None

if not Transcript.is_type(event.type):
continue

transcript = Transcript.from_event(event)
text = transcript.text.lower().strip()
text = re.sub(r"[^a-z ]", "", text)
assert text == "how do you call yourself"
break

# Need to close stdin for graceful termination
proc.stdin.close()
_, stderr = await proc.communicate()

assert proc.returncode == 0, stderr.decode()
Binary file added tests/whats_your_name_french.wav
Binary file not shown.
6 changes: 6 additions & 0 deletions wyoming_faster_whisper/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ async def main() -> None:
"--initial-prompt",
help="Optional text to provide as a prompt for the first window",
)
parser.add_argument(
"--task",
default="transcribe",
help="Whether to transcribe or translate (default: transcribe)",
choices=["transcribe", "translate"],
)
#
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
parser.add_argument(
Expand Down
2 changes: 2 additions & 0 deletions wyoming_faster_whisper/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
self.model_lock = model_lock
self.initial_prompt = initial_prompt
self._language = self.cli_args.language
self._task = self.cli_args.task
self._wav_dir = tempfile.TemporaryDirectory()
self._wav_path = os.path.join(self._wav_dir.name, "speech.wav")
self._wav_file: Optional[wave.Wave_write] = None
Expand Down Expand Up @@ -71,6 +72,7 @@ async def handle_event(self, event: Event) -> bool:
beam_size=self.cli_args.beam_size,
language=self._language,
initial_prompt=self.initial_prompt,
task=self._task,
)

text = " ".join(segment.text for segment in segments)
Expand Down