Skip to content

Commit 85d4b0b

Browse files
author
蒄骰
committed
upload ctc_file and remove irrelavant codes
1 parent 52fab27 commit 85d4b0b

File tree

4 files changed

+5562
-41
lines changed

4 files changed

+5562
-41
lines changed

examples/contextual_asr/contextual_asr_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ class DataConfig:
105105
infer_type: str = "bias"
106106
infer_file: str = "/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_ref/test-clean.biasing_100.tsv"
107107
ctc_file: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_large_libri_test_other_char.txt"
108-
filter_type: str = "char"
109-
phn_to_name_dict: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_ft_libri960_${ref_split}_phn.json"
110108
common_words_5k_dir: str="/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/words/common_words_5k.txt"
111109
probability_threshold: float = 0.9
112110
word_num: int = 15

examples/contextual_asr/dataset/hotwordsinfer_dataset.py

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,6 @@ def find_candidate_names(sentence, ngram_index, n=2):
3636
candidates.update(ngram_index.get(ngram, []))
3737
return candidates
3838

39-
def build_ngram_index_phn(names, n=2):
40-
"""构建N-Gram倒排索引"""
41-
index = {}
42-
for name in names:
43-
phonemes = name.split()
44-
for i in range(len(phonemes) - n + 1):
45-
ngram = ' '.join(phonemes[i:i+n])
46-
index.setdefault(ngram, set()).add(name)
47-
return index
48-
49-
def find_candidate_names_phn(phonemes, ngram_index, n=2):
50-
"""通过N-Gram倒排索引找到候选人名"""
51-
candidates = set()
52-
phonemes = phonemes.split()
53-
for i in range(len(phonemes) - n + 1):
54-
ngram = ' '.join(phonemes[i:i+n])
55-
candidates.update(ngram_index.get(ngram, []))
56-
return candidates
57-
5839
@lru_cache(maxsize=100000)
5940
def similarity(name, sentence):
6041
return Levenshtein.ratio(name, sentence)
@@ -139,11 +120,6 @@ def __init__(
139120
# analyze
140121
self.hotwords_num=0
141122
self.miss_words_num=0
142-
self.filter_type=dataset_config.filter_type
143-
if self.filter_type=="phn":
144-
with open(dataset_config.phn_to_name_dict, 'r') as file:
145-
self.phn_to_name_dict = json.load(file)
146-
147123
self.probability_threshold = dataset_config.get("probability_threshold", 0.95)
148124
self.word_num = dataset_config.get("word_num", 15)
149125
self.prompt_word_num = 0
@@ -202,22 +178,14 @@ def __getitem__(self, index):
202178
ocr = ocr.upper()
203179
elif self.infer_type=="filter":
204180
gt = eval(self.hotwords_list[index])
205-
if self.filter_type == "char":
206-
infer_sentence = self.infer_list[index].lower()
207-
else:
208-
infer_sentence = self.infer_list[index]
209-
181+
infer_sentence = self.infer_list[index].lower()
210182
words_list = infer_sentence.split()
211183
filtered_words = [word for word in words_list if word not in self.common_words_5k]
212184
infer_sentence = ' '.join(filtered_words)
213185

214186
biaswords=eval(self.biaswords_list[index])
215-
if self.filter_type=="char":
216-
ngram_index = build_ngram_index(biaswords)
217-
candidates = find_candidate_names(infer_sentence, ngram_index)
218-
elif self.filter_type=="phn":
219-
ngram_index = build_ngram_index_phn(biaswords)
220-
candidates = find_candidate_names_phn(infer_sentence, ngram_index)
187+
ngram_index = build_ngram_index(biaswords)
188+
candidates = find_candidate_names(infer_sentence, ngram_index)
221189
if not self.filter_infer_sentence_few:
222190
scores = score_candidates(candidates, infer_sentence)
223191
sorted_dict = sorted(scores.items(), key=lambda item: item[1], reverse=True)
@@ -246,10 +214,6 @@ def __getitem__(self, index):
246214
logger.info("infer sentence: %s",infer_sentence)
247215
logger.info("target sentence: %s", target)
248216
logger.info("gt: %s, keys_list: %s", gt, keys_list)
249-
# ===============================
250-
if self.filter_type=="phn":
251-
keys_list = [self.phn_to_name_dict[phn] for phn in keys_list]
252-
keys_list = [item for sublist in keys_list for item in sublist]
253217

254218
ocr = " ".join(keys_list).upper()
255219

0 commit comments

Comments
 (0)