@@ -112,26 +112,86 @@ def _compute_atten_weights(
112112 return weights
113113
114114
115+ def _compute_atten_weights_new (
116+ model : "Whisper" ,
117+ tokenizer : "Tokenizer" ,
118+ text_tokens : List [int ],
119+ mel : torch .Tensor ,
120+ num_samples : int ,
121+ tokens : torch .tensor ,
122+ cache : dict ,
123+ medfilt_width : int = 7 ,
124+ qk_scale : float = 1.0 ,
125+ * ,
126+ topk = 20 ,
127+ w_colnorm = 1 ,
128+ w_rownorm = 1 ,
129+ w_coverage = 0
130+ ) -> torch .Tensor :
131+ """
132+ Implementation of https://arxiv.org/abs/2509.09987 (https://github.com/30stomercury/whisper-char-alignment).
133+ """
134+ if cache ['qks' ] is None :
135+ _compute_qks (model , tokenizer , text_tokens , mel , tokens , cache )
136+ weights = torch .cat (cache ['qks' ])
137+ weights = weights [..., :round (num_samples / N_SAMPLES_PER_TOKEN )]
138+ weights = median_filter (weights , medfilt_width )
139+ weights = (weights * qk_scale ).softmax (dim = - 1 )
140+
141+ n_layers = weights .size (0 )
142+ n_heads = weights .size (1 )
143+ score_matix = torch .zeros (n_layers , n_heads , device = weights .device )
144+ if w_colnorm > 0 :
145+ col_norm_sum = weights .norm (dim = - 2 ).sum (- 1 )
146+ score_matix += w_colnorm * col_norm_sum
147+ if w_rownorm > 0 :
148+ row_norm_sum = weights .norm (dim = - 1 ).sum (- 1 )
149+ score_matix += w_rownorm * row_norm_sum
150+ if w_coverage > 0 :
151+ coverage = torch .sum (weights , dim = 2 )
152+ penalty = torch .max (coverage , coverage .clone ().fill_ (0.5 )).sum (- 1 )
153+ penalty = penalty - coverage .size (- 1 ) * 0.5
154+ penalty = w_coverage * penalty
155+ score_matix -= penalty
156+
157+ top_idxs = score_matix .flatten ().topk (topk ).indices
158+ matrix = weights [top_idxs // n_heads , top_idxs % n_heads ]
159+ col_norm = matrix .norm (dim = - 2 , keepdim = True )
160+ matrix = torch .mean (matrix / col_norm , 0 )
161+ matrix = matrix [len (tokenizer .sot_sequence ):- 1 ]
162+
163+ return matrix
164+
165+
115166def _compute_jump_indices (
116167 model : "Whisper" ,
117168 cache : dict ,
118169 extra_models : List ["Whisper" ] = None ,
170+ new : bool = False ,
119171 ** kwargs
120172):
121- weights = _compute_atten_weights (model , cache = cache , ** kwargs )
122- if extra_models :
123- extra_weights = [weights ]
124- for mi , other_model in enumerate (extra_models ):
125- m = _compute_atten_weights (other_model , cache = cache ['extra_caches' ][mi ], ** kwargs )
126- extra_weights .append (m )
127- weights = torch .cat (extra_weights , dim = 0 )
128- extra_text_token_probs = [c ['text_token_probs' ] for c in cache ['extra_caches' ]] + [cache ['text_token_probs' ]]
129- cache ['text_token_probs' ] = torch .tensor (
130- extra_text_token_probs ,
131- device = extra_weights [0 ].device
132- ).mean (dim = 0 ).tolist ()
133-
134- matrix = weights .mean (dim = 0 )
173+ if new :
174+ weights = _compute_atten_weights_new (model , cache = cache , ** kwargs )
175+ else :
176+ weights = _compute_atten_weights (model , cache = cache , ** kwargs )
177+ if extra_models :
178+ extra_weights = [weights ]
179+ for mi , other_model in enumerate (extra_models ):
180+ m = _compute_atten_weights (other_model , cache = cache ['extra_caches' ][mi ], ** kwargs )
181+ extra_weights .append (m )
182+ weights = torch .cat (extra_weights , dim = 0 )
183+ extra_text_token_probs = (
184+ [c ['text_token_probs' ] for c in cache ['extra_caches' ]] + [cache ['text_token_probs' ]]
185+ )
186+ cache ['text_token_probs' ] = torch .tensor (
187+ extra_text_token_probs ,
188+ device = extra_weights [0 ].device
189+ ).mean (dim = 0 ).tolist ()
190+
191+ if new :
192+ matrix = weights
193+ else :
194+ matrix = weights .mean (dim = 0 )
135195 text_indices , time_indices = dtw (- matrix )
136196
137197 jumps = np .pad (np .diff (text_indices ), (1 , 0 ), constant_values = 1 ).astype (bool )
@@ -153,11 +213,14 @@ def find_alignment_stable(
153213 token_split = None ,
154214 audio_features : torch .Tensor = None ,
155215 extra_models : List ["Whisper" ] = None ,
156- dynamic_heads : Optional [Union [bool , int , str ]] = None
216+ dynamic_heads : Optional [Union [bool , int , str ]] = None ,
217+ aligner : Union [str , dict ] = 'legacy'
157218) -> List [WordTiming ]:
158219 if extra_models and (invalid_model_types := set (map (type , extra_models )) - {type (model )}):
159220 raise NotImplementedError (f'Got unsupported model type(s): { invalid_model_types } ' )
160221
222+ assert isinstance (aligner , dict ) or aligner in ('new' , 'legacy' ), f'aligner must be "new"/"legacy", got "{ aligner } "'
223+
161224 if ts_num :
162225 warnings .warn ('``ts_num`` is deprecated and will be removed in future versions.' ,
163226 stacklevel = 2 )
@@ -173,13 +236,21 @@ def find_alignment_stable(
173236 ]
174237 ).to (model .device )
175238
239+ word_tokens_orig = itk = None
176240 if token_split is None :
177241 words , word_tokens = tokenizer .split_to_word_tokens (text_tokens + [tokenizer .eot ])
178242 else :
179243 words , word_tokens = token_split
244+ if isinstance (word_tokens , dict ):
245+ word_tokens_orig = word_tokens ['tokens_orig' ]
246+ itk = word_tokens ['ignore_tokens' ]
247+ word_tokens = word_tokens ['tokens' ]
248+ word_tokens_orig .append ([tokenizer .eot ])
180249 words .append (tokenizer .decode ([tokenizer .eot ]))
181250 word_tokens .append ([tokenizer .eot ])
182251 word_boundaries = np .pad (np .cumsum ([len (t ) for t in word_tokens [:- 1 ]]), (1 , 0 ))
252+ if itk :
253+ word_boundaries += np .array ([tk [:len (itk )] == itk for tk in word_tokens ], dtype = word_boundaries .dtype )
183254 if dynamic_heads :
184255 if dynamic_heads is True :
185256 dynamic_heads_count = 6
@@ -203,12 +274,18 @@ def find_alignment_stable(
203274 tokens = tokens ,
204275 qk_scale = qk_scale ,
205276 medfilt_width = medfilt_width ,
206- extra_models = extra_models ,
207- dynamic_heads_count = dynamic_heads_count
277+ extra_models = extra_models
208278 )
279+ if aligner != 'legacy' :
280+ new = True
281+ if isinstance (aligner , dict ):
282+ kwargs .update (aligner )
283+ else :
284+ new = False
285+ kwargs ['dynamic_heads_count' ] = dynamic_heads_count
209286 cache = _new_cache (audio_features = audio_features , extras = 0 if extra_models is None else len (extra_models ))
210287 for _ in range (dynamic_iterations or 1 ):
211- _compute_jump_indices (cache = cache , ** kwargs )
288+ _compute_jump_indices (cache = cache , new = new , ** kwargs )
212289 jump_times = cache ['jump_indices' ] / TOKENS_PER_SECOND
213290 start_times = jump_times [word_boundaries [:- 1 ]]
214291 end_times = jump_times [word_boundaries [1 :]]
@@ -217,6 +294,10 @@ def find_alignment_stable(
217294 for i , j in zip (word_boundaries [:- 1 ], word_boundaries [1 :])
218295 ]
219296
297+ if word_tokens_orig is not None :
298+ assert len (word_tokens ) == len (word_tokens_orig )
299+ word_tokens = word_tokens_orig
300+
220301 return [
221302 WordTiming (word , tokens , start , end , probability )
222303 for word , tokens , start , end , probability in zip (
@@ -265,7 +346,8 @@ def split_word_tokens(segments: List[dict],
265346 * ,
266347 padding : (str , int ) = None ,
267348 split_callback : Callable = None ,
268- pad_first_seg : bool = True ):
349+ pad_first_seg : bool = True ,
350+ char_split : bool = False ):
269351 if padding is not None :
270352 if isinstance (padding , str ):
271353 padding = tokenizer .encode (padding )
@@ -275,6 +357,7 @@ def split_word_tokens(segments: List[dict],
275357 seg_indices = []
276358 words = []
277359 word_tokens = []
360+ word_char_tokens = []
278361 for i , s in enumerate (segments ):
279362 temp_word_tokens = [t for t in s ['tokens' ] if not isinstance (t , int ) or t < tokenizer .eot ]
280363 curr_words , curr_word_tokens = (
@@ -294,10 +377,18 @@ def split_word_tokens(segments: List[dict],
294377 words .append (None )
295378 word_tokens .append (padding )
296379 seg_indices .extend ([i ] * len (curr_words ))
297- tokens .extend (list (chain .from_iterable (curr_word_tokens )))
380+ if char_split :
381+ curr_word_char_tokens = [[ct for char in word for ct in tokenizer .encode (char )] for word in curr_words ]
382+ word_char_tokens .extend (curr_word_char_tokens )
383+ tokens .extend (list (chain .from_iterable (curr_word_char_tokens )))
384+ else :
385+ tokens .extend (list (chain .from_iterable (curr_word_tokens )))
298386 words .extend (curr_words )
299387 word_tokens .extend (curr_word_tokens )
300388
389+ if char_split :
390+ word_tokens = dict (tokens = word_char_tokens , tokens_orig = word_tokens , ignore_tokens = tokenizer .encode (' ' ))
391+
301392 return tokens , (words , word_tokens ), seg_indices
302393
303394
@@ -333,6 +424,7 @@ def add_word_timestamps_stable(
333424 split_callback : Callable = None ,
334425 gap_padding : Optional [str ] = ' ...' ,
335426 pad_first_seg : bool = True ,
427+ aligner : Union [str , dict ] = 'legacy' ,
336428 ** kwargs ,
337429):
338430 if len (segments ) == 0 :
@@ -347,6 +439,10 @@ def add_word_timestamps_stable(
347439 if append_punctuations is None :
348440 append_punctuations = "\" '.。,,!!??::”)]}、"
349441
442+ char_split = isinstance (aligner , dict ) and aligner .pop ('char_split' , False )
443+ if char_split :
444+ gap_padding = None
445+
350446 def align ():
351447 for seg in segments :
352448 seg ['words' ] = []
@@ -356,15 +452,17 @@ def align():
356452 tokenizer ,
357453 padding = gap_padding ,
358454 split_callback = split_callback ,
359- pad_first_seg = pad_first_seg
455+ pad_first_seg = pad_first_seg ,
456+ char_split = char_split
360457 )
361458
362459 alignment = find_alignment_stable (model , tokenizer , text_tokens , mel , num_samples ,
363460 ** kwargs ,
364461 token_split = token_split ,
365462 audio_features = audio_features ,
366463 ts_num = ts_num ,
367- ts_noise = ts_noise )
464+ ts_noise = ts_noise ,
465+ aligner = aligner )
368466 alt_beginning_alignment = pop_empty_alignment (alignment , seg_indices )
369467
370468 merge_punctuations (alignment , prepend_punctuations , append_punctuations )
0 commit comments