@@ -16,7 +16,6 @@ def __init__(
1616 n_class : int ,
1717 lr : float ,
1818 text_process : TextProcess ,
19- ctc_decoder : CTCDecoder ,
2019 cfg_optim : dict ,
2120 ):
2221 super ().__init__ ()
@@ -25,7 +24,6 @@ def __init__(
2524 )
2625 self .lr = lr
2726 self .text_process = text_process
28- self .ctc_decoder = ctc_decoder
2927 self .cal_wer = torchmetrics .WordErrorRate ()
3028 self .cfg_optim = cfg_optim
3129 self .criterion = nn .CTCLoss (zero_infinity = True )
@@ -63,12 +61,8 @@ def validation_step(self, batch, batch_idx):
6361 outputs .permute (1 , 0 , 2 ), targets , input_lengths , target_lengths
6462 )
6563
66- if self .ctc_decoder :
67- # unsqueeze for batchsize 1
68- predicts = [self .ctc_decoder (sent .unsqueeze (0 )) for sent in outputs ]
69- else :
70- decode = outputs .argmax (dim = - 1 )
71- predicts = [self .text_process .decode (sent ) for sent in decode ]
64+ decode = outputs .argmax (dim = - 1 )
65+ predicts = [self .text_process .decode (sent ) for sent in decode ]
7266
7367 targets = [self .text_process .int2text (sent ) for sent in targets ]
7468
@@ -92,12 +86,8 @@ def test_step(self, batch, batch_idx):
9286 outputs .permute (1 , 0 , 2 ), targets , input_lengths , target_lengths
9387 )
9488
95- if self .ctc_decoder :
96- # unsqueeze for batchsize 1
97- predicts = [self .ctc_decoder (sent .unsqueeze (0 )) for sent in outputs ]
98- else :
99- decode = outputs .argmax (dim = - 1 )
100- predicts = [self .text_process .decode (sent ) for sent in decode ]
89+ decode = outputs .argmax (dim = - 1 )
90+ predicts = [self .text_process .decode (sent ) for sent in decode ]
10191
10292 targets = [self .text_process .int2text (sent ) for sent in targets ]
10393
0 commit comments