@@ -19,6 +19,7 @@ def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=Non
1919 * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model
2020 """
2121 self .processor = AutoProcessor .from_pretrained (model )
22+ self .model_name = model
2223 self .model = AutoModelForCTC .from_pretrained (model )
2324 if device != None :
2425 self .device = torch .device (device )
@@ -40,10 +41,9 @@ def prepare_dataset(self, batch: dict) -> dict:
4041 batch ["input_values" ] = self .processor (batch ["speech" ], sampling_rate = batch ["sampling_rate" ]).input_values
4142 return batch
4243
43- def __call__ (self , file : str , tokenized : bool = False ) -> str :
44+ def __call__ (self , file : str ) -> str :
4445 """
4546 :param str file: path of sound file
46- :param bool show_pad: show [PAD] in output
4747 :param str model: The ASR model
4848 """
4949 b = {}
@@ -53,20 +53,16 @@ def __call__(self, file: str, tokenized: bool = False) -> str:
5353 logits = self .model (input_dict .input_values ).logits
5454 pred_ids = torch .argmax (logits , dim = - 1 )[0 ]
5555
56- if tokenized :
57- txt = self .processor .batch_decode (logits .detach ().numpy ()).text
58- else :
59- txt = self .processor .batch_decode (logits .detach ().numpy ()).text .replace (' ' ,'' )
56+ txt = self .processor .batch_decode (logits .detach ().numpy ()).text [0 ]
6057 return txt
6158
6259_model_name = "airesearch/wav2vec2-large-xlsr-53-th"
6360_model = None
6461
6562
66- def asr (file : str , tokenized : bool = False , model : str = _model_name ) -> str :
63+ def asr (file : str , model : str = _model_name ) -> str :
6764 """
6865 :param str file: path of sound file
69- :param bool show_pad: show [PAD] in output
7066 :param str model: The ASR model
7167 :return: thai text from ASR
7268 :rtype: str
@@ -81,4 +77,4 @@ def asr(file: str, tokenized: bool = False, model: str = _model_name) -> str:
8177 _model = ASR (model )
8278 _model_name = model
8379
84- return _model (file = file , tokenized = tokenized )
80+ return _model (file = file )
0 commit comments