diff --git a/lhotse/lazy.py b/lhotse/lazy.py index f1278a347..97f9ffe6a 100644 --- a/lhotse/lazy.py +++ b/lhotse/lazy.py @@ -2,9 +2,13 @@ import random import types import warnings +import logging +import torch +import torch.distributed as dist from contextlib import contextmanager from functools import partial from typing import Any, Callable, Iterable, List, Literal, Optional, TypeVar, Union +import math from lhotse.serialization import ( LazyMixin, @@ -89,6 +93,45 @@ def mux( ) ) + @classmethod + def schedule_mux( + cls, + *manifests, + stop_early: bool = False, + start_weights: Optional[List[Union[int, float]]] = None, + end_weights: Optional[List[Union[int, float]]] = None, + seed: Union[int, Literal["trng", "randomized"]] = 0, + scheduler_type: Literal["linear", "exponential", "cosine"] = "linear", + total_iterations: Optional[int] = None, + trainer=None, + ): + """ + Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time. + If one of the iterables is exhausted before the others, we will keep iterating until all iterables + are exhausted. This behavior can be changed with ``stop_early`` parameter. + + :param manifests: iterables to be multiplexed. + They can be either lazy or eager, but the resulting manifest will always be lazy. + :param stop_early: should we stop the iteration as soon as we exhaust one of the manifests. + :param weights: an optional weight for each iterable, affects the probability of it being sampled. + The weights are uniform by default. + If lengths are known, it makes sense to pass them here for uniform distribution of + items in the expectation. + :param seed: the random seed, ensures deterministic order across multiple iterations. + """ + return cls( + LazyIteratorWeightedMultiplexer( + *manifests, + stop_early=stop_early, + start_weights=start_weights, + end_weights=end_weights, + seed=seed, + scheduler_type=scheduler_type, + total_iterations=total_iterations, + trainer=trainer, + ) + ) + @classmethod def infinite_mux( cls, @@ -133,6 +176,51 @@ def infinite_mux( ) ) + @classmethod + def weighted_mux( + cls, + *manifests, + stop_early: bool = False, + start_weights: Optional[List[Union[int, float]]] = None, + end_weights: Optional[List[Union[int, float]]] = None, + seed: Union[int, Literal["trng", "randomized"]] = 0, + scheduler_type: Literal["linear", "exponential", "cosine"] = "linear", + total_iterations: Optional[int] = None, + ): + """ + Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time + with dynamic weight scheduling. Unlike ``mux()``, this method allows the weights to change dynamically + from start_weights to end_weights over the course of iterations, similar to a learning rate scheduler. + + The weight scheduling works by interpolating between start_weights and end_weights based on the current + iteration count. This allows for gradual transitions in sampling behavior, which can be useful for + curriculum learning or adaptive sampling strategies. + + :param manifests: iterables to be multiplexed. + They can be either lazy or eager, but the resulting manifest will always be lazy. + :param stop_early: should we stop the iteration as soon as we exhaust one of the manifests. + :param start_weights: initial weights for each iterable. If None, uniform weights [1, 1, ...] are used. + :param end_weights: final weights for each iterable. If None, uniform weights [1, 1, ...] are used. + :param seed: the random seed, ensures deterministic order across multiple iterations. + :param scheduler_type: the type of weight scheduling to use: + - "linear": linear interpolation between start and end weights + - "exponential": exponential interpolation (faster change at the beginning) + - "cosine": cosine interpolation (smooth transition) + :param total_iterations: total number of iterations for weight scheduling. + If None, estimated from the sum of lengths of all manifests. + """ + return cls( + LazyIteratorWeightedMultiplexer( + *manifests, + stop_early=stop_early, + start_weights=start_weights, + end_weights=end_weights, + seed=seed, + scheduler_type=scheduler_type, + total_iterations=total_iterations, + ) + ) + def shuffle( self, rng: Optional[random.Random] = None, @@ -828,3 +916,150 @@ def _make_gen(reader): with open_best(path, read_mode) as f: count = sum(buf.count(b"\n") for buf in _make_gen(f.read)) return count + + +class LazyIteratorWeightedMultiplexer(Dillable): + """ + A wrapper over multiple iterators that enables to combine lazy manifests in Lhotse + with dynamic weight scheduling. During iteration, this class randomly selects the + iterable used to yield an item, but the weights change dynamically from start_weight + to end_weight over the course of iterations, similar to a learning rate scheduler. + + The weight scheduling works by interpolating between start_weight and end_weight + based on the current iteration count. This allows for gradual transitions in sampling + behavior, which can be useful for curriculum learning or adaptive sampling strategies. + + Since the iterables might be of different length, we provide a ``weights`` parameter + to let the user decide which iterables should be sampled more frequently than others. + When an iterable is exhausted, we will keep sampling from the other iterables, until + we exhaust them all, unless ``stop_early`` is set to ``True``. + """ + + def __init__( + self, + *iterators: Iterable, + stop_early: bool = False, + start_weights: Optional[List[Union[int, float]]] = None, + end_weights: Optional[List[Union[int, float]]] = None, + seed: Union[int, Literal["trng", "randomized"]] = 0, + scheduler_type: Literal["linear", "exponential", "cosine"] = "linear", + total_iterations: Optional[int] = None, + trainer=None, + ) -> None: + self.iterators = list(iterators) + self.stop_early = stop_early + self.seed = seed + self.scheduler_type = scheduler_type + self.total_iterations = total_iterations + self.current_iteration = 0 + self.trainer = trainer + + self.logger = logging.getLogger(__name__) + + if self.trainer is not None: + self.current_iteration = self.trainer.global_step + + self.logger.info(f"Trainer is not None. Setting current iteration to {self.current_iteration} and total iterations to {self.total_iterations}") + + assert ( + len(self.iterators) > 1 + ), "There have to be at least two iterables to multiplex." + + # Initialize weights + if start_weights is None: + self.start_weights = [1] * len(self.iterators) + else: + self.start_weights = start_weights + + if end_weights is None: + self.end_weights = [1] * len(self.iterators) + else: + self.end_weights = end_weights + + assert len(self.iterators) == len(self.start_weights) + assert len(self.iterators) == len(self.end_weights) + + # If total_iterations is not provided, estimate it from the sum of lengths + if self.total_iterations is None: + self.total_iterations = sum(len(it) for it in self.iterators) + + + def _get_current_weights(self) -> List[float]: + """ + Calculate current weights based on the scheduler type and current iteration. + """ + if self.current_iteration >= self.total_iterations: + return self.end_weights + + # Calculate progress from 0 to 1 + progress = self.current_iteration / self.total_iterations + + # Apply scheduler function + if self.scheduler_type == "linear": + alpha = progress + elif self.scheduler_type == "exponential": + alpha = 1 - (1 - progress) ** 2 + elif self.scheduler_type == "cosine": + alpha = 1 - (1 + math.cos(math.pi * progress)) / 2 + else: + raise ValueError(f"Unknown scheduler_type: {self.scheduler_type}") + + # Interpolate between start and end weights + current_weights = [] + for start_w, end_w in zip(self.start_weights, self.end_weights): + current_w = start_w + alpha * (end_w - start_w) + current_weights.append(current_w) + + return current_weights + + def __iter__(self): + from lhotse.dataset.dataloading import resolve_seed + + rng = random.Random(resolve_seed(self.seed)) + iters = [iter(it) for it in self.iterators] + exhausted = [False for _ in range(len(iters))] + + def should_continue(): + if self.stop_early: + return not any(exhausted) + else: + return not all(exhausted) + + while should_continue(): + # Get current weights based on iteration count + + if self.current_iteration == 0 and self.trainer.global_step != 0: + self.current_iteration = round(self.trainer.global_step / self.trainer.max_steps * self.total_iterations) + + + current_weights = self._get_current_weights() + + active_indexes, active_weights = zip( + *[ + (i, w) + for i, (is_exhausted, w) in enumerate(zip(exhausted, current_weights)) + if not is_exhausted + ] + ) + + if not active_indexes: # All iterators exhausted + break + + idx = rng.choices(active_indexes, weights=active_weights, k=1)[0] + selected = iters[idx] + try: + item = next(selected) + + self.current_iteration += 1 + self.logger.info(f"Updated iteration: {self.current_iteration} | Current weights: {current_weights} | End weights: {self.end_weights}") + + yield item + except StopIteration: + exhausted[idx] = True + continue + + def __len__(self) -> int: + return sum(len(it) for it in self.iterators) + + def __add__(self, other) -> "LazyIteratorChain": + return LazyIteratorChain(self, other) diff --git a/lhotse/serialization.py b/lhotse/serialization.py index f7de9e818..40952a8f2 100644 --- a/lhotse/serialization.py +++ b/lhotse/serialization.py @@ -81,7 +81,7 @@ def get_aistore_client(): endpoint_url = os.environ["AIS_ENDPOINT"] version = parse_version(aistore.__version__) - return aistore.Client(endpoint_url, timeout=(1, 20)), version + return aistore.Client(endpoint_url, timeout=(5,30)), version def save_to_yaml(data: Any, path: Pathlike) -> None: @@ -788,6 +788,7 @@ class AIStoreIOBackend(IOBackend): def open(self, identifier: str, mode: str): client, version = get_aistore_client() + object = client.fetch_object_by_url(identifier) if "r" in mode: try: @@ -798,7 +799,7 @@ def open(self, identifier: str, mode: str): request = object.get() if version >= parse_version("1.9.1"): # AIStore SDK 1.9.1 supports ObjectFile for improved read fault resiliency - return request.as_file() + return request.raw() else: return request.raw() if "w" in mode: