Skip to content

Commit 69e30a7

Browse files
authored
Merge pull request #2 from cstorm125/change_airesearch
Change to airesearch/wav2vec2-large-xlsr-53-th
2 parents e3c7e3a + b6e75f3 commit 69e30a7

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

pythaiasr/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,23 @@
33
from datasets import ClassLabel
44
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2ForCTC, Wav2Vec2Processor
55
import torchaudio
6-
import librosa
76
import numpy as np
87

9-
processor = Wav2Vec2Processor.from_pretrained("chompk/wav2vec2-large-xlsr-thai-tokenized")
10-
model = Wav2Vec2ForCTC.from_pretrained("chompk/wav2vec2-large-xlsr-thai-tokenized")
8+
processor = Wav2Vec2Processor.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
9+
model = Wav2Vec2ForCTC.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
1110
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1211

1312

1413
def speech_file_to_array_fn(batch: dict) -> dict:
1514
speech_array, sampling_rate = torchaudio.load(batch["path"])
16-
batch["speech"] = speech_array[0].numpy()
15+
batch["speech"] = speech_array[0]
1716
batch["sampling_rate"] = sampling_rate
1817
return batch
1918

2019

2120
def resample(batch: dict) -> dict:
22-
batch["speech"] = librosa.resample(np.asarray(batch["speech"]), 48_000, 16_000)
21+
resampler=torchaudio.transforms.Resample(batch['sampling_rate'], 16_000)
22+
batch["speech"] = resampler(batch["speech"]).numpy()
2323
batch["sampling_rate"] = 16_000
2424
return batch
2525

@@ -30,7 +30,7 @@ def prepare_dataset(batch: dict) -> dict:
3030
return batch
3131

3232

33-
def asr(file: str, show_pad: bool = False) -> str:
33+
def asr(file: str, tokenized: bool = False) -> str:
3434
"""
3535
:param str file: path of sound file
3636
:param bool show_pad: show [PAD] in output
@@ -44,9 +44,9 @@ def asr(file: str, show_pad: bool = False) -> str:
4444
logits = model(input_dict.input_values.to(device)).logits
4545
pred_ids = torch.argmax(logits, dim=-1)[0]
4646

47-
if show_pad:
47+
if tokenized:
4848
txt = processor.decode(pred_ids)
4949
else:
50-
txt = processor.decode(pred_ids).replace('[PAD]','')
50+
txt = processor.decode(pred_ids).replace(' ','')
5151

5252
return txt

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def read(*paths):
1212
'datasets',
1313
'transformers',
1414
'torchaudio',
15+
'soundfile',
1516
'torch',
16-
'librosa',
1717
'numpy'
1818
]
1919

0 commit comments

Comments
 (0)