Skip to content

Commit e49a686

Browse files
authored
Merge branch 'develop' into improve/on-prem-wrapper
2 parents a27fec1 + 491d3f4 commit e49a686

File tree

2 files changed

+53
-4
lines changed

2 files changed

+53
-4
lines changed

CHANGELOG.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
## next
44

5+
- BREAKING(util): make `Binarize.__call__` return `string` tracks (instead of `int`) [@benniekiss](https://github.com/benniekiss/)
56
- feat(cli): add option to apply pipeline on a directory of audio files
7+
- feat(pipeline): add `preload` option to base `Pipeline.__call__` to force preloading audio in memory ([@antoinelaurent](https://github.com/antoinelaurent/))
8+
- feat(pipeline): add `Pipeline.cuda()` convenience method [@tkanarsky](https://github.com/tkanarsky/)
69
- improve(util): make `permutate` faster thanks to vectorized cost function
7-
- BREAKING(util): make `Binarize.__call__` return `string` tracks (instead of `int`) [@benniekiss](https://github.com/benniekiss/)
8-
- improve(pyannoteAI): update on-premise wrapper to return both regular and exclusive diarization
10+
- improve(pyannoteAI): update pyannoteAI wrapper to return both regular and exclusive diarization
911

1012
## Version 4.0.1 (2025-10-10)
1113

src/pyannote/audio/core/pipeline.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2222
# SOFTWARE.
2323

24+
from __future__ import annotations
2425
import os
2526
import warnings
2627
from collections import OrderedDict
@@ -406,7 +407,23 @@ def classes(self) -> List | Iterator:
406407
"""
407408
raise NotImplementedError()
408409

409-
def __call__(self, file: AudioFile, **kwargs):
410+
def __call__(self, file: AudioFile, preload: bool = False, **kwargs):
411+
"""Validate file, (optionally) load it in memory, then process it
412+
413+
Parameters
414+
----------
415+
file : AudioFile
416+
File to process
417+
preload : bool, optional
418+
Whether to preload waveform before applying the pipeline.
419+
kwargs : keyword arguments, optional
420+
Additional keyword arguments passed to `self.apply(...)`
421+
422+
Returns
423+
-------
424+
output : Any
425+
Whatever `self.apply(...)` returns
426+
"""
410427
fix_reproducibility(getattr(self, "device", torch.device("cpu")))
411428

412429
if not self.instantiated:
@@ -432,16 +449,35 @@ def __call__(self, file: AudioFile, **kwargs):
432449

433450
file = Audio.validate_file(file)
434451

452+
# check if the instance has preprocessors and wrap the file if so
435453
if hasattr(self, "preprocessors"):
436454
file = ProtocolFile(file, lazy=self.preprocessors)
437455

456+
# pre-load the audio in memory if requested
457+
if preload:
458+
# raise error if `waveform`` is already in memory (or will be via a preprocessor)
459+
if (
460+
"waveform" in getattr(self, "preprocessors", dict())
461+
or "waveform" in file
462+
):
463+
raise ValueError(
464+
"Cannot preload audio: `waveform` key is already available or will be via a preprocessor."
465+
)
466+
467+
# load waveform in memory (and keep track of its original sample rate)
468+
file["waveform"], file["sample_rate"] = Audio()(file)
469+
470+
# the above line already took care of channel selection,
471+
# therefore we remove the `channel` key from the file
472+
file.pop("channel", None)
473+
438474
# send file duration to telemetry as well as
439475
# requested number of speakers in case of diarization
440476
track_pipeline_apply(self, file, **kwargs)
441477

442478
return self.apply(file, **kwargs)
443479

444-
def to(self, device: torch.device):
480+
def to(self, device: torch.device) -> Pipeline:
445481
"""Send pipeline to `device`"""
446482

447483
if not isinstance(device, torch.device):
@@ -462,3 +498,14 @@ def to(self, device: torch.device):
462498
self.device = device
463499

464500
return self
501+
502+
def cuda(self, device: torch.device | int | None = None) -> Pipeline:
503+
"""Send pipeline to (optionally specified) cuda device"""
504+
if device is None:
505+
return self.to(torch.device("cuda"))
506+
elif isinstance(device, int):
507+
return self.to(torch.device("cuda", device))
508+
else:
509+
if device.type != "cuda":
510+
raise ValueError("Expected CUDA device. Use `Pipeline.to(device)` for other devices.")
511+
return self.to(device)

0 commit comments

Comments
 (0)