Skip to content

Commit 55669d7

Browse files
committed
added new alignment algorithm
-added parameter, `aligner`, to `transcribe()`/`align()`/`align_words()` (only for vanilla models); `aligner="new"` uses implementation of new alignment algorithm (https://arxiv.org/abs/2509.09987) -updated doctstrings to reflect new parameter
1 parent c63366e commit 55669d7

File tree

6 files changed

+149
-27
lines changed

6 files changed

+149
-27
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ Docstrings:
261261
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
262262
To specify number of iterations for finding the optimal heads,
263263
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).
264+
aligner : "legacy" or "new" or dict, default "legacy"
265+
Algorithm for selecting attention heads for alignment. Use dictionary to specify keyword arguments for 'new'.
264266
clip_timestamps : str or list of float
265267
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
266268
The last end timestamp defaults to the end of the file.
@@ -1007,6 +1009,8 @@ Docstring:
10071009
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
10081010
To specify number of iterations for finding the optimal heads,
10091011
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).
1012+
aligner : "legacy" or "new" or dict, default "legacy"
1013+
Algorithm for selecting attention heads for alignment. Use dictionary to specify keyword arguments for 'new'.
10101014

10111015
Returns
10121016
-------
@@ -1126,6 +1130,8 @@ Docstring:
11261130
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
11271131
To specify number of iterations for finding the optimal heads,
11281132
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).
1133+
aligner : "legacy" or "new" or dict, default "legacy"
1134+
Algorithm for selecting attention heads for alignment. Use dictionary to specify keyword arguments for 'new'.
11291135
normalize_text : bool or dict, default True
11301136
Whether to normalize text of each segment.
11311137
inplace : bool, default True

stable_whisper/alignment.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def align(
149149
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
150150
To specify number of iterations for finding the optimal heads,
151151
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).
152+
aligner : "legacy" or "new" or dict, default "legacy"
153+
Algorithm for selecting attention heads for alignment. Use dictionary to specify keyword arguments for 'new'.
152154
153155
Returns
154156
-------
@@ -316,6 +318,8 @@ def align_words(
316318
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
317319
To specify number of iterations for finding the optimal heads,
318320
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).
321+
aligner : "legacy" or "new" or dict, default "legacy"
322+
Algorithm for selecting attention heads for alignment. Use dictionary to specify keyword arguments for 'new'.
319323
normalize_text : bool or dict, default True
320324
Whether to normalize text of each segment.
321325
inplace : bool, default True
@@ -419,7 +423,8 @@ def compute_timestamps(audio_segment: torch.Tensor, word_tokens: List[WordToken]
419423
append_punctuations='',
420424
gap_padding=None,
421425
extra_models=options.align.extra_models,
422-
dynamic_heads=options.align.dynamic_heads
426+
dynamic_heads=options.align.dynamic_heads,
427+
aligner=options.align.aligner
423428
)
424429
return [w for seg in temp_segments for w in seg['words']]
425430

stable_whisper/non_whisper/alignment.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,13 @@ def __init__(
181181
Only if ``presplit=True``, ``gap_padding`` is prepended to each segments for word timing alignment.
182182
Used to reduce the probability of model predicting timestamps earlier than the first utterance.
183183
Ignored if ``model`` is a faster-whisper model.
184+
dynamic_heads : bool or int or str, optional
185+
Whether to find optimal cross-attention heads during runtime instead of using the predefined heads for
186+
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
187+
To specify number of iterations for finding the optimal heads,
188+
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).
189+
aligner : "legacy" or "new" or dict, default "legacy"
190+
Algorithm for selecting attention heads for alignment. Use dictionary to specify keyword arguments for 'new'.
184191
185192
Notes
186193
-----

stable_whisper/options.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def __init__(self, **kwargs):
164164
self.presplit: Union[bool, List[str]] = self._pop('presplit', True)
165165
self.extra_models: Optional[list] = self._pop('extra_models', None)
166166
self.dynamic_heads: Optional[Union[bool, int, str]] = self._pop('dynamic_heads', None)
167+
self.aligner: Union[str, dict] = self._pop('aligner', 'legacy')
167168

168169
def to_non_vanilla(self):
169170
if self.extra_models:

stable_whisper/timing.py

