Skip to content

Commit 61501ac

Browse files
committed
fix overlapped DER and its test
(the test passed with an incorrect metric because the hypothesis contained no overlap)
1 parent f8bf929 commit 61501ac

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

src/pyannote/metrics/diarization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,8 @@ class OverlappedDiarizationErrorRate(BaseMetric):
676676
def __init__(self, collar: float = 0.0):
677677
super().__init__()
678678

679-
self.der_ovl = DiarizationErrorRate(collar=collar, skip_overlap=False)
680-
self.der_nonovl = DiarizationErrorRate(collar=collar, skip_overlap=False)
679+
self.der_ovl = IdentificationErrorRate(collar=collar, skip_overlap=False)
680+
self.der_nonovl = IdentificationErrorRate(collar=collar, skip_overlap=False)
681681

682682
@classmethod
683683
def metric_components(cls) -> MetricComponents:

tests/test_diarization.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,24 +51,37 @@ def hypothesis():
5151
return hypothesis
5252

5353

54-
def test_ovl_der(reference_with_overlap, hypothesis):
54+
@pytest.fixture
55+
def hypothesis_overlap():
56+
hypothesis = Annotation()
57+
hypothesis[Segment(2, 13)] = "a"
58+
hypothesis[Segment(10, 14)] = "d"
59+
hypothesis[Segment(14, 24)] = "b"
60+
hypothesis[Segment(22, 38)] = "c"
61+
hypothesis[Segment(36, 40)] = "d"
62+
return hypothesis
63+
64+
65+
def test_ovl_der(reference_with_overlap, hypothesis_overlap: Annotation):
5566
der_ovl = OverlappedDiarizationErrorRate()
5667
der_regular = DiarizationErrorRate()
5768

58-
error_rate_ovl = der_ovl(reference_with_overlap, hypothesis)
59-
error_rate_regular = der_regular(reference_with_overlap, hypothesis)
69+
error_rate_ovl = der_ovl(reference_with_overlap, hypothesis_overlap)
70+
error_rate_regular = der_regular(reference_with_overlap, hypothesis_overlap)
6071

6172
npt.assert_almost_equal(error_rate_ovl, error_rate_regular, decimal=7)
6273

6374

64-
def test_ovl_der_components(reference_with_overlap, hypothesis):
75+
def test_ovl_der_components(reference_with_overlap, hypothesis_overlap):
6576
for collar in [0.0, 0.1, 0.5]:
6677
der_ovl = OverlappedDiarizationErrorRate(collar=collar)
6778
der_regular = DiarizationErrorRate(collar=collar)
6879

69-
comp_ovl: Details = der_ovl(reference_with_overlap, hypothesis, detailed=True)
80+
comp_ovl: Details = der_ovl(
81+
reference_with_overlap, hypothesis_overlap, detailed=True
82+
)
7083
comp_regular: Details = der_regular(
71-
reference_with_overlap, hypothesis, detailed=True
84+
reference_with_overlap, hypothesis_overlap, detailed=True
7285
)
7386

7487
print(comp_ovl)

0 commit comments

Comments
 (0)