33from datasets import ClassLabel
44from transformers import Wav2Vec2CTCTokenizer , Wav2Vec2ForCTC , Wav2Vec2Processor
55import torchaudio
6- import librosa
76import numpy as np
87
9- processor = Wav2Vec2Processor .from_pretrained ("chompk /wav2vec2-large-xlsr-thai-tokenized " )
10- model = Wav2Vec2ForCTC .from_pretrained ("chompk /wav2vec2-large-xlsr-thai-tokenized " )
8+ processor = Wav2Vec2Processor .from_pretrained ("airesearch /wav2vec2-large-xlsr-53-th " )
9+ model = Wav2Vec2ForCTC .from_pretrained ("airesearch /wav2vec2-large-xlsr-53-th " )
1110device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
1211
1312
1413def speech_file_to_array_fn (batch : dict ) -> dict :
1514 speech_array , sampling_rate = torchaudio .load (batch ["path" ])
16- batch ["speech" ] = speech_array [0 ]. numpy ()
15+ batch ["speech" ] = speech_array [0 ]
1716 batch ["sampling_rate" ] = sampling_rate
1817 return batch
1918
2019
2120def resample (batch : dict ) -> dict :
22- batch ["speech" ] = librosa .resample (np .asarray (batch ["speech" ]), 48_000 , 16_000 )
21+ resampler = torchaudio .transforms .Resample (batch ['sampling_rate' ], 16_000 )
22+ batch ["speech" ] = resampler (batch ["speech" ]).numpy ()
2323 batch ["sampling_rate" ] = 16_000
2424 return batch
2525
@@ -30,7 +30,7 @@ def prepare_dataset(batch: dict) -> dict:
3030 return batch
3131
3232
33- def asr (file : str , show_pad : bool = False ) -> str :
33+ def asr (file : str , tokenized : bool = False ) -> str :
3434 """
3535 :param str file: path of sound file
3636 :param bool show_pad: show [PAD] in output
@@ -44,9 +44,9 @@ def asr(file: str, show_pad: bool = False) -> str:
4444 logits = model (input_dict .input_values .to (device )).logits
4545 pred_ids = torch .argmax (logits , dim = - 1 )[0 ]
4646
47- if show_pad :
47+ if tokenized :
4848 txt = processor .decode (pred_ids )
4949 else :
50- txt = processor .decode (pred_ids ).replace ('[PAD] ' ,'' )
50+ txt = processor .decode (pred_ids ).replace (' ' ,'' )
5151
5252 return txt
0 commit comments