|
41 | 41 | from .types import Details, MetricComponents |
42 | 42 | from .utils import UEMSupportMixin |
43 | 43 |
|
| 44 | +from .matcher import ( |
| 45 | + LabelMatcher, |
| 46 | + MATCH_TOTAL, |
| 47 | + MATCH_CORRECT, |
| 48 | + MATCH_CONFUSION, |
| 49 | + MATCH_MISSED_DETECTION, |
| 50 | + MATCH_FALSE_ALARM, |
| 51 | +) |
| 52 | + |
44 | 53 | if TYPE_CHECKING: |
45 | 54 | pass |
46 | 55 |
|
47 | 56 | # TODO: can't we put these as class attributes? |
48 | 57 | DER_NAME = "diarization error rate" |
49 | 58 |
|
| 59 | +OVLDER_PREFIX_OVL = "ovl" |
| 60 | +OVLDER_PREFIX_NONOVL = "nonovl" |
| 61 | + |
50 | 62 |
|
51 | 63 | class DiarizationErrorRate(IdentificationErrorRate): |
52 | 64 | """Diarization error rate |
@@ -646,3 +658,96 @@ def compute_components( |
646 | 658 | return super(DiarizationCompleteness, self).compute_components( |
647 | 659 | hypothesis, reference, uem=uem, **kwargs |
648 | 660 | ) |
| 661 | + |
| 662 | + |
| 663 | +class OverlappedDiarizationErrorRate(BaseMetric): |
| 664 | + """Diarization error rate with details for overlap and non-overlap errors. |
| 665 | + Error components will be prefixed with 'ovl' or 'nonovl' (e.g. 'ovl false alarm') |
| 666 | +
|
| 667 | + Parameters |
| 668 | + ---------- |
| 669 | + collar : float, optional |
| 670 | + Duration (in seconds) of collars removed from evaluation around |
| 671 | + boundaries of reference segments. |
| 672 | + """ |
| 673 | + |
| 674 | + OVDER_NAME = "diarization error rate" |
| 675 | + |
| 676 | + def __init__(self, collar: float = 0.0): |
| 677 | + super().__init__() |
| 678 | + |
| 679 | + self.der_ovl = DiarizationErrorRate(collar=collar, skip_overlap=False) |
| 680 | + self.der_nonovl = DiarizationErrorRate(collar=collar, skip_overlap=False) |
| 681 | + |
| 682 | + @classmethod |
| 683 | + def metric_components(cls) -> MetricComponents: |
| 684 | + comps = [] |
| 685 | + for ovl in [OVLDER_PREFIX_NONOVL, OVLDER_PREFIX_OVL]: |
| 686 | + for comp in [ |
| 687 | + MATCH_TOTAL, |
| 688 | + MATCH_CORRECT, |
| 689 | + MATCH_CONFUSION, |
| 690 | + MATCH_MISSED_DETECTION, |
| 691 | + MATCH_FALSE_ALARM, |
| 692 | + ]: |
| 693 | + comps.append(f"{ovl} {comp}") |
| 694 | + return comps |
| 695 | + |
| 696 | + @classmethod |
| 697 | + def metric_name(cls) -> str: |
| 698 | + return cls.OVDER_NAME |
| 699 | + |
| 700 | + def compute_components( |
| 701 | + self, reference: Annotation, hypothesis: Annotation, uem: Timeline | None = None |
| 702 | + ) -> Details: |
| 703 | + |
| 704 | + # map 'hypothesis' labels to 'reference' labels |
| 705 | + mapping: dict = DiarizationErrorRate().optimal_mapping( |
| 706 | + reference, hypothesis, uem=uem |
| 707 | + ) |
| 708 | + hypothesis = hypothesis.rename_labels(mapping) |
| 709 | + |
| 710 | + # split uem into non-overlapping and overlapping regions |
| 711 | + overlap: Timeline = reference.get_overlap() |
| 712 | + if uem is None: |
| 713 | + uem = ( |
| 714 | + reference.support() |
| 715 | + .get_timeline() |
| 716 | + .union(hypothesis.support().get_timeline()) |
| 717 | + ) |
| 718 | + nonovl_regions: Timeline = uem.extrude(overlap) |
| 719 | + ovl_regions: Timeline = uem.crop(overlap) |
| 720 | + |
| 721 | + # update internal metrics for (non-)overlapping errors |
| 722 | + comps_nonovl = self.der_nonovl.compute_components( |
| 723 | + reference, hypothesis, uem=nonovl_regions |
| 724 | + ) |
| 725 | + comps_ovl = self.der_ovl.compute_components( |
| 726 | + reference, hypothesis, uem=ovl_regions |
| 727 | + ) |
| 728 | + |
| 729 | + components = {} |
| 730 | + components.update( |
| 731 | + {f"{OVLDER_PREFIX_NONOVL} {k}": v for k, v in comps_nonovl.items()} |
| 732 | + ) |
| 733 | + components.update({f"{OVLDER_PREFIX_OVL} {k}": v for k, v in comps_ovl.items()}) |
| 734 | + |
| 735 | + return components |
| 736 | + |
| 737 | + def compute_metric(self, detail: Details) -> float: |
| 738 | + numerator = 0.0 |
| 739 | + denominator = 0.0 |
| 740 | + for ovl in [OVLDER_PREFIX_NONOVL, OVLDER_PREFIX_OVL]: |
| 741 | + numerator += ( |
| 742 | + detail[f"{ovl} {MATCH_FALSE_ALARM}"] |
| 743 | + + detail[f"{ovl} {MATCH_MISSED_DETECTION}"] |
| 744 | + + detail[f"{ovl} {MATCH_CONFUSION}"] |
| 745 | + ) |
| 746 | + denominator += detail[f"{ovl} {MATCH_TOTAL}"] |
| 747 | + if denominator == 0.0: |
| 748 | + if numerator == 0: |
| 749 | + return 0.0 |
| 750 | + else: |
| 751 | + return 1.0 |
| 752 | + else: |
| 753 | + return numerator / denominator |
0 commit comments