Skip to content

Commit 636e67b

Browse files
committed
PyThaiASR v1.1.0
- remove tokenized
1 parent 30abe5f commit 636e67b

File tree

3 files changed

+7
-12
lines changed

3 files changed

+7
-12
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,10 @@ print(asr(file))
3737
### API
3838

3939
```python
40-
asr(file: str, show_pad: bool = False, model: str = "airesearch/wav2vec2-large-xlsr-53-th")
40+
asr(file: str, model: str = "airesearch/wav2vec2-large-xlsr-53-th")
4141
```
4242

4343
- file: path of sound file
44-
- show_pad: show [PAD] in output
4544
- model: The ASR model
4645
- return: thai text from ASR
4746

pythaiasr/__init__.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=Non
1919
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model
2020
"""
2121
self.processor = AutoProcessor.from_pretrained(model)
22+
self.model_name = model
2223
self.model = AutoModelForCTC.from_pretrained(model)
2324
if device!=None:
2425
self.device = torch.device(device)
@@ -40,10 +41,9 @@ def prepare_dataset(self, batch: dict) -> dict:
4041
batch["input_values"] = self.processor(batch["speech"], sampling_rate=batch["sampling_rate"]).input_values
4142
return batch
4243

43-
def __call__(self, file: str, tokenized: bool = False) -> str:
44+
def __call__(self, file: str) -> str:
4445
"""
4546
:param str file: path of sound file
46-
:param bool show_pad: show [PAD] in output
4747
:param str model: The ASR model
4848
"""
4949
b = {}
@@ -53,20 +53,16 @@ def __call__(self, file: str, tokenized: bool = False) -> str:
5353
logits = self.model(input_dict.input_values).logits
5454
pred_ids = torch.argmax(logits, dim=-1)[0]
5555

56-
if tokenized:
57-
txt = self.processor.batch_decode(logits.detach().numpy()).text
58-
else:
59-
txt = self.processor.batch_decode(logits.detach().numpy()).text.replace(' ','')
56+
txt = self.processor.batch_decode(logits.detach().numpy()).text[0]
6057
return txt
6158

6259
_model_name = "airesearch/wav2vec2-large-xlsr-53-th"
6360
_model = None
6461

6562

66-
def asr(file: str, tokenized: bool = False, model: str = _model_name) -> str:
63+
def asr(file: str, model: str = _model_name) -> str:
6764
"""
6865
:param str file: path of sound file
69-
:param bool show_pad: show [PAD] in output
7066
:param str model: The ASR model
7167
:return: thai text from ASR
7268
:rtype: str
@@ -81,4 +77,4 @@ def asr(file: str, tokenized: bool = False, model: str = _model_name) -> str:
8177
_model = ASR(model)
8278
_model_name = model
8379

84-
return _model(file=file, tokenized=tokenized)
80+
return _model(file=file)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def read(*paths):
2727

2828
setup(
2929
name='pythaiasr',
30-
version='1.0.1',
30+
version='1.1.0',
3131
packages=['pythaiasr'],
3232
url='https://github.com/pythainlp/pythaiasr',
3333
license='Apache Software License 2.0',

0 commit comments

Comments
 (0)