Skip to content

Commit 8ff6646

Browse files
authored
Merge branch 'develop' into setup/pin-torch-torchcodec
2 parents f873da3 + 491d3f4 commit 8ff6646

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
- BREAKING(util): make `Binarize.__call__` return `string` tracks (instead of `int`) [@benniekiss](https://github.com/benniekiss/)
66
- feat(cli): add option to apply pipeline on a directory of audio files
7+
- feat(pipeline): add `preload` option to base `Pipeline.__call__` to force preloading audio in memory ([@antoinelaurent](https://github.com/antoinelaurent/))
78
- feat(pipeline): add `Pipeline.cuda()` convenience method [@tkanarsky](https://github.com/tkanarsky/)
89
- improve(util): make `permutate` faster thanks to vectorized cost function
910

src/pyannote/audio/core/pipeline.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,23 @@ def classes(self) -> List | Iterator:
407407
"""
408408
raise NotImplementedError()
409409

410-
def __call__(self, file: AudioFile, **kwargs):
410+
def __call__(self, file: AudioFile, preload: bool = False, **kwargs):
411+
"""Validate file, (optionally) load it in memory, then process it
412+
413+
Parameters
414+
----------
415+
file : AudioFile
416+
File to process
417+
preload : bool, optional
418+
Whether to preload waveform before applying the pipeline.
419+
kwargs : keyword arguments, optional
420+
Additional keyword arguments passed to `self.apply(...)`
421+
422+
Returns
423+
-------
424+
output : Any
425+
Whatever `self.apply(...)` returns
426+
"""
411427
fix_reproducibility(getattr(self, "device", torch.device("cpu")))
412428

413429
if not self.instantiated:
@@ -433,9 +449,28 @@ def __call__(self, file: AudioFile, **kwargs):
433449

434450
file = Audio.validate_file(file)
435451

452+
# check if the instance has preprocessors and wrap the file if so
436453
if hasattr(self, "preprocessors"):
437454
file = ProtocolFile(file, lazy=self.preprocessors)
438455

456+
# pre-load the audio in memory if requested
457+
if preload:
458+
# raise error if `waveform`` is already in memory (or will be via a preprocessor)
459+
if (
460+
"waveform" in getattr(self, "preprocessors", dict())
461+
or "waveform" in file
462+
):
463+
raise ValueError(
464+
"Cannot preload audio: `waveform` key is already available or will be via a preprocessor."
465+
)
466+
467+
# load waveform in memory (and keep track of its original sample rate)
468+
file["waveform"], file["sample_rate"] = Audio()(file)
469+
470+
# the above line already took care of channel selection,
471+
# therefore we remove the `channel` key from the file
472+
file.pop("channel", None)
473+
439474
# send file duration to telemetry as well as
440475
# requested number of speakers in case of diarization
441476
track_pipeline_apply(self, file, **kwargs)

0 commit comments

Comments
 (0)