|
| 1 | +"""Code for transcription using the onnx-asr library.""" |
| 2 | + |
| 3 | +import asyncio |
| 4 | +import logging |
| 5 | +import os |
| 6 | +import tempfile |
| 7 | +import wave |
| 8 | +from pathlib import Path |
| 9 | +from typing import Optional, Union |
| 10 | +from unittest.mock import patch |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import onnx_asr |
| 14 | +from huggingface_hub import snapshot_download |
| 15 | +from wyoming.asr import Transcribe, Transcript |
| 16 | +from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop |
| 17 | +from wyoming.event import Event |
| 18 | +from wyoming.info import Describe, Info |
| 19 | +from wyoming.server import AsyncEventHandler |
| 20 | + |
| 21 | +_LOGGER = logging.getLogger(__name__) |
| 22 | + |
| 23 | +_RATE = 16000 |
| 24 | + |
| 25 | + |
| 26 | +class OnnxAsrModel: |
| 27 | + """Wrapper for onnx-asr model.""" |
| 28 | + |
| 29 | + def __init__( |
| 30 | + self, model_id: str, cache_dir: Union[str, Path], local_files_only: bool |
| 31 | + ) -> None: |
| 32 | + """Initialize model.""" |
| 33 | + |
| 34 | + # Force download to our cache dir |
| 35 | + def snapshot_download_with_cache(*args, **kwargs) -> str: |
| 36 | + kwargs["cache_dir"] = str(Path(cache_dir).resolve()) |
| 37 | + kwargs["local_files_only"] = local_files_only |
| 38 | + |
| 39 | + return snapshot_download(*args, **kwargs) |
| 40 | + |
| 41 | + with patch("huggingface_hub.snapshot_download", snapshot_download_with_cache): |
| 42 | + self.onnx_model = onnx_asr.load_model(model_id) |
| 43 | + |
| 44 | + def transcribe( |
| 45 | + self, wav_path: Union[str, Path], language: Optional[str], *args, **kwargs |
| 46 | + ) -> str: |
| 47 | + """Returns transcription for WAV file. |
| 48 | +
|
| 49 | + WAV file must be 16Khz 16-bit mono audio. |
| 50 | + """ |
| 51 | + wav_file: wave.Wave_read = wave.open(str(wav_path), "rb") |
| 52 | + with wav_file: |
| 53 | + assert wav_file.getframerate() == _RATE, "Sample rate must be 16Khz" |
| 54 | + assert wav_file.getsampwidth() == 2, "Width must be 16-bit (2 bytes)" |
| 55 | + assert wav_file.getnchannels() == 1, "Audio must be mono" |
| 56 | + audio_bytes = wav_file.readframes(wav_file.getnframes()) |
| 57 | + |
| 58 | + audio_array = ( |
| 59 | + np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32767.0 |
| 60 | + ) |
| 61 | + |
| 62 | + recognize_kwargs = {} |
| 63 | + if language: |
| 64 | + recognize_kwargs["language"] = language |
| 65 | + |
| 66 | + text = self.onnx_model.recognize(audio_array, **recognize_kwargs) |
| 67 | + return text |
| 68 | + |
| 69 | + |
| 70 | +class OnnxAsrEventHandler(AsyncEventHandler): |
| 71 | + """Event handler for clients.""" |
| 72 | + |
| 73 | + def __init__( |
| 74 | + self, |
| 75 | + wyoming_info: Info, |
| 76 | + language: Optional[str], |
| 77 | + beam_size: int, |
| 78 | + model: OnnxAsrModel, |
| 79 | + model_lock: asyncio.Lock, |
| 80 | + *args, |
| 81 | + **kwargs, |
| 82 | + ) -> None: |
| 83 | + super().__init__(*args, **kwargs) |
| 84 | + |
| 85 | + self.wyoming_info_event = wyoming_info.event() |
| 86 | + self.model = model |
| 87 | + self.model_lock = model_lock |
| 88 | + self._beam_size = beam_size |
| 89 | + self._language = language |
| 90 | + self._wav_dir = tempfile.TemporaryDirectory() |
| 91 | + self._wav_path = os.path.join(self._wav_dir.name, "speech.wav") |
| 92 | + self._wav_file: Optional[wave.Wave_write] = None |
| 93 | + self._audio_converter = AudioChunkConverter(rate=_RATE, width=2, channels=1) |
| 94 | + |
| 95 | + async def handle_event(self, event: Event) -> bool: |
| 96 | + if AudioChunk.is_type(event.type): |
| 97 | + chunk = self._audio_converter.convert(AudioChunk.from_event(event)) |
| 98 | + |
| 99 | + if self._wav_file is None: |
| 100 | + self._wav_file = wave.open(self._wav_path, "wb") |
| 101 | + self._wav_file.setframerate(chunk.rate) |
| 102 | + self._wav_file.setsampwidth(chunk.width) |
| 103 | + self._wav_file.setnchannels(chunk.channels) |
| 104 | + |
| 105 | + self._wav_file.writeframes(chunk.audio) |
| 106 | + return True |
| 107 | + |
| 108 | + if AudioStop.is_type(event.type): |
| 109 | + _LOGGER.debug( |
| 110 | + "Audio stopped. Transcribing with language=%s", self._language |
| 111 | + ) |
| 112 | + assert self._wav_file is not None |
| 113 | + |
| 114 | + self._wav_file.close() |
| 115 | + self._wav_file = None |
| 116 | + |
| 117 | + async with self.model_lock: |
| 118 | + text = self.model.transcribe( |
| 119 | + self._wav_path, |
| 120 | + language=self._language, |
| 121 | + ) |
| 122 | + |
| 123 | + _LOGGER.info(text) |
| 124 | + |
| 125 | + await self.write_event(Transcript(text=text).event()) |
| 126 | + _LOGGER.debug("Completed request") |
| 127 | + |
| 128 | + # Reset |
| 129 | + self._language = self._language |
| 130 | + |
| 131 | + return False |
| 132 | + |
| 133 | + if Transcribe.is_type(event.type): |
| 134 | + transcribe = Transcribe.from_event(event) |
| 135 | + if transcribe.language: |
| 136 | + self._language = transcribe.language |
| 137 | + _LOGGER.debug("Language set to %s", transcribe.language) |
| 138 | + return True |
| 139 | + |
| 140 | + if Describe.is_type(event.type): |
| 141 | + await self.write_event(self.wyoming_info_event) |
| 142 | + _LOGGER.debug("Sent info") |
| 143 | + return True |
| 144 | + |
| 145 | + return True |
0 commit comments