Skip to content

Commit 6509757

Browse files
committed
Adding weighted WER feature.
1 parent d4586b2 commit 6509757

File tree

6 files changed

+64
-5
lines changed

6 files changed

+64
-5
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
include LICENSE README.md requirements.txt
22
recursive-include libs *.*
3+
recursive-include texterrors/data *

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Features:
1010
- Metrics by group (for example speaker)
1111
- Comparing two hypothesis files to reference
1212
- Oracle WER
13+
- **NEW** Weighted WER (English only)
1314
- Sorting most common errors by frequency or count
1415
- Measuring performance on keywords
1516
- Measuring OOV-CER (see [https://arxiv.org/abs/2107.08091](https://arxiv.org/abs/2107.08091) )
@@ -89,6 +90,7 @@ This results in a WER of 83.3\% because of the extra insertion and deletion. And
8990

9091
Recent changes:
9192

93+
- 11.11.25 Weighted WER for English
9294
- 26.02.25 Faster alignment, better multihyp support, fixed multihyp bug.
9395
- 22.06.22 refactored internals to make them simpler, character aware alignment is off by default, added more explanations
9496
- 20.05.22 fixed bug missing regex dependency

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ termcolor
66
Levenshtein
77
regex
88
pytest
9+
importlib_resources

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import setuptools
55
import sys
66

7-
__version__ = "1.0.10"
7+
__version__ = "1.0.11"
88

99

1010
class get_pybind_include(object):
@@ -101,5 +101,7 @@ def get_requires():
101101
entry_points={'console_scripts': ['texterrors=texterrors.texterrors:cli']},
102102
install_requires=get_requires(),
103103
setup_requires=['pybind11'],
104-
python_requires='>=3.6'
104+
python_requires='>=3.6',
105+
package_data={"texterrors": ["data/wordlist"]},
106+
include_package_data=True,
105107
)

tests/test_functions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,21 @@ def test_oov_cer():
117117
assert err / cnt == 0., err / cnt
118118

119119

120+
def test_weighted_wer():
121+
reflines = ['1 my name is john doe']
122+
hyplines = ['1 my name is joe doe']
123+
refs = create_inp(reflines)
124+
hyps = create_inp(hyplines)
125+
buffer = io.StringIO()
126+
texterrors.process_output(refs, hyps, buffer, 'A', 'B',weighted_wer=True, skip_detailed=True)
127+
output = buffer.getvalue()
128+
ref ="""WER: 20.0 (ins 0, del 0, sub 1 / 5)
129+
SER: 100.0
130+
Weighted WER: 28.3
131+
"""
132+
assert output == ref, show_diff(output, ref)
133+
134+
120135
def test_seq_distance():
121136
a, b = 'a b', 'a b'
122137
d = texterrors.seq_distance(StringVector(a.split()), StringVector(b.split()))

texterrors/texterrors.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from texterrors_align import StringVector
1414
from loguru import logger
1515
from termcolor import colored
16+
from importlib.resources import files, as_file
17+
1618

1719
OOV_SYM = '<unk>'
1820
CPP_WORDS_CONTAINER = True
@@ -691,7 +693,7 @@ def process_multiple_outputs(ref_utts, hypa_utts, hypb_utts, fh, num_top_errors,
691693
def process_output(ref_utts, hyp_utts, fh, ref_file, hyp_file, cer=False, num_top_errors=10, oov_set=None, debug=False,
692694
use_chardiff=True, isctm=False, skip_detailed=False,
693695
keywords=None, utt_group_map=None, oracle_wer=False,
694-
freq_sort=False, nocolor=False, insert_tok='<eps>', terminal_width=None):
696+
freq_sort=False, nocolor=False, insert_tok='<eps>', terminal_width=None, weighted_wer=False):
695697

696698
if terminal_width is None:
697699
terminal_width, _ = shutil.get_terminal_size()
@@ -744,6 +746,40 @@ def process_output(ref_utts, hyp_utts, fh, ref_file, hyp_file, cer=False, num_to
744746
fh.write(f'WER: {100.*wer:.1f} (ins {ins_count}, del {del_count}, sub {sub_count} / {error_stats.total_count})'
745747
f'\nSER: {100.*error_stats.utt_wrong / len(error_stats.utts):.1f}\n')
746748

749+
if weighted_wer:
750+
words = []
751+
probs = []
752+
753+
wordlist_resource= files('texterrors') / 'data' / 'wordlist'
754+
with as_file(wordlist_resource) as wordlist_path:
755+
with open(wordlist_path) as fh_wordlist:
756+
for line in fh_wordlist:
757+
word, prob = line.strip().split()
758+
words.append(word)
759+
probs.append(float(prob))
760+
probs = -np.log(np.array(probs))
761+
minscore, maxscore = probs[100], probs[-1]
762+
probs[:100] = minscore
763+
word2weight = {}
764+
maxweight = 0.
765+
for word, prob in zip(words, probs):
766+
word2weight[word] = max((prob - minscore) / (maxscore - minscore), 1e-2)
767+
maxweight = max(maxweight, word2weight[word])
768+
769+
num = 0
770+
for word, cnt in error_stats.subs.items():
771+
ref_w, hyp_w = word.split('>')
772+
weight = (word2weight.get(ref_w, maxweight) + word2weight.get(hyp_w, maxweight)) / 2.
773+
num += weight * cnt
774+
for word, cnt in error_stats.ins.items():
775+
num += word2weight.get(word, maxweight) * cnt
776+
for word, cnt in error_stats.dels.items():
777+
num += word2weight.get(word, maxweight) * cnt
778+
denom = sum(word2weight.get(word, maxweight) * cnt for word, cnt in error_stats.word_counts.items())
779+
780+
weighted_wer = num / denom
781+
fh.write(f'Weighted WER: {100.*weighted_wer:.1f}\n')
782+
747783
if cer:
748784
cer = error_stats.char_error_count / float(error_stats.char_count)
749785
fh.write(f'CER: {100.*cer:.1f} ({error_stats.char_error_count} / {error_stats.char_count})\n')
@@ -785,7 +821,8 @@ def main(
785821
utt_group_map_f: ('Should be a file which maps uttids to group, WER will be output per group.', 'option', '') = '',
786822
usecolor: ('Show detailed output with color (use less -R). Red/white is reference, Green/white model output.', 'flag', 'c')=False,
787823
num_top_errors: ('Number of errors to show per type in detailed output.', 'option')=10,
788-
second_hyp_f: ('Will compare outputs between two hypothesis files.', 'option')=''
824+
second_hyp_f: ('Will compare outputs between two hypothesis files.', 'option')='',
825+
weighted_wer: ('Use weighted WER, will weight the errors by word frequency.', 'flag', None) = False,
789826
):
790827

791828
logger.remove()
@@ -820,7 +857,8 @@ def main(
820857
process_output(ref_utts, hyp_utts, fh, cer=cer, debug=debug, oov_set=oov_set,
821858
ref_file=ref_file, hyp_file=hyp_file, use_chardiff=use_chardiff, skip_detailed=skip_detailed,
822859
keywords=keywords, utt_group_map=utt_group_map, freq_sort=freq_sort,
823-
isctm=isctm, oracle_wer=oracle_wer, nocolor=not usecolor, num_top_errors=num_top_errors)
860+
isctm=isctm, oracle_wer=oracle_wer, nocolor=not usecolor, num_top_errors=num_top_errors,
861+
weighted_wer=weighted_wer)
824862
else:
825863
ref_utts = read_ref_file(ref_file, isark)
826864
hyp_uttsa = read_hyp_file(hyp_file, isark, False)

0 commit comments

Comments
 (0)