Lines changed: 120 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -112,26 +112,86 @@ def _compute_atten_weights(
112112
return weights
113113

114114

115+
def _compute_atten_weights_new(
116+
model: "Whisper",
117+
tokenizer: "Tokenizer",
118+
text_tokens: List[int],
119+
mel: torch.Tensor,
120+
num_samples: int,
121+
tokens: torch.tensor,
122+
cache: dict,
123+
medfilt_width: int = 7,
124+
qk_scale: float = 1.0,
125+
*,
126+
topk=20,
127+
w_colnorm=1,
128+
w_rownorm=1,
129+
w_coverage=0
130+
) -> torch.Tensor:
131+
"""
132+
Implementation of https://arxiv.org/abs/2509.09987 (https://github.com/30stomercury/whisper-char-alignment).
133+
"""
134+
if cache['qks'] is None:
135+
_compute_qks(model, tokenizer, text_tokens, mel, tokens, cache)
136+
weights = torch.cat(cache['qks'])
137+
weights = weights[..., :round(num_samples / N_SAMPLES_PER_TOKEN)]
138+
weights = median_filter(weights, medfilt_width)
139+
weights = (weights * qk_scale).softmax(dim=-1)
140+
141+
n_layers = weights.size(0)
142+
n_heads = weights.size(1)
143+
score_matix = torch.zeros(n_layers, n_heads, device=weights.device)
144+
if w_colnorm > 0:
145+
col_norm_sum = weights.norm(dim=-2).sum(-1)
146+
score_matix += w_colnorm * col_norm_sum
147+
if w_rownorm > 0:
148+
row_norm_sum = weights.norm(dim=-1).sum(-1)
149+
score_matix += w_rownorm * row_norm_sum
150+
if w_coverage > 0:
151+
coverage = torch.sum(weights, dim=2)
152+
penalty = torch.max(coverage, coverage.clone().fill_(0.5)).sum(-1)
153+
penalty = penalty - coverage.size(-1) * 0.5
154+
penalty = w_coverage * penalty
155+
score_matix -= penalty
156+
157+
top_idxs = score_matix.flatten().topk(topk).indices
158+
matrix = weights[top_idxs // n_heads, top_idxs % n_heads]
159+
col_norm = matrix.norm(dim=-2, keepdim=True)
160+
matrix = torch.mean(matrix / col_norm, 0)
161+
matrix = matrix[len(tokenizer.sot_sequence):-1]
162+
163+
return matrix
164+
165+
115166
def _compute_jump_indices(
116167
model: "Whisper",
117168
cache: dict,
118169
extra_models: List["Whisper"] = None,
170+
new: bool = False,
119171
**kwargs
120172
):
121-
weights = _compute_atten_weights(model, cache=cache, **kwargs)
122-
if extra_models:
123-
extra_weights = [weights]
124-
for mi, other_model in enumerate(extra_models):
125-
m = _compute_atten_weights(other_model, cache=cache['extra_caches'][mi], **kwargs)
126-
extra_weights.append(m)
127-
weights = torch.cat(extra_weights, dim=0)
128-
extra_text_token_probs = [c['text_token_probs'] for c in cache['extra_caches']] + [cache['text_token_probs']]
129-
cache['text_token_probs'] = torch.tensor(
130-
extra_text_token_probs,
131-
device=extra_weights[0].device
132-
).mean(dim=0).tolist()
133-
134-
matrix = weights.mean(dim=0)
173+
if new:
174+
weights = _compute_atten_weights_new(model, cache=cache, **kwargs)
175+
else:
176+
weights = _compute_atten_weights(model, cache=cache, **kwargs)
177+
if extra_models:
178+
extra_weights = [weights]
179+
for mi, other_model in enumerate(extra_models):
180+
m = _compute_atten_weights(other_model, cache=cache['extra_caches'][mi], **kwargs)
181+
extra_weights.append(m)
182+
weights = torch.cat(extra_weights, dim=0)
183+
extra_text_token_probs = (
184+
[c['text_token_probs'] for c in cache['extra_caches']] + [cache['text_token_probs']]
185+
)
186+
cache['text_token_probs'] = torch.tensor(
187+
extra_text_token_probs,
188+
device=extra_weights[0].device
189+
).mean(dim=0).tolist()
190+
191+
if new:
192+
matrix = weights
193+
else:
194+
matrix = weights.mean(dim=0)
135195
text_indices, time_indices = dtw(-matrix)
136196

137197
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
@@ -153,11 +213,14 @@ def find_alignment_stable(
153213
token_split=None,
154214
audio_features: torch.Tensor = None,
155215
extra_models: List["Whisper"] = None,
156-
dynamic_heads: Optional[Union[bool, int, str]] = None
216+
dynamic_heads: Optional[Union[bool, int, str]] = None,
217+
aligner: Union[str, dict] = 'legacy'
157218
) -> List[WordTiming]:
158219
if extra_models and (invalid_model_types := set(map(type, extra_models)) - {type(model)}):
159220
raise NotImplementedError(f'Got unsupported model type(s): {invalid_model_types}')
160221

222+
assert isinstance(aligner, dict) or aligner in ('new', 'legacy'), f'aligner must be "new"/"legacy", got "{aligner}"'
223+
161224
if ts_num:
162225
warnings.warn('``ts_num`` is deprecated and will be removed in future versions.',
163226
stacklevel=2)
@@ -173,13 +236,21 @@ def find_alignment_stable(
173236
]
174237
).to(model.device)
175238

239+
word_tokens_orig = itk = None
176240
if token_split is None:
177241
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
178242
else:
179243
words, word_tokens = token_split
244+
if isinstance(word_tokens, dict):
245+
word_tokens_orig = word_tokens['tokens_orig']
246+
itk = word_tokens['ignore_tokens']
247+
word_tokens = word_tokens['tokens']
248+
word_tokens_orig.append([tokenizer.eot])
180249
words.append(tokenizer.decode([tokenizer.eot]))
181250
word_tokens.append([tokenizer.eot])
182251
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
252+
if itk:
253+
word_boundaries += np.array([tk[:len(itk)] == itk for tk in word_tokens], dtype=word_boundaries.dtype)
183254
if dynamic_heads:
184255
if dynamic_heads is True:
185256
dynamic_heads_count = 6
@@ -203,12 +274,18 @@ def find_alignment_stable(
203274
tokens=tokens,
204275
qk_scale=qk_scale,
205276
medfilt_width=medfilt_width,
206-
extra_models=extra_models,
207-
dynamic_heads_count=dynamic_heads_count
277+
extra_models=extra_models
208278
)
279+
if aligner != 'legacy':
280+
new = True
281+
if isinstance(aligner, dict):
282+
kwargs.update(aligner)
283+
else:
284+
new = False
285+
kwargs['dynamic_heads_count'] = dynamic_heads_count
209286
cache = _new_cache(audio_features=audio_features, extras=0 if extra_models is None else len(extra_models))
210287
for _ in range(dynamic_iterations or 1):
211-
_compute_jump_indices(cache=cache, **kwargs)
288+
_compute_jump_indices(cache=cache, new=new, **kwargs)
212289
jump_times = cache['jump_indices'] / TOKENS_PER_SECOND
213290
start_times = jump_times[word_boundaries[:-1]]
214291
end_times = jump_times[word_boundaries[1:]]
@@ -217,6 +294,10 @@ def find_alignment_stable(
217294
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
218295
]
219296

297+
if word_tokens_orig is not None:
298+
assert len(word_tokens) == len(word_tokens_orig)
299+
word_tokens = word_tokens_orig
300+
220301
return [
221302
WordTiming(word, tokens, start, end, probability)
222303
for word, tokens, start, end, probability in zip(
@@ -265,7 +346,8 @@ def split_word_tokens(segments: List[dict],
265346
*,
266347
padding: (str, int) = None,
267348
split_callback: Callable = None,
268-
pad_first_seg: bool = True):
349+
pad_first_seg: bool = True,
350+
char_split: bool = False):
269351
if padding is not None:
270352
if isinstance(padding, str):
271353
padding = tokenizer.encode(padding)
@@ -275,6 +357,7 @@ def split_word_tokens(segments: List[dict],
275357
seg_indices = []
276358
words = []
277359
word_tokens = []
360+
word_char_tokens = []
278361
for i, s in enumerate(segments):
279362
temp_word_tokens = [t for t in s['tokens'] if not isinstance(t, int) or t < tokenizer.eot]
280363
curr_words, curr_word_tokens = (
@@ -294,10 +377,18 @@ def split_word_tokens(segments: List[dict],
294377
words.append(None)
295378
word_tokens.append(padding)
296379
seg_indices.extend([i] * len(curr_words))
297-
tokens.extend(list(chain.from_iterable(curr_word_tokens)))
380+
if char_split:
381+
curr_word_char_tokens = [[ct for char in word for ct in tokenizer.encode(char)] for word in curr_words]
382+
word_char_tokens.extend(curr_word_char_tokens)
383+
tokens.extend(list(chain.from_iterable(curr_word_char_tokens)))
384+
else:
385+
tokens.extend(list(chain.from_iterable(curr_word_tokens)))
298386
words.extend(curr_words)
299387
word_tokens.extend(curr_word_tokens)
300388

389+
if char_split:
390+
word_tokens = dict(tokens=word_char_tokens, tokens_orig=word_tokens, ignore_tokens=tokenizer.encode(' '))
391+
301392
return tokens, (words, word_tokens), seg_indices
302393

303394

@@ -333,6 +424,7 @@ def add_word_timestamps_stable(
333424
split_callback: Callable = None,
334425
gap_padding: Optional[str] = ' ...',
335426
pad_first_seg: bool = True,
427+
aligner: Union[str, dict] = 'legacy',
336428
**kwargs,
337429
):
338430
if len(segments) == 0:
@@ -347,6 +439,10 @@ def add_word_timestamps_stable(
347439
if append_punctuations is None:
348440
append_punctuations = "\"'.。,,!!??::”)]}、"
349441

442+
char_split = isinstance(aligner, dict) and aligner.pop('char_split', False)
443+
if char_split:
444+
gap_padding = None
445+
350446
def align():
351447
for seg in segments:
352448
seg['words'] = []
@@ -356,15 +452,17 @@ def align():
356452
tokenizer,
357453
padding=gap_padding,
358454
split_callback=split_callback,
359-
pad_first_seg=pad_first_seg
455+
pad_first_seg=pad_first_seg,
456+
char_split=char_split
360457
)
361458

362459
alignment = find_alignment_stable(model, tokenizer, text_tokens, mel, num_samples,
363460
**kwargs,
364461
token_split=token_split,
365462
audio_features=audio_features,
366463
ts_num=ts_num,
367-
ts_noise=ts_noise)
464+
ts_noise=ts_noise,
465+
aligner=aligner)
368466
alt_beginning_alignment = pop_empty_alignment(alignment, seg_indices)
369467

370468
merge_punctuations(alignment, prepend_punctuations, append_punctuations)

stable_whisper/whisper_word_level/original_whisper.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def transcribe_stable(
7272
ignore_compatibility: bool = False,
7373
extra_models: Optional[List["Whisper"]] = None,
7474
dynamic_heads: Optional[Union[bool, int, str]] = None,
75+
aligner: Union[str, dict] = 'legacy',
7576
clip_timestamps: Optional[Union[str, List[float]]] = None,
7677
resume: Union[WhisperResult, str, dict, list] = None,
7778
**decode_options) \
@@ -199,6 +200,8 @@ def transcribe_stable(
199200
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
200201
To specify number of iterations for finding the optimal heads,
201202
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).
203+
aligner : "legacy" or "new" or dict, default "legacy"
204+
Algorithm for selecting attention heads for alignment. Use dictionary to specify keyword arguments for 'new'.
202205
clip_timestamps : str or list of float
203206
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
204207
The last end timestamp defaults to the end of the file.
@@ -644,7 +647,8 @@ def inner_transcribe():
644647
split_callback=split_callback,
645648
gap_padding=gap_padding,
646649
extra_models=extra_models,
647-
dynamic_heads=dynamic_heads
650+
dynamic_heads=dynamic_heads,
651+
aligner=aligner
648652
)
649653

650654
for i in reversed(range(len(current_segments))):
@@ -673,6 +677,10 @@ def inner_transcribe():
673677
fast_forward()
674678
return
675679

680+
all_tokens.extend(
681+
[token for segment in current_segments for token in segment["tokens"]]
682+
)
683+
676684
if segment_silence_timing is not None:
677685
for seg_i, segment in enumerate(current_segments):
678686
segment = Segment(**segment, ignore_unused_args=True).suppress_silence(
@@ -692,9 +700,6 @@ def inner_transcribe():
692700
for i, segment in enumerate(current_segments, start=len(all_segments))
693701
]
694702
)
695-
all_tokens.extend(
696-
[token for segment in current_segments for token in segment["tokens"]]
697-
)
698703
if not single_timestamp_ending or avg_prob_threshold:
699704
segment_samples = num_samples
700705

0 commit comments

Comments
 (0)