Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 38 additions & 36 deletions src/pyannote/core/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,37 +106,40 @@

See :class:`pyannote.core.Annotation` for the complete reference.
"""

import itertools
import warnings
from collections import defaultdict
from typing import (
Hashable,
Optional,
TYPE_CHECKING,
Dict,
Union,
Hashable,
Iterable,
Iterator,
List,
Literal,
Optional,
Set,
Text,
TextIO,
Tuple,
Iterator,
Text,
TYPE_CHECKING,
Union,
overload,
)

import numpy as np
from sortedcontainers import SortedDict

from . import (
PYANNOTE_LABEL,
PYANNOTE_SEGMENT,
PYANNOTE_TRACK,
PYANNOTE_LABEL,
)
from .feature import SlidingWindowFeature
from .segment import Segment, SlidingWindow
from .timeline import Timeline
from .feature import SlidingWindowFeature
from .utils.generators import string_generator, int_generator
from .utils.types import Label, Key, Support, LabelGenerator, TrackName, CropMode
from .utils.generators import int_generator, string_generator
from .utils.types import CropMode, Key, Label, LabelGenerator, Support, TrackName

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -166,12 +169,10 @@ def from_df(
uri: Optional[str] = None,
modality: Optional[str] = None,
) -> "Annotation":

df = df[[PYANNOTE_SEGMENT, PYANNOTE_TRACK, PYANNOTE_LABEL]]
return Annotation.from_records(df.itertuples(index=False), uri, modality)

def __init__(self, uri: Optional[str] = None, modality: Optional[str] = None):

self._uri: Optional[str] = uri
self.modality: Optional[str] = modality

Expand Down Expand Up @@ -205,7 +206,6 @@ def uri(self, uri: str):
self._uri = uri

def _updateLabels(self):

# list of labels that needs to be updated
update = set(
label for label, update in self._labelNeedsUpdate.items() if update
Expand Down Expand Up @@ -271,9 +271,20 @@ def itersegments(self):
"""
return iter(self._tracks)

@overload
def itertracks(
self, yield_label: Literal[False] = ...
) -> Iterator[Tuple[Segment, TrackName]]: ...
@overload
def itertracks(
self, yield_label: Literal[True]
) -> Iterator[Tuple[Segment, TrackName, Label]]: ...

def itertracks(
self, yield_label: bool = False
) -> Iterator[Union[Tuple[Segment, TrackName], Tuple[Segment, TrackName, Label]]]:
) -> Union[
Iterator[Tuple[Segment, TrackName]], Iterator[Tuple[Segment, TrackName, Label]]
]:
"""Iterate over tracks (in chronological order)

Parameters
Expand Down Expand Up @@ -501,13 +512,11 @@ def crop(self, support: Support, mode: CropMode = "intersection") -> "Annotation
return self.crop(support, mode=mode)

elif isinstance(support, Timeline):

# if 'support' is a `Timeline`, we use its support
support = support.support()
cropped = self.__class__(uri=self.uri, modality=self.modality)

if mode == "loose":

_tracks = {}
_labels = set([])

Expand All @@ -527,14 +536,12 @@ def crop(self, support: Support, mode: CropMode = "intersection") -> "Annotation
return cropped

elif mode == "strict":

_tracks = {}
_labels = set([])

for segment, other_segment in self.get_timeline(copy=False).co_iter(
support
):

if segment not in other_segment:
continue

Expand All @@ -553,11 +560,9 @@ def crop(self, support: Support, mode: CropMode = "intersection") -> "Annotation
return cropped

elif mode == "intersection":

for segment, other_segment in self.get_timeline(copy=False).co_iter(
support
):

intersection = segment & other_segment
for track, label in self._tracks[segment].items():
track = cropped.new_track(intersection, candidate=track)
Expand All @@ -576,22 +581,22 @@ def extrude(
A simple illustration:

.. code-block:: text

A |------| |------| #
B |----------| # self
C |--------------| |------| #

#
|-------| |-----------| # removed
#
A #

A #
B |---| # mode="intersection"
C |--| |------| #
C |--| |------| #

A #
B # mode="loose"
C |------| #
C |------| #

A |------| #
B |----------| # mode="strict"
Expand Down Expand Up @@ -659,7 +664,7 @@ def get_overlap(self, labels: Optional[Iterable[Label]] = None) -> "Timeline":
-------
overlap : `pyannote.core.Timeline`
Timeline of the overlaps.
"""
"""
if labels:
annotation = self.subset(labels)
else:
Expand Down Expand Up @@ -790,9 +795,9 @@ def new_track(
def __str__(self):
"""Human-friendly representation"""
# TODO: use pandas.DataFrame
return "\n".join(
["%s %s %s" % (s, t, l) for s, t, l in self.itertracks(yield_label=True)]
)
return "\n".join([
"%s %s %s" % (s, t, l) for s, t, l in self.itertracks(yield_label=True)
])

def __delitem__(self, key: Key):
"""Delete one track
Expand All @@ -806,7 +811,6 @@ def __delitem__(self, key: Key):

# del annotation[segment]
if isinstance(key, Segment):

# Pop segment out of dictionary
# and get corresponding tracks
# Raises KeyError if segment does not exist
Expand All @@ -821,7 +825,6 @@ def __delitem__(self, key: Key):

# del annotation[segment, track]
elif isinstance(key, tuple) and len(key) == 2:

# get segment tracks as dictionary
# if segment does not exist, get empty dictionary
# Raises KeyError if segment does not exist
Expand Down Expand Up @@ -1366,7 +1369,6 @@ def support(self, collar: float = 0.0) -> "Annotation":
# with same uri and modality as original
support = self.empty()
for label in self.labels():

# get timeline for current label
timeline = self.label_timeline(label, copy=True)

Expand Down Expand Up @@ -1427,7 +1429,7 @@ def __mul__(self, other: "Annotation") -> np.ndarray:

if not isinstance(other, Annotation):
raise TypeError(
"computing cooccurrence matrix only works with Annotation " "instances."
"computing cooccurrence matrix only works with Annotation instances."
)

i_labels = self.labels()
Expand Down Expand Up @@ -1455,11 +1457,11 @@ def discretize(
duration: Optional[float] = None,
):
"""Discretize

Parameters
----------
support : Segment, optional
Part of annotation to discretize.
Part of annotation to discretize.
Defaults to annotation full extent.
resolution : float or SlidingWindow, optional
Defaults to 10ms frames.
Expand Down