diff --git a/whisperx/alignment.py b/whisperx/alignment.py index b77a7f1bf..50cb3d9dc 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -415,8 +415,6 @@ def get_wildcard_emission(frame_emission, tokens, blank_id): Returns: tensor: Maximum probability score for each token position """ - assert 0 <= blank_id < len(frame_emission) - # Convert tokens to a tensor if they are not already tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens @@ -424,17 +422,20 @@ def get_wildcard_emission(frame_emission, tokens, blank_id): wildcard_mask = (tokens == -1) # Get scores for non-wildcard positions - regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index + # regular_scores = frame_emission[tokens.clamp(min=0).long()] # clamp to avoid -1 index + V = frame_emission.size(0) + assert V > 0 and 0 <= blank_id < V, "empty emissions or invalid blank_id" - # Create a mask and compute the maximum value without modifying frame_emission - max_valid_score = frame_emission.clone() # Create a copy - max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token - max_valid_score = max_valid_score.max() + t = torch.as_tensor(tokens) + wildcard_mask = t.eq(-1) # capture wildcards BEFORE clamping + idx = t.clamp(0, V - 1).long() # clamp both ends [0..V-1] + regular_scores = frame_emission[idx] - # Use where operation to combine results - result = torch.where(wildcard_mask, max_valid_score, regular_scores) + max_valid = frame_emission.clone() + max_valid[blank_id] = float("-inf") + max_valid = max_valid.max() - return result + return torch.where(wildcard_mask, max_valid, regular_scores) @dataclass