Skip to content

Commit f8bf929

Browse files
committed
add overlapped DER + tests
1 parent 4e0571f commit f8bf929

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

src/pyannote/metrics/diarization.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,24 @@
4141
from .types import Details, MetricComponents
4242
from .utils import UEMSupportMixin
4343

44+
from .matcher import (
45+
LabelMatcher,
46+
MATCH_TOTAL,
47+
MATCH_CORRECT,
48+
MATCH_CONFUSION,
49+
MATCH_MISSED_DETECTION,
50+
MATCH_FALSE_ALARM,
51+
)
52+
4453
if TYPE_CHECKING:
4554
pass
4655

4756
# TODO: can't we put these as class attributes?
4857
DER_NAME = "diarization error rate"
4958

59+
OVLDER_PREFIX_OVL = "ovl"
60+
OVLDER_PREFIX_NONOVL = "nonovl"
61+
5062

5163
class DiarizationErrorRate(IdentificationErrorRate):
5264
"""Diarization error rate
@@ -646,3 +658,96 @@ def compute_components(
646658
return super(DiarizationCompleteness, self).compute_components(
647659
hypothesis, reference, uem=uem, **kwargs
648660
)
661+
662+
663+
class OverlappedDiarizationErrorRate(BaseMetric):
664+
"""Diarization error rate with details for overlap and non-overlap errors.
665+
Error components will be prefixed with 'ovl' or 'nonovl' (e.g. 'ovl false alarm')
666+
667+
Parameters
668+
----------
669+
collar : float, optional
670+
Duration (in seconds) of collars removed from evaluation around
671+
boundaries of reference segments.
672+
"""
673+
674+
OVDER_NAME = "diarization error rate"
675+
676+
def __init__(self, collar: float = 0.0):
677+
super().__init__()
678+
679+
self.der_ovl = DiarizationErrorRate(collar=collar, skip_overlap=False)
680+
self.der_nonovl = DiarizationErrorRate(collar=collar, skip_overlap=False)
681+
682+
@classmethod
683+
def metric_components(cls) -> MetricComponents:
684+
comps = []
685+
for ovl in [OVLDER_PREFIX_NONOVL, OVLDER_PREFIX_OVL]:
686+
for comp in [
687+
MATCH_TOTAL,
688+
MATCH_CORRECT,
689+
MATCH_CONFUSION,
690+
MATCH_MISSED_DETECTION,
691+
MATCH_FALSE_ALARM,
692+
]:
693+
comps.append(f"{ovl} {comp}")
694+
return comps
695+
696+
@classmethod
697+
def metric_name(cls) -> str:
698+
return cls.OVDER_NAME
699+
700+
def compute_components(
701+
self, reference: Annotation, hypothesis: Annotation, uem: Timeline | None = None
702+
) -> Details:
703+
704+
# map 'hypothesis' labels to 'reference' labels
705+
mapping: dict = DiarizationErrorRate().optimal_mapping(
706+
reference, hypothesis, uem=uem
707+
)
708+
hypothesis = hypothesis.rename_labels(mapping)
709+
710+
# split uem into non-overlapping and overlapping regions
711+
overlap: Timeline = reference.get_overlap()
712+
if uem is None:
713+
uem = (
714+
reference.support()
715+
.get_timeline()
716+
.union(hypothesis.support().get_timeline())
717+
)
718+
nonovl_regions: Timeline = uem.extrude(overlap)
719+
ovl_regions: Timeline = uem.crop(overlap)
720+
721+
# update internal metrics for (non-)overlapping errors
722+
comps_nonovl = self.der_nonovl.compute_components(
723+
reference, hypothesis, uem=nonovl_regions
724+
)
725+
comps_ovl = self.der_ovl.compute_components(
726+
reference, hypothesis, uem=ovl_regions
727+
)
728+
729+
components = {}
730+
components.update(
731+
{f"{OVLDER_PREFIX_NONOVL} {k}": v for k, v in comps_nonovl.items()}
732+
)
733+
components.update({f"{OVLDER_PREFIX_OVL} {k}": v for k, v in comps_ovl.items()})
734+
735+
return components
736+
737+
def compute_metric(self, detail: Details) -> float:
738+
numerator = 0.0
739+
denominator = 0.0
740+
for ovl in [OVLDER_PREFIX_NONOVL, OVLDER_PREFIX_OVL]:
741+
numerator += (
742+
detail[f"{ovl} {MATCH_FALSE_ALARM}"]
743+
+ detail[f"{ovl} {MATCH_MISSED_DETECTION}"]
744+
+ detail[f"{ovl} {MATCH_CONFUSION}"]
745+
)
746+
denominator += detail[f"{ovl} {MATCH_TOTAL}"]
747+
if denominator == 0.0:
748+
if numerator == 0:
749+
return 0.0
750+
else:
751+
return 1.0
752+
else:
753+
return numerator / denominator

tests/test_diarization.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from calendar import c
12
import pytest
23

34
import pyannote.core
@@ -7,9 +8,17 @@
78
from pyannote.metrics.diarization import DiarizationErrorRate
89
from pyannote.metrics.diarization import DiarizationPurity
910
from pyannote.metrics.diarization import DiarizationCoverage
11+
from pyannote.metrics.diarization import (
12+
OverlappedDiarizationErrorRate,
13+
OVLDER_PREFIX_NONOVL,
14+
OVLDER_PREFIX_OVL,
15+
)
1016

1117
import numpy.testing as npt
1218

19+
from pyannote.metrics.matcher import MATCH_TOTAL
20+
from pyannote.metrics.types import Details
21+
1322

1423
@pytest.fixture
1524
def reference():
@@ -42,6 +51,40 @@ def hypothesis():
4251
return hypothesis
4352

4453

54+
def test_ovl_der(reference_with_overlap, hypothesis):
55+
der_ovl = OverlappedDiarizationErrorRate()
56+
der_regular = DiarizationErrorRate()
57+
58+
error_rate_ovl = der_ovl(reference_with_overlap, hypothesis)
59+
error_rate_regular = der_regular(reference_with_overlap, hypothesis)
60+
61+
npt.assert_almost_equal(error_rate_ovl, error_rate_regular, decimal=7)
62+
63+
64+
def test_ovl_der_components(reference_with_overlap, hypothesis):
65+
for collar in [0.0, 0.1, 0.5]:
66+
der_ovl = OverlappedDiarizationErrorRate(collar=collar)
67+
der_regular = DiarizationErrorRate(collar=collar)
68+
69+
comp_ovl: Details = der_ovl(reference_with_overlap, hypothesis, detailed=True)
70+
comp_regular: Details = der_regular(
71+
reference_with_overlap, hypothesis, detailed=True
72+
)
73+
74+
print(comp_ovl)
75+
print(comp_regular)
76+
77+
# test that for each component, the sum of non-overlapped and overlapped components is equal to the regular component
78+
# eg check that ovl confusion+nonovl confusion = confusion
79+
for component in der_regular.metric_components():
80+
ovl_compsum = comp_ovl["nonovl " + component] + comp_ovl["ovl " + component]
81+
reg_compsum = comp_regular[component]
82+
npt.assert_almost_equal(ovl_compsum, reg_compsum, decimal=7)
83+
# check there is overlapped and nonoverlapped speech
84+
assert comp_ovl[f"{OVLDER_PREFIX_NONOVL} {MATCH_TOTAL}"] > 0.0
85+
assert comp_ovl[f"{OVLDER_PREFIX_OVL} {MATCH_TOTAL}"] > 0.0
86+
87+
4588
def test_error_rate(reference, hypothesis):
4689
diarizationErrorRate = DiarizationErrorRate()
4790
error_rate = diarizationErrorRate(reference, hypothesis)

0 commit comments

Comments
 (0)