@@ -36,25 +36,6 @@ def find_candidate_names(sentence, ngram_index, n=2):
3636 candidates .update (ngram_index .get (ngram , []))
3737 return candidates
3838
39- def build_ngram_index_phn (names , n = 2 ):
40- """构建N-Gram倒排索引"""
41- index = {}
42- for name in names :
43- phonemes = name .split ()
44- for i in range (len (phonemes ) - n + 1 ):
45- ngram = ' ' .join (phonemes [i :i + n ])
46- index .setdefault (ngram , set ()).add (name )
47- return index
48-
49- def find_candidate_names_phn (phonemes , ngram_index , n = 2 ):
50- """通过N-Gram倒排索引找到候选人名"""
51- candidates = set ()
52- phonemes = phonemes .split ()
53- for i in range (len (phonemes ) - n + 1 ):
54- ngram = ' ' .join (phonemes [i :i + n ])
55- candidates .update (ngram_index .get (ngram , []))
56- return candidates
57-
5839@lru_cache (maxsize = 100000 )
5940def similarity (name , sentence ):
6041 return Levenshtein .ratio (name , sentence )
@@ -139,11 +120,6 @@ def __init__(
139120 # analyze
140121 self .hotwords_num = 0
141122 self .miss_words_num = 0
142- self .filter_type = dataset_config .filter_type
143- if self .filter_type == "phn" :
144- with open (dataset_config .phn_to_name_dict , 'r' ) as file :
145- self .phn_to_name_dict = json .load (file )
146-
147123 self .probability_threshold = dataset_config .get ("probability_threshold" , 0.95 )
148124 self .word_num = dataset_config .get ("word_num" , 15 )
149125 self .prompt_word_num = 0
@@ -202,22 +178,14 @@ def __getitem__(self, index):
202178 ocr = ocr .upper ()
203179 elif self .infer_type == "filter" :
204180 gt = eval (self .hotwords_list [index ])
205- if self .filter_type == "char" :
206- infer_sentence = self .infer_list [index ].lower ()
207- else :
208- infer_sentence = self .infer_list [index ]
209-
181+ infer_sentence = self .infer_list [index ].lower ()
210182 words_list = infer_sentence .split ()
211183 filtered_words = [word for word in words_list if word not in self .common_words_5k ]
212184 infer_sentence = ' ' .join (filtered_words )
213185
214186 biaswords = eval (self .biaswords_list [index ])
215- if self .filter_type == "char" :
216- ngram_index = build_ngram_index (biaswords )
217- candidates = find_candidate_names (infer_sentence , ngram_index )
218- elif self .filter_type == "phn" :
219- ngram_index = build_ngram_index_phn (biaswords )
220- candidates = find_candidate_names_phn (infer_sentence , ngram_index )
187+ ngram_index = build_ngram_index (biaswords )
188+ candidates = find_candidate_names (infer_sentence , ngram_index )
221189 if not self .filter_infer_sentence_few :
222190 scores = score_candidates (candidates , infer_sentence )
223191 sorted_dict = sorted (scores .items (), key = lambda item : item [1 ], reverse = True )
@@ -246,10 +214,6 @@ def __getitem__(self, index):
246214 logger .info ("infer sentence: %s" ,infer_sentence )
247215 logger .info ("target sentence: %s" , target )
248216 logger .info ("gt: %s, keys_list: %s" , gt , keys_list )
249- # ===============================
250- if self .filter_type == "phn" :
251- keys_list = [self .phn_to_name_dict [phn ] for phn in keys_list ]
252- keys_list = [item for sublist in keys_list for item in sublist ]
253217
254218 ocr = " " .join (keys_list ).upper ()
255219
0 commit comments