|
13 | 13 | from texterrors_align import StringVector |
14 | 14 | from loguru import logger |
15 | 15 | from termcolor import colored |
| 16 | +from importlib.resources import files, as_file |
| 17 | + |
16 | 18 |
|
17 | 19 | OOV_SYM = '<unk>' |
18 | 20 | CPP_WORDS_CONTAINER = True |
@@ -691,7 +693,7 @@ def process_multiple_outputs(ref_utts, hypa_utts, hypb_utts, fh, num_top_errors, |
691 | 693 | def process_output(ref_utts, hyp_utts, fh, ref_file, hyp_file, cer=False, num_top_errors=10, oov_set=None, debug=False, |
692 | 694 | use_chardiff=True, isctm=False, skip_detailed=False, |
693 | 695 | keywords=None, utt_group_map=None, oracle_wer=False, |
694 | | - freq_sort=False, nocolor=False, insert_tok='<eps>', terminal_width=None): |
| 696 | + freq_sort=False, nocolor=False, insert_tok='<eps>', terminal_width=None, weighted_wer=False): |
695 | 697 |
|
696 | 698 | if terminal_width is None: |
697 | 699 | terminal_width, _ = shutil.get_terminal_size() |
@@ -744,6 +746,40 @@ def process_output(ref_utts, hyp_utts, fh, ref_file, hyp_file, cer=False, num_to |
744 | 746 | fh.write(f'WER: {100.*wer:.1f} (ins {ins_count}, del {del_count}, sub {sub_count} / {error_stats.total_count})' |
745 | 747 | f'\nSER: {100.*error_stats.utt_wrong / len(error_stats.utts):.1f}\n') |
746 | 748 |
|
| 749 | + if weighted_wer: |
| 750 | + words = [] |
| 751 | + probs = [] |
| 752 | + |
| 753 | + wordlist_resource= files('texterrors') / 'data' / 'wordlist' |
| 754 | + with as_file(wordlist_resource) as wordlist_path: |
| 755 | + with open(wordlist_path) as fh_wordlist: |
| 756 | + for line in fh_wordlist: |
| 757 | + word, prob = line.strip().split() |
| 758 | + words.append(word) |
| 759 | + probs.append(float(prob)) |
| 760 | + probs = -np.log(np.array(probs)) |
| 761 | + minscore, maxscore = probs[100], probs[-1] |
| 762 | + probs[:100] = minscore |
| 763 | + word2weight = {} |
| 764 | + maxweight = 0. |
| 765 | + for word, prob in zip(words, probs): |
| 766 | + word2weight[word] = max((prob - minscore) / (maxscore - minscore), 1e-2) |
| 767 | + maxweight = max(maxweight, word2weight[word]) |
| 768 | + |
| 769 | + num = 0 |
| 770 | + for word, cnt in error_stats.subs.items(): |
| 771 | + ref_w, hyp_w = word.split('>') |
| 772 | + weight = (word2weight.get(ref_w, maxweight) + word2weight.get(hyp_w, maxweight)) / 2. |
| 773 | + num += weight * cnt |
| 774 | + for word, cnt in error_stats.ins.items(): |
| 775 | + num += word2weight.get(word, maxweight) * cnt |
| 776 | + for word, cnt in error_stats.dels.items(): |
| 777 | + num += word2weight.get(word, maxweight) * cnt |
| 778 | + denom = sum(word2weight.get(word, maxweight) * cnt for word, cnt in error_stats.word_counts.items()) |
| 779 | + |
| 780 | + weighted_wer = num / denom |
| 781 | + fh.write(f'Weighted WER: {100.*weighted_wer:.1f}\n') |
| 782 | + |
747 | 783 | if cer: |
748 | 784 | cer = error_stats.char_error_count / float(error_stats.char_count) |
749 | 785 | fh.write(f'CER: {100.*cer:.1f} ({error_stats.char_error_count} / {error_stats.char_count})\n') |
@@ -785,7 +821,8 @@ def main( |
785 | 821 | utt_group_map_f: ('Should be a file which maps uttids to group, WER will be output per group.', 'option', '') = '', |
786 | 822 | usecolor: ('Show detailed output with color (use less -R). Red/white is reference, Green/white model output.', 'flag', 'c')=False, |
787 | 823 | num_top_errors: ('Number of errors to show per type in detailed output.', 'option')=10, |
788 | | - second_hyp_f: ('Will compare outputs between two hypothesis files.', 'option')='' |
| 824 | + second_hyp_f: ('Will compare outputs between two hypothesis files.', 'option')='', |
| 825 | + weighted_wer: ('Use weighted WER, will weight the errors by word frequency.', 'flag', None) = False, |
789 | 826 | ): |
790 | 827 |
|
791 | 828 | logger.remove() |
@@ -820,7 +857,8 @@ def main( |
820 | 857 | process_output(ref_utts, hyp_utts, fh, cer=cer, debug=debug, oov_set=oov_set, |
821 | 858 | ref_file=ref_file, hyp_file=hyp_file, use_chardiff=use_chardiff, skip_detailed=skip_detailed, |
822 | 859 | keywords=keywords, utt_group_map=utt_group_map, freq_sort=freq_sort, |
823 | | - isctm=isctm, oracle_wer=oracle_wer, nocolor=not usecolor, num_top_errors=num_top_errors) |
| 860 | + isctm=isctm, oracle_wer=oracle_wer, nocolor=not usecolor, num_top_errors=num_top_errors, |
| 861 | + weighted_wer=weighted_wer) |
824 | 862 | else: |
825 | 863 | ref_utts = read_ref_file(ref_file, isark) |
826 | 864 | hyp_uttsa = read_hyp_file(hyp_file, isark, False) |
|
0 commit comments