Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions whisperx/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,26 +415,27 @@ def get_wildcard_emission(frame_emission, tokens, blank_id):
Returns:
tensor: Maximum probability score for each token position
"""
assert 0 <= blank_id < len(frame_emission)

# Convert tokens to a tensor if they are not already
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens

# Create a mask to identify wildcard positions
wildcard_mask = (tokens == -1)

# Get scores for non-wildcard positions
regular_scores = frame_emission[tokens.clamp(min=0)] # clamp to avoid -1 index
# regular_scores = frame_emission[tokens.clamp(min=0).long()] # clamp to avoid -1 index
V = frame_emission.size(0)
assert V > 0 and 0 <= blank_id < V, "empty emissions or invalid blank_id"

# Create a mask and compute the maximum value without modifying frame_emission
max_valid_score = frame_emission.clone() # Create a copy
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
max_valid_score = max_valid_score.max()
t = torch.as_tensor(tokens)
wildcard_mask = t.eq(-1) # capture wildcards BEFORE clamping
idx = t.clamp(0, V - 1).long() # clamp both ends [0..V-1]
regular_scores = frame_emission[idx]

# Use where operation to combine results
result = torch.where(wildcard_mask, max_valid_score, regular_scores)
max_valid = frame_emission.clone()
max_valid[blank_id] = float("-inf")
max_valid = max_valid.max()

return result
return torch.where(wildcard_mask, max_valid, regular_scores)


@dataclass
Expand Down