Skip to content

Commit f2b079f

Browse files
committed
Add support for GigaAM for Russian
1 parent 4f3f9b6 commit f2b079f

File tree

6 files changed

+185
-1
lines changed

6 files changed

+185
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## 3.0.0
44

55
- Add support for `sherpa-onnx` and Nvidia's parakeet model
6+
- Add support for [GigaAM](https://github.com/salute-developers/GigaAM) for Russian via [`onnx-asr`](https://github.com/istupakov/onnx-asr)
67
- Add `--stt-library` to select speech-to-text library (deprecate `--use-transformers`)
78
- Default `--model` to "auto" (prefer parakeet)
89
- Add Docker build here

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ RUN \
2323
\
2424
&& .venv/bin/pip3 install --no-cache-dir \
2525
--extra-index-url https://www.piwheels.org/simple \
26-
-e '.[transformers,sherpa]' \
26+
-e '.[transformers,sherpa,onnx-asr]' \
2727
\
2828
&& rm -rf /var/lib/apt/lists/*
2929

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,6 @@ transformers = [
7575
sherpa = [
7676
"sherpa-onnx==1.12.15",
7777
]
78+
onnx_asr = [
79+
"onnx-asr[cpu,hub]==0.7.0",
80+
]

wyoming_faster_whisper/__main__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ async def main() -> None:
117117
stt_library = SttLibrary.SHERPA
118118
except ImportError:
119119
stt_library = SttLibrary.FASTER_WHISPER
120+
elif args.language == "ru":
121+
# Prefer GigaAM via onnx-asr
122+
try:
123+
from .sherpa_handler import SherpaModel
124+
125+
stt_library = SttLibrary.ONNX_ASR
126+
except ImportError:
127+
stt_library = SttLibrary.FASTER_WHISPER
120128

121129
_LOGGER.debug("Speech-to-text library automatically selected: %s", stt_library)
122130

@@ -194,6 +202,13 @@ async def main() -> None:
194202
whisper_model = TransformersWhisperModel(
195203
args.model, args.download_dir, args.local_files_only
196204
)
205+
elif stt_library == SttLibrary.ONNX_ASR:
206+
# Use onnx-asr
207+
from .onnx_asr_handler import OnnxAsrModel
208+
209+
whisper_model = OnnxAsrModel(
210+
args.model, args.download_dir, args.local_files_only
211+
)
197212
else:
198213
# Use faster-whisper
199214
whisper_model = faster_whisper.WhisperModel(
@@ -254,6 +269,22 @@ async def main() -> None:
254269
model_lock,
255270
)
256271
)
272+
elif stt_library == SttLibrary.ONNX_ASR:
273+
# Use onnx-asr
274+
from .onnx_asr_handler import OnnxAsrEventHandler, OnnxAsrModel
275+
276+
assert isinstance(whisper_model, OnnxAsrModel)
277+
278+
await server.run(
279+
partial(
280+
OnnxAsrEventHandler,
281+
wyoming_info,
282+
args.language,
283+
args.beam_size,
284+
whisper_model,
285+
model_lock,
286+
)
287+
)
257288
else:
258289
# faster-whisper
259290
from .faster_whisper_handler import FasterWhisperEventHandler
@@ -296,6 +327,9 @@ def guess_model(stt_library: SttLibrary, language: Optional[str], is_arm: bool)
296327

297328
return "openai/whisper-base"
298329

330+
if stt_library == SttLibrary.ONNX_ASR:
331+
return "gigaam-v2-rnnt"
332+
299333
# faster-whisper
300334
if is_arm:
301335
return "tiny-int8"

wyoming_faster_whisper/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class SttLibrary(str, Enum):
1010
FASTER_WHISPER = "faster-whisper"
1111
TRANSFORMERS = "transformers"
1212
SHERPA = "sherpa"
13+
ONNX_ASR = "onnx-asr"
1314

1415

1516
PARAKEET_LANGUAGES = {
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Code for transcription using the onnx-asr library."""
2+
3+
import asyncio
4+
import logging
5+
import os
6+
import tempfile
7+
import wave
8+
from pathlib import Path
9+
from typing import Optional, Union
10+
from unittest.mock import patch
11+
12+
import numpy as np
13+
import onnx_asr
14+
from huggingface_hub import snapshot_download
15+
from wyoming.asr import Transcribe, Transcript
16+
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
17+
from wyoming.event import Event
18+
from wyoming.info import Describe, Info
19+
from wyoming.server import AsyncEventHandler
20+
21+
_LOGGER = logging.getLogger(__name__)
22+
23+
_RATE = 16000
24+
25+
26+
class OnnxAsrModel:
27+
"""Wrapper for onnx-asr model."""
28+
29+
def __init__(
30+
self, model_id: str, cache_dir: Union[str, Path], local_files_only: bool
31+
) -> None:
32+
"""Initialize model."""
33+
34+
# Force download to our cache dir
35+
def snapshot_download_with_cache(*args, **kwargs) -> str:
36+
kwargs["cache_dir"] = str(Path(cache_dir).resolve())
37+
kwargs["local_files_only"] = local_files_only
38+
39+
return snapshot_download(*args, **kwargs)
40+
41+
with patch("huggingface_hub.snapshot_download", snapshot_download_with_cache):
42+
self.onnx_model = onnx_asr.load_model(model_id)
43+
44+
def transcribe(
45+
self, wav_path: Union[str, Path], language: Optional[str], *args, **kwargs
46+
) -> str:
47+
"""Returns transcription for WAV file.
48+
49+
WAV file must be 16Khz 16-bit mono audio.
50+
"""
51+
wav_file: wave.Wave_read = wave.open(str(wav_path), "rb")
52+
with wav_file:
53+
assert wav_file.getframerate() == _RATE, "Sample rate must be 16Khz"
54+
assert wav_file.getsampwidth() == 2, "Width must be 16-bit (2 bytes)"
55+
assert wav_file.getnchannels() == 1, "Audio must be mono"
56+
audio_bytes = wav_file.readframes(wav_file.getnframes())
57+
58+
audio_array = (
59+
np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32767.0
60+
)
61+
62+
recognize_kwargs = {}
63+
if language:
64+
recognize_kwargs["language"] = language
65+
66+
text = self.onnx_model.recognize(audio_array, **recognize_kwargs)
67+
return text
68+
69+
70+
class OnnxAsrEventHandler(AsyncEventHandler):
71+
"""Event handler for clients."""
72+
73+
def __init__(
74+
self,
75+
wyoming_info: Info,
76+
language: Optional[str],
77+
beam_size: int,
78+
model: OnnxAsrModel,
79+
model_lock: asyncio.Lock,
80+
*args,
81+
**kwargs,
82+
) -> None:
83+
super().__init__(*args, **kwargs)
84+
85+
self.wyoming_info_event = wyoming_info.event()
86+
self.model = model
87+
self.model_lock = model_lock
88+
self._beam_size = beam_size
89+
self._language = language
90+
self._wav_dir = tempfile.TemporaryDirectory()
91+
self._wav_path = os.path.join(self._wav_dir.name, "speech.wav")
92+
self._wav_file: Optional[wave.Wave_write] = None
93+
self._audio_converter = AudioChunkConverter(rate=_RATE, width=2, channels=1)
94+
95+
async def handle_event(self, event: Event) -> bool:
96+
if AudioChunk.is_type(event.type):
97+
chunk = self._audio_converter.convert(AudioChunk.from_event(event))
98+
99+
if self._wav_file is None:
100+
self._wav_file = wave.open(self._wav_path, "wb")
101+
self._wav_file.setframerate(chunk.rate)
102+
self._wav_file.setsampwidth(chunk.width)
103+
self._wav_file.setnchannels(chunk.channels)
104+
105+
self._wav_file.writeframes(chunk.audio)
106+
return True
107+
108+
if AudioStop.is_type(event.type):
109+
_LOGGER.debug(
110+
"Audio stopped. Transcribing with language=%s", self._language
111+
)
112+
assert self._wav_file is not None
113+
114+
self._wav_file.close()
115+
self._wav_file = None
116+
117+
async with self.model_lock:
118+
text = self.model.transcribe(
119+
self._wav_path,
120+
language=self._language,
121+
)
122+
123+
_LOGGER.info(text)
124+
125+
await self.write_event(Transcript(text=text).event())
126+
_LOGGER.debug("Completed request")
127+
128+
# Reset
129+
self._language = self._language
130+
131+
return False
132+
133+
if Transcribe.is_type(event.type):
134+
transcribe = Transcribe.from_event(event)
135+
if transcribe.language:
136+
self._language = transcribe.language
137+
_LOGGER.debug("Language set to %s", transcribe.language)
138+
return True
139+
140+
if Describe.is_type(event.type):
141+
await self.write_event(self.wyoming_info_event)
142+
_LOGGER.debug("Sent info")
143+
return True
144+
145+
return True

0 commit comments

Comments
 (0)