Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/contextual_asr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ They categorize the 5,000 most frequent words in the Librispeech training corpus
words, with the remainder classified as rare words. The biasing list generated for the test set consists of two segments: rare words in the transcriptions, and distractors sampled from the 209.2K rare words vocabulary. Biasing lists of varying lengths are generated by incorporating N = {100, 500, 1000, 2000} distractors into the lists.


The viterbi decode results of our CTC Fine-tuned WavLM-Large: [test-clean](https://drive.google.com/file/d/1kMzPx8oRK3aOsxNaMGski3zH8z5Otvek/view?usp=drive_link), [test-other](https://drive.google.com/file/d/12KHaatVg5O0MIBTcf8e_rNjV_i9WLBFR/view?usp=drive_link) (``ctc_file`` in contextual_asr_config.py)

## Decoding with checkpoints
LLM-based ASR Inference script.
Expand Down
2 changes: 0 additions & 2 deletions examples/contextual_asr/contextual_asr_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ class DataConfig:
infer_type: str = "bias"
infer_file: str = "/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_ref/test-clean.biasing_100.tsv"
ctc_file: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_large_libri_test_other_char.txt"
filter_type: str = "char"
phn_to_name_dict: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_ft_libri960_${ref_split}_phn.json"
common_words_5k_dir: str="/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/words/common_words_5k.txt"
probability_threshold: float = 0.9
word_num: int = 15
Expand Down
42 changes: 3 additions & 39 deletions examples/contextual_asr/dataset/hotwordsinfer_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,6 @@ def find_candidate_names(sentence, ngram_index, n=2):
candidates.update(ngram_index.get(ngram, []))
return candidates

def build_ngram_index_phn(names, n=2):
"""构建N-Gram倒排索引"""
index = {}
for name in names:
phonemes = name.split()
for i in range(len(phonemes) - n + 1):
ngram = ' '.join(phonemes[i:i+n])
index.setdefault(ngram, set()).add(name)
return index

def find_candidate_names_phn(phonemes, ngram_index, n=2):
"""通过N-Gram倒排索引找到候选人名"""
candidates = set()
phonemes = phonemes.split()
for i in range(len(phonemes) - n + 1):
ngram = ' '.join(phonemes[i:i+n])
candidates.update(ngram_index.get(ngram, []))
return candidates

@lru_cache(maxsize=100000)
def similarity(name, sentence):
return Levenshtein.ratio(name, sentence)
Expand Down Expand Up @@ -139,11 +120,6 @@ def __init__(
# analyze
self.hotwords_num=0
self.miss_words_num=0
self.filter_type=dataset_config.filter_type
if self.filter_type=="phn":
with open(dataset_config.phn_to_name_dict, 'r') as file:
self.phn_to_name_dict = json.load(file)

self.probability_threshold = dataset_config.get("probability_threshold", 0.95)
self.word_num = dataset_config.get("word_num", 15)
self.prompt_word_num = 0
Expand Down Expand Up @@ -202,22 +178,14 @@ def __getitem__(self, index):
ocr = ocr.upper()
elif self.infer_type=="filter":
gt = eval(self.hotwords_list[index])
if self.filter_type == "char":
infer_sentence = self.infer_list[index].lower()
else:
infer_sentence = self.infer_list[index]

infer_sentence = self.infer_list[index].lower()
words_list = infer_sentence.split()
filtered_words = [word for word in words_list if word not in self.common_words_5k]
infer_sentence = ' '.join(filtered_words)

biaswords=eval(self.biaswords_list[index])
if self.filter_type=="char":
ngram_index = build_ngram_index(biaswords)
candidates = find_candidate_names(infer_sentence, ngram_index)
elif self.filter_type=="phn":
ngram_index = build_ngram_index_phn(biaswords)
candidates = find_candidate_names_phn(infer_sentence, ngram_index)
ngram_index = build_ngram_index(biaswords)
candidates = find_candidate_names(infer_sentence, ngram_index)
if not self.filter_infer_sentence_few:
scores = score_candidates(candidates, infer_sentence)
sorted_dict = sorted(scores.items(), key=lambda item: item[1], reverse=True)
Expand Down Expand Up @@ -246,10 +214,6 @@ def __getitem__(self, index):
logger.info("infer sentence: %s",infer_sentence)
logger.info("target sentence: %s", target)
logger.info("gt: %s, keys_list: %s", gt, keys_list)
# ===============================
if self.filter_type=="phn":
keys_list = [self.phn_to_name_dict[phn] for phn in keys_list]
keys_list = [item for sublist in keys_list for item in sublist]

ocr = " ".join(keys_list).upper()

Expand Down
Loading