diff --git a/examples/mala_asr_slidespeech/utils/compute_wer_details.py b/examples/mala_asr_slidespeech/utils/compute_wer_details.py new file mode 100644 index 00000000..73a021a8 --- /dev/null +++ b/examples/mala_asr_slidespeech/utils/compute_wer_details.py @@ -0,0 +1,760 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +from enum import Enum +import re, sys, unicodedata +import codecs +import argparse +from tqdm import tqdm +import os + +remove_tag = False +spacelist = [" ", "\t", "\r", "\n"] +puncts = [ + "!", + ",", + "?", + "、", + "。", + "!", + ",", + ";", + "?", + ":", + "「", + "」", + "︰", + "『", + "』", + "《", + "》", +] + + +class Code(Enum): + match = 1 + substitution = 2 + insertion = 3 + deletion = 4 + + +class WordError(object): + def __init__(self): + self.errors = { + Code.substitution: 0, + Code.insertion: 0, + Code.deletion: 0, + } + self.ref_words = 0 + + def get_wer(self): + assert self.ref_words != 0 + errors = ( + self.errors[Code.substitution] + + self.errors[Code.insertion] + + self.errors[Code.deletion] + ) + return 100.0 * errors / self.ref_words + + def get_result_string(self): + return ( + f"error_rate={self.get_wer():.4f}, " + f"ref_words={self.ref_words}, " + f"subs={self.errors[Code.substitution]}, " + f"ins={self.errors[Code.insertion]}, " + f"dels={self.errors[Code.deletion]}" + ) + + + +def characterize(string): + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + # https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == "Zs" or cat1 == "Cn" or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == "Lo": # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = " " + if char == "<": + sep = ">" + j = i + 1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c == sep): + break + j += 1 + if j < len(string) and string[j] == ">": + j += 1 + res.append(string[i:j]) + i = j + return res + + +def stripoff_tags(x): + if not x: + return "" + chars = [] + i = 0 + T = len(x) + while i < T: + if x[i] == "<": + while i < T and x[i] != ">": + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return "".join(chars) + + +def normalize(sentence, ignore_words, cs, split=None): + """sentence, ignore_words are both in unicode""" + new_sentence = [] + for token in sentence: + x = token + if not cs: + x = x.upper() + if x in ignore_words: + continue + if remove_tag: + x = stripoff_tags(x) + if not x: + continue + if split and x in split: + new_sentence += split[x] + else: + new_sentence.append(x) + return new_sentence + + +class Calculator: + def __init__(self): + self.data = {} + self.space = [] + self.cost = {} + self.cost["cor"] = 0 + self.cost["sub"] = 1 + self.cost["del"] = 1 + self.cost["ins"] = 1 + + def calculate(self, lab, rec): + # Initialization + lab.insert(0, "") + rec.insert(0, "") + while len(self.space) < len(lab): + self.space.append([]) + for row in self.space: + for element in row: + element["dist"] = 0 + element["error"] = "non" + while len(row) < len(rec): + row.append({"dist": 0, "error": "non"}) + for i in range(len(lab)): + self.space[i][0]["dist"] = i + self.space[i][0]["error"] = "del" + for j in range(len(rec)): + self.space[0][j]["dist"] = j + self.space[0][j]["error"] = "ins" + self.space[0][0]["error"] = "non" + for token in lab: + if token not in self.data and len(token) > 0: + self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} + for token in rec: + if token not in self.data and len(token) > 0: + self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} + # Computing edit distance + for i, lab_token in enumerate(lab): + for j, rec_token in enumerate(rec): + # i lab j rec + if i == 0 or j == 0: + continue + min_dist = sys.maxsize + min_error = "none" + dist = self.space[i - 1][j]["dist"] + self.cost["del"] + error = "del" + if dist < min_dist: + min_dist = dist + min_error = error + dist = self.space[i][j - 1]["dist"] + self.cost["ins"] + error = "ins" + if dist < min_dist: + min_dist = dist + min_error = error + if lab_token == rec_token.replace("", ""): + dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"] + error = "cor" + else: + dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"] + error = "sub" + if dist < min_dist: + min_dist = dist + min_error = error + self.space[i][j]["dist"] = min_dist + self.space[i][j]["error"] = min_error + # Tracing back + result = { + "lab": [], + "rec": [], + "code": [], + "all": 0, + "cor": 0, + "sub": 0, + "ins": 0, + "del": 0, + } + i = len(lab) - 1 + j = len(rec) - 1 + while True: + if self.space[i][j]["error"] == "cor": # correct + if len(lab[i]) > 0: + self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 + self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1 + result["all"] = result["all"] + 1 + result["cor"] = result["cor"] + 1 + result["lab"].insert(0, lab[i]) + result["rec"].insert(0, rec[j]) + result["code"].insert(0, Code.match) + i = i - 1 + j = j - 1 + elif self.space[i][j]["error"] == "sub": # substitution + if len(lab[i]) > 0: + self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 + self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1 + result["all"] = result["all"] + 1 + result["sub"] = result["sub"] + 1 + result["lab"].insert(0, lab[i]) + result["rec"].insert(0, rec[j]) + result["code"].insert(0, Code.substitution) + i = i - 1 + j = j - 1 + elif self.space[i][j]["error"] == "del": # deletion + if len(lab[i]) > 0: + self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 + self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1 + result["all"] = result["all"] + 1 + result["del"] = result["del"] + 1 + result["lab"].insert(0, lab[i]) + result["rec"].insert(0, "") + result["code"].insert(0, Code.deletion) + i = i - 1 + elif self.space[i][j]["error"] == "ins": # insertion + if len(rec[j]) > 0: + self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1 + result["ins"] = result["ins"] + 1 + result["lab"].insert(0, "") + result["rec"].insert(0, rec[j]) + result["code"].insert(0, Code.insertion) + j = j - 1 + elif self.space[i][j]["error"] == "non": # starting point + break + else: # shouldn't reach here + print( + "this should not happen , i = {i} , j = {j} , error = {error}".format( + i=i, j=j, error=self.space[i][j]["error"] + ) + ) + return result + + def overall(self): + result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} + for token in self.data: + result["all"] = result["all"] + self.data[token]["all"] + result["cor"] = result["cor"] + self.data[token]["cor"] + result["sub"] = result["sub"] + self.data[token]["sub"] + result["ins"] = result["ins"] + self.data[token]["ins"] + result["del"] = result["del"] + self.data[token]["del"] + return result + + def cluster(self, data): + result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} + for token in data: + if token in self.data: + result["all"] = result["all"] + self.data[token]["all"] + result["cor"] = result["cor"] + self.data[token]["cor"] + result["sub"] = result["sub"] + self.data[token]["sub"] + result["ins"] = result["ins"] + self.data[token]["ins"] + result["del"] = result["del"] + self.data[token]["del"] + return result + + def keys(self): + return list(self.data.keys()) + + +def width(string): + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + + +def default_cluster(word): + unicode_names = [unicodedata.name(char) for char in word] + for i in reversed(range(len(unicode_names))): + if unicode_names[i].startswith("DIGIT"): # 1 + unicode_names[i] = "Number" # 'DIGIT' + elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[ + i + ].startswith("CJK COMPATIBILITY IDEOGRAPH"): + # 明 / 郎 + unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH' + elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[ + i + ].startswith("LATIN SMALL LETTER"): + # A / a + unicode_names[i] = "English" # 'LATIN LETTER' + elif unicode_names[i].startswith("HIRAGANA LETTER"): # は こ め + unicode_names[i] = "Japanese" # 'GANA LETTER' + elif ( + unicode_names[i].startswith("AMPERSAND") + or unicode_names[i].startswith("APOSTROPHE") + or unicode_names[i].startswith("COMMERCIAL AT") + or unicode_names[i].startswith("DEGREE CELSIUS") + or unicode_names[i].startswith("EQUALS SIGN") + or unicode_names[i].startswith("FULL STOP") + or unicode_names[i].startswith("HYPHEN-MINUS") + or unicode_names[i].startswith("LOW LINE") + or unicode_names[i].startswith("NUMBER SIGN") + or unicode_names[i].startswith("PLUS SIGN") + or unicode_names[i].startswith("SEMICOLON") + ): + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else: + return "Other" + if len(unicode_names) == 0: + return "Other" + if len(unicode_names) == 1: + return unicode_names[0] + for i in range(len(unicode_names) - 1): + if unicode_names[i] != unicode_names[i + 1]: + return "Other" + return unicode_names[0] + + +def get_args(): + parser = argparse.ArgumentParser(description="wer cal") + parser.add_argument("--ref", type=str, help="Text input path") + parser.add_argument("--ref_ocr", type=str, help="Text input path") + parser.add_argument("--ref2session", type=str, default="", help="Text input path") + parser.add_argument("--rec_name", type=str, action="append", default=[]) + parser.add_argument("--rec_file", type=str, action="append", default=[]) + parser.add_argument("--verbose", type=int, default=1, help="show") + parser.add_argument("--char", type=bool, default=True, help="show") + args = parser.parse_args() + return args + + +def main(args): + cluster_file = "" + ignore_words = set() + tochar = args.char + verbose = args.verbose + padding_symbol = " " + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + + if not case_sensitive: + ig = set([w.upper() for w in ignore_words]) + ignore_words = ig + + default_clusters = {} + default_words = {} + ref_file = args.ref + ref_ocr = args.ref_ocr + rec_files = args.rec_file + rec_names = args.rec_name + assert len(rec_files) == len(rec_names) + + # load ocr + ref_ocr_dict = {} + with codecs.open(ref_ocr, "r", "utf-8") as fh: + for line in fh: + if "$" in line: + line = line.replace("$", " ") + if tochar: # + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split) + + if split and not case_sensitive: #x + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit + + utt2session = None + + if args.ref2session != "": + utt2session = dict() + with codecs.open(args.ref2session, "r", "utf-8") as fh: + for line in fh: + uttid, session = line.strip().split() + utt2session[uttid] = session + + rec_sets = {} + calculators_dict = dict() + ub_wer_dict = dict() + hotwords_related_dict = dict() # 记录recall相关的内容 + multirec_session_result = None + if args.ref2session != "": multirec_session_result = dict() + for i, hyp_file in enumerate(rec_files): + rec_sets[rec_names[i]] = dict() + with codecs.open(hyp_file, "r", "utf-8") as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split) + + calculators_dict[rec_names[i]] = Calculator() + ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()} + multirec_session_result[rec_names[i]] = dict() + hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0} + # tp: 热词在label里,同时在rec里 + # tn: 热词不在label里,同时不在rec里 + # fp: 热词不在label里,但是在rec里 + # fn: 热词在label里,但是不在rec里 + + # record wrong label but in ocr + wrong_rec_but_in_ocr_dict = {} + for rec_name in rec_names: + wrong_rec_but_in_ocr_dict[rec_name] = 0 + + _file_total_len = 0 + with os.popen("cat {} | wc -l".format(ref_file)) as pipe: + _file_total_len = int(pipe.read().strip()) + + # compute error rate on the interaction of reference file and hyp file + for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len): + if tochar: + array = characterize(line) + else: + array = line.rstrip('\n').split() + if len(array) == 0: continue + fid = array[0] + lab = normalize(array[1:], ignore_words, case_sensitive, split) + + if verbose: + print('\nutt: %s' % fid) + + ocr_text = ref_ocr_dict[fid] + ocr_set = set(ocr_text) + print('ocr: {}'.format(" ".join(ocr_text))) + list_match = [] # 指label里面在ocr里面的内容 + list_not_mathch = [] + tmp_error = 0 + tmp_match = 0 + for index in range(len(lab)): + # text_list.append(uttlist[index+1]) + if lab[index] not in ocr_set: + tmp_error += 1 + list_not_mathch.append(lab[index]) + else: + tmp_match += 1 + list_match.append(lab[index]) + print('label in ocr: {}'.format(" ".join(list_match))) + + # for each reco file + base_wrong_ocr_wer = None + ocr_wrong_ocr_wer = None + + for rec_name in rec_names: + rec_set = rec_sets[rec_name] + if fid not in rec_set: + continue + rec = rec_set[fid] + + # print(rec) + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + + result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy()) + if verbose: + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + + if utt2session is not None: + session_result = multirec_session_result[rec_name] + session = utt2session[fid] + if session not in session_result: + session_result[session] = {'lab': [], 'rec': [], 'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + + session_result[session]['all'] = result['all'] + session_result[session]['all'] + session_result[session]['cor'] = result['cor'] + session_result[session]['cor'] + session_result[session]['sub'] = result['sub'] + session_result[session]['sub'] + session_result[session]['ins'] = result['ins'] + session_result[session]['ins'] + session_result[session]['del'] = result['del'] + session_result[session]['del'] + + # print(result['rec']) + wrong_rec_but_in_ocr = [] + for idx in range(len(result['lab'])): + if result['lab'][idx] != "": + if result['lab'][idx] != result['rec'][idx].replace("", ""): + if result['lab'][idx] in list_match: + wrong_rec_but_in_ocr.append(result['lab'][idx]) + wrong_rec_but_in_ocr_dict[rec_name] += 1 + print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr))) + + if rec_name == "base": + base_wrong_ocr_wer = len(wrong_rec_but_in_ocr) + if "ocr" in rec_name or "hot" in rec_name: + ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr) + if ocr_wrong_ocr_wer < base_wrong_ocr_wer: + print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer)) + elif ocr_wrong_ocr_wer > base_wrong_ocr_wer: + print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer)) + + # recall = 0 + # false_alarm = 0 + # for idx in range(len(result['lab'])): + # if "" in result['rec'][idx]: + # if result['rec'][idx].replace("", "") in list_match: + # recall += 1 + # else: + # false_alarm += 1 + # print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format( + # recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0 + # )) + # tp: 热词在label里,同时在rec里 + # tn: 热词不在label里,同时不在rec里 + # fp: 热词不在label里,但是在rec里 + # fn: 热词在label里,但是不在rec里 + _rec_list = [word.replace("", "") for word in rec] + _label_list = [word for word in lab] + _tp = _tn = _fp = _fn = 0 + hot_true_list = set([hotword for hotword in ocr_text if hotword in _label_list]) + hot_bad_list = set([hotword for hotword in ocr_text if hotword not in _label_list]) + for badhotword in hot_bad_list: + count = len([word for word in _rec_list if word == badhotword]) + # print(f"bad {badhotword} count: {count}") + # for word in _rec_list: + # if badhotword == word: + # count += 1 + if count == 0: + hotwords_related_dict[rec_name]['tn'] += 1 + _tn += 1 + # fp: 0 + else: + hotwords_related_dict[rec_name]['fp'] += count + _fp += count + # tn: 0 + # if badhotword in _rec_list: + # hotwords_related_dict[rec_name]['fp'] += 1 + # else: + # hotwords_related_dict[rec_name]['tn'] += 1 + for hotword in hot_true_list: + true_count = len([word for word in _label_list if hotword == word]) + rec_count = len([word for word in _rec_list if hotword == word]) + # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}") + if rec_count == true_count: + hotwords_related_dict[rec_name]['tp'] += true_count + _tp += true_count + elif rec_count > true_count: + hotwords_related_dict[rec_name]['tp'] += true_count + # fp: 不在label里,但是在rec里 + hotwords_related_dict[rec_name]['fp'] += rec_count - true_count + _tp += true_count + _fp += rec_count - true_count + else: + hotwords_related_dict[rec_name]['tp'] += rec_count + # fn: 热词在label里,但是不在rec里 + hotwords_related_dict[rec_name]['fn'] += true_count - rec_count + _tp += rec_count + _fn += true_count - rec_count + print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format( + _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0 + )) + + # if hotword in _rec_list: + # hotwords_related_dict[rec_name]['tp'] += 1 + # else: + # hotwords_related_dict[rec_name]['fn'] += 1 + # 计算uwer, bwer, wer + for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]): + if code == Code.match: + ub_wer_dict[rec_name]["wer"].ref_words += 1 + if lab_word in hot_true_list: + # tmp_ref.append(ref_tokens[ref_idx]) + ub_wer_dict[rec_name]["b_wer"].ref_words += 1 + else: + ub_wer_dict[rec_name]["u_wer"].ref_words += 1 + elif code == Code.substitution: + ub_wer_dict[rec_name]["wer"].ref_words += 1 + ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1 + if lab_word in hot_true_list: + # tmp_ref.append(ref_tokens[ref_idx]) + ub_wer_dict[rec_name]["b_wer"].ref_words += 1 + ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1 + else: + ub_wer_dict[rec_name]["u_wer"].ref_words += 1 + ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1 + elif code == Code.deletion: + ub_wer_dict[rec_name]["wer"].ref_words += 1 + ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1 + if lab_word in hot_true_list: + # tmp_ref.append(ref_tokens[ref_idx]) + ub_wer_dict[rec_name]["b_wer"].ref_words += 1 + ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1 + else: + ub_wer_dict[rec_name]["u_wer"].ref_words += 1 + ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1 + elif code == Code.insertion: + ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1 + if rec_word in hot_true_list: + ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1 + else: + ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1 + + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])): + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length - len_lab) + space['rec'].append(length - len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end=' ') + else: + print('lab:', end=' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['lab'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end=' ') + else: + print('rec:', end=' ') + + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['rec'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print() + # print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 + print('\n', end='\n') + # break + if verbose: + print('===========================================================================') + print() + + # print(zip(multirec_session_result)) + session_results_zip = dict() + for rec_name, session_result in multirec_session_result.items(): + for session, result in session_result.items(): + if session not in session_results_zip.keys(): + session_results_zip[session] = [] + session_results_zip[session].append((rec_name, result)) + # if result['all'] != 0 : + # wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + # else : + # wer = 0.0 + # print('{} {} -> {:4.2f} %'.format(rec_name, session, wer), end = ' ') + # print('N=%d C=%d S=%d D=%d I=%d' % + # (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + for session, result_list in session_results_zip.items(): + for rec_name, result in result_list: + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('{} {} -> {:4.2f} %'.format(rec_name, session, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + print() + + if verbose: + print('===========================================================================') + print() + + print(wrong_rec_but_in_ocr_dict) + for rec_name in rec_names: + result = calculators_dict[rec_name].overall() + + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}") + print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}") + print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}") + + recall = hotwords_related_dict[rec_name]['tp'] / ( + hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] + ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0 + print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format( + hotwords_related_dict[rec_name]['tp'], + hotwords_related_dict[rec_name]['tn'], + hotwords_related_dict[rec_name]['fp'], + hotwords_related_dict[rec_name]['fn'], + sum([v for k, v in hotwords_related_dict[rec_name].items()]), + recall + )) + + print() + print("{:.3f}; {:.3f}; {:.3f}; {:.2f}%".format(ub_wer_dict[rec_name]['wer'].get_wer(), ub_wer_dict[rec_name]['u_wer'].get_wer(), ub_wer_dict[rec_name]['b_wer'].get_wer(), recall)) + + # tp: 热词在label里,同时在rec里 + # tn: 热词不在label里,同时不在rec里 + # fp: 热词不在label里,但是在rec里 + # fn: 热词在label里,但是不在rec里 + if not verbose: + print() + print() + # print(session_results_zip) + + +if __name__ == "__main__": + args = get_args() + # print("") + print(args) + main(args) diff --git a/examples/mala_asr_slidespeech/utils/run_L95_test.sh b/examples/mala_asr_slidespeech/utils/run_L95_test.sh new file mode 100644 index 00000000..26817d83 --- /dev/null +++ b/examples/mala_asr_slidespeech/utils/run_L95_test.sh @@ -0,0 +1,14 @@ +export CUDA_VISIBLE_DEVICES=0 + +mkdir -p /nfs/yangguanrou.ygr/slidespeech/for_pr/slides-finetune-wavlm/asr/3840/test/ +python /nfs/yangguanrou.ygr/slidespeech/compute_wer_details/compute_wer_details.py --v 1 \ +--ref /nfs/yangguanrou.ygr/experiments_slides_wavlm/slides-finetune-wavlm/asr/3840/paper/decode_log_test_clean_beam4_repetition_penalty1_gt.proc \ +--ref_ocr /nfs/yangguanrou.ygr/slidespeech/test_oracle_v1/hot_related/ocr_1gram_top50_mmr070_hotwords_list \ +--ref2session /nfs/yangguanrou.ygr/slidespeech/test_oracle_v1/utt2spk \ +--rec_name base \ +--rec_name hot \ +--rec_file /nfs/yangguanrou.ygr/experiments_slides_wavlm/slides-finetune-wavlm_notext/asr/1760/decode_log_test_clean_beam4_repetition_penalty1_pred.proc \ +--rec_file /nfs/yangguanrou.ygr/experiments_slides_wavlm/slides-finetune-wavlm/asr/3840/paper/decode_log_test_clean_beam4_repetition_penalty1_pred.proc \ +> /nfs/yangguanrou.ygr/slidespeech/for_pr/slides-finetune-wavlm/asr/3840/test/metric.log + +#-m debugpy --listen 5678 --wait-for-client