Skip to content
Closed

Latency #1509

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyannote/audio/tasks/segmentation/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ def segmentation_loss(
if self.weigh_by_cardinality
else None
)

seg_loss = nll_loss(
permutated_prediction,
torch.argmax(target, dim=-1),
Expand Down Expand Up @@ -548,6 +549,12 @@ def training_step(self, batch, batch_idx: int):
warm_up_right = round(self.warm_up[1] / self.duration * num_frames)
weight[:, num_frames - warm_up_right :] = 0.0

latency = 0.1 # will be a parameter of the task (in s)
delay = int(np.floor(num_frames * latency / self.duration)) # round down

prediction = prediction[:, delay:, :]
target = target[:, :num_frames-delay, :]

if self.specifications.powerset:
multilabel = self.model.powerset.to_multilabel(prediction)
permutated_target, _ = permutate(multilabel, target)
Expand Down
21 changes: 13 additions & 8 deletions pyannote/audio/torchmetrics/audio/diarization_error_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ class DiarizationErrorRate(Metric):
higher_is_better = False
is_differentiable = False

def __init__(self, threshold: float = 0.5):
def __init__(self, threshold: float = 0.5, per_frame: bool = False):
super().__init__()

self.threshold = threshold
self.per_frame = per_frame

self.add_state("false_alarm", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state(
Expand Down Expand Up @@ -85,21 +86,25 @@ def update(
speech_total : torch.Tensor
Diarization error rate components accumulated over the whole batch.
"""

false_alarm, missed_detection, speaker_confusion, speech_total = _der_update(
if self.per_frame:
self.false_alarm, self.missed_detection, self.speaker_confusion, self.speech_total = _der_update(preds, target,
per_frame = self.per_frame, threshold=self.threshold)
else:
false_alarm, missed_detection, speaker_confusion, speech_total = _der_update(
preds, target, threshold=self.threshold
)
self.false_alarm += false_alarm
self.missed_detection += missed_detection
self.speaker_confusion += speaker_confusion
self.speech_total += speech_total
)
self.false_alarm += false_alarm
self.missed_detection += missed_detection
self.speaker_confusion += speaker_confusion
self.speech_total += speech_total

def compute(self):
return _der_compute(
self.false_alarm,
self.missed_detection,
self.speaker_confusion,
self.speech_total,
self.per_frame
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
def _der_update(
preds: torch.Tensor,
target: torch.Tensor,
per_frame: bool = False,
threshold: Union[torch.Tensor, float] = 0.5,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute components of diarization error rate
Expand All @@ -53,7 +54,6 @@ def _der_update(
speech_total : torch.Tensor
Diarization error rate components accumulated over the whole batch.
"""

# make threshold a (num_thresholds,) tensor
scalar_threshold = isinstance(threshold, Number)
if scalar_threshold:
Expand Down Expand Up @@ -86,6 +86,9 @@ def _der_update(

speaker_confusion = torch.sum((hypothesis != target) * hypothesis, 1) - false_alarm
# (batch_size, num_frames, num_thresholds)

if per_frame:
return torch.sum(false_alarm, 0)[:,0], torch.sum(missed_detection, 0)[:,0], torch.sum(speaker_confusion, 0)[:,0], 1.0 * torch.sum(target)

false_alarm = torch.sum(torch.sum(false_alarm, 1), 0)
missed_detection = torch.sum(torch.sum(missed_detection, 1), 0)
Expand All @@ -107,6 +110,7 @@ def _der_compute(
missed_detection: torch.Tensor,
speaker_confusion: torch.Tensor,
speech_total: torch.Tensor,
per_frame: bool = False,
) -> torch.Tensor:
"""Compute diarization error rate from its components

Expand All @@ -123,7 +127,8 @@ def _der_compute(
der : (num_thresholds, )-shaped torch.Tensor
Diarization error rate.
"""

if per_frame:
return false_alarm, missed_detection, speaker_confusion, speech_total
return (false_alarm + missed_detection + speaker_confusion) / (speech_total + 1e-8)


Expand Down
2 changes: 1 addition & 1 deletion pyannote/audio/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def nll_loss(
num_classes = prediction.shape[2]

losses = F.nll_loss(
prediction.view(-1, num_classes),
prediction.reshape(-1, num_classes),
# (batch_size x num_frames, num_classes)
target.view(-1),
# (batch_size x num_frames, )
Expand Down
Loading