diff --git a/src/pyannote/core/annotation.py b/src/pyannote/core/annotation.py index a65d1f5..9e10697 100755 --- a/src/pyannote/core/annotation.py +++ b/src/pyannote/core/annotation.py @@ -106,37 +106,40 @@ See :class:`pyannote.core.Annotation` for the complete reference. """ + import itertools import warnings from collections import defaultdict from typing import ( - Hashable, - Optional, + TYPE_CHECKING, Dict, - Union, + Hashable, Iterable, + Iterator, List, + Literal, + Optional, Set, + Text, TextIO, Tuple, - Iterator, - Text, - TYPE_CHECKING, + Union, + overload, ) import numpy as np from sortedcontainers import SortedDict from . import ( + PYANNOTE_LABEL, PYANNOTE_SEGMENT, PYANNOTE_TRACK, - PYANNOTE_LABEL, ) +from .feature import SlidingWindowFeature from .segment import Segment, SlidingWindow from .timeline import Timeline -from .feature import SlidingWindowFeature -from .utils.generators import string_generator, int_generator -from .utils.types import Label, Key, Support, LabelGenerator, TrackName, CropMode +from .utils.generators import int_generator, string_generator +from .utils.types import CropMode, Key, Label, LabelGenerator, Support, TrackName if TYPE_CHECKING: import pandas as pd @@ -166,12 +169,10 @@ def from_df( uri: Optional[str] = None, modality: Optional[str] = None, ) -> "Annotation": - df = df[[PYANNOTE_SEGMENT, PYANNOTE_TRACK, PYANNOTE_LABEL]] return Annotation.from_records(df.itertuples(index=False), uri, modality) def __init__(self, uri: Optional[str] = None, modality: Optional[str] = None): - self._uri: Optional[str] = uri self.modality: Optional[str] = modality @@ -205,7 +206,6 @@ def uri(self, uri: str): self._uri = uri def _updateLabels(self): - # list of labels that needs to be updated update = set( label for label, update in self._labelNeedsUpdate.items() if update @@ -271,9 +271,20 @@ def itersegments(self): """ return iter(self._tracks) + @overload + def itertracks( + self, yield_label: Literal[False] = ... + ) -> Iterator[Tuple[Segment, TrackName]]: ... + @overload + def itertracks( + self, yield_label: Literal[True] + ) -> Iterator[Tuple[Segment, TrackName, Label]]: ... + def itertracks( self, yield_label: bool = False - ) -> Iterator[Union[Tuple[Segment, TrackName], Tuple[Segment, TrackName, Label]]]: + ) -> Union[ + Iterator[Tuple[Segment, TrackName]], Iterator[Tuple[Segment, TrackName, Label]] + ]: """Iterate over tracks (in chronological order) Parameters @@ -501,13 +512,11 @@ def crop(self, support: Support, mode: CropMode = "intersection") -> "Annotation return self.crop(support, mode=mode) elif isinstance(support, Timeline): - # if 'support' is a `Timeline`, we use its support support = support.support() cropped = self.__class__(uri=self.uri, modality=self.modality) if mode == "loose": - _tracks = {} _labels = set([]) @@ -527,14 +536,12 @@ def crop(self, support: Support, mode: CropMode = "intersection") -> "Annotation return cropped elif mode == "strict": - _tracks = {} _labels = set([]) for segment, other_segment in self.get_timeline(copy=False).co_iter( support ): - if segment not in other_segment: continue @@ -553,11 +560,9 @@ def crop(self, support: Support, mode: CropMode = "intersection") -> "Annotation return cropped elif mode == "intersection": - for segment, other_segment in self.get_timeline(copy=False).co_iter( support ): - intersection = segment & other_segment for track, label in self._tracks[segment].items(): track = cropped.new_track(intersection, candidate=track) @@ -576,7 +581,7 @@ def extrude( A simple illustration: .. code-block:: text - + A |------| |------| # B |----------| # self C |--------------| |------| # @@ -584,14 +589,14 @@ def extrude( # |-------| |-----------| # removed # - - A # + + A # B |---| # mode="intersection" - C |--| |------| # - + C |--| |------| # + A # B # mode="loose" - C |------| # + C |------| # A |------| # B |----------| # mode="strict" @@ -659,7 +664,7 @@ def get_overlap(self, labels: Optional[Iterable[Label]] = None) -> "Timeline": ------- overlap : `pyannote.core.Timeline` Timeline of the overlaps. - """ + """ if labels: annotation = self.subset(labels) else: @@ -790,9 +795,9 @@ def new_track( def __str__(self): """Human-friendly representation""" # TODO: use pandas.DataFrame - return "\n".join( - ["%s %s %s" % (s, t, l) for s, t, l in self.itertracks(yield_label=True)] - ) + return "\n".join([ + "%s %s %s" % (s, t, l) for s, t, l in self.itertracks(yield_label=True) + ]) def __delitem__(self, key: Key): """Delete one track @@ -806,7 +811,6 @@ def __delitem__(self, key: Key): # del annotation[segment] if isinstance(key, Segment): - # Pop segment out of dictionary # and get corresponding tracks # Raises KeyError if segment does not exist @@ -821,7 +825,6 @@ def __delitem__(self, key: Key): # del annotation[segment, track] elif isinstance(key, tuple) and len(key) == 2: - # get segment tracks as dictionary # if segment does not exist, get empty dictionary # Raises KeyError if segment does not exist @@ -1366,7 +1369,6 @@ def support(self, collar: float = 0.0) -> "Annotation": # with same uri and modality as original support = self.empty() for label in self.labels(): - # get timeline for current label timeline = self.label_timeline(label, copy=True) @@ -1427,7 +1429,7 @@ def __mul__(self, other: "Annotation") -> np.ndarray: if not isinstance(other, Annotation): raise TypeError( - "computing cooccurrence matrix only works with Annotation " "instances." + "computing cooccurrence matrix only works with Annotation instances." ) i_labels = self.labels() @@ -1455,11 +1457,11 @@ def discretize( duration: Optional[float] = None, ): """Discretize - + Parameters ---------- support : Segment, optional - Part of annotation to discretize. + Part of annotation to discretize. Defaults to annotation full extent. resolution : float or SlidingWindow, optional Defaults to 10ms frames.