Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions lhotse/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
import random
import types
import warnings
import logging
import torch
import torch.distributed as dist
from contextlib import contextmanager
from functools import partial
from typing import Any, Callable, Iterable, List, Literal, Optional, TypeVar, Union
import math

from lhotse.serialization import (
LazyMixin,
Expand Down Expand Up @@ -89,6 +93,45 @@ def mux(
)
)

@classmethod
def schedule_mux(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems duplicated, can we remove?

cls,
*manifests,
stop_early: bool = False,
start_weights: Optional[List[Union[int, float]]] = None,
end_weights: Optional[List[Union[int, float]]] = None,
seed: Union[int, Literal["trng", "randomized"]] = 0,
scheduler_type: Literal["linear", "exponential", "cosine"] = "linear",
total_iterations: Optional[int] = None,
trainer=None,
):
"""
Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time.
If one of the iterables is exhausted before the others, we will keep iterating until all iterables
are exhausted. This behavior can be changed with ``stop_early`` parameter.

:param manifests: iterables to be multiplexed.
They can be either lazy or eager, but the resulting manifest will always be lazy.
:param stop_early: should we stop the iteration as soon as we exhaust one of the manifests.
:param weights: an optional weight for each iterable, affects the probability of it being sampled.
The weights are uniform by default.
If lengths are known, it makes sense to pass them here for uniform distribution of
items in the expectation.
:param seed: the random seed, ensures deterministic order across multiple iterations.
"""
return cls(
LazyIteratorWeightedMultiplexer(
*manifests,
stop_early=stop_early,
start_weights=start_weights,
end_weights=end_weights,
seed=seed,
scheduler_type=scheduler_type,
total_iterations=total_iterations,
trainer=trainer,
)
)

@classmethod
def infinite_mux(
cls,
Expand Down Expand Up @@ -133,6 +176,51 @@ def infinite_mux(
)
)

@classmethod
def weighted_mux(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this to schedule_mux like it was named above? (I noticed this one has updated and more extensive doc)

cls,
*manifests,
stop_early: bool = False,
start_weights: Optional[List[Union[int, float]]] = None,
end_weights: Optional[List[Union[int, float]]] = None,
seed: Union[int, Literal["trng", "randomized"]] = 0,
scheduler_type: Literal["linear", "exponential", "cosine"] = "linear",
total_iterations: Optional[int] = None,
):
"""
Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time
with dynamic weight scheduling. Unlike ``mux()``, this method allows the weights to change dynamically
from start_weights to end_weights over the course of iterations, similar to a learning rate scheduler.

The weight scheduling works by interpolating between start_weights and end_weights based on the current
iteration count. This allows for gradual transitions in sampling behavior, which can be useful for
curriculum learning or adaptive sampling strategies.

:param manifests: iterables to be multiplexed.
They can be either lazy or eager, but the resulting manifest will always be lazy.
:param stop_early: should we stop the iteration as soon as we exhaust one of the manifests.
:param start_weights: initial weights for each iterable. If None, uniform weights [1, 1, ...] are used.
:param end_weights: final weights for each iterable. If None, uniform weights [1, 1, ...] are used.
:param seed: the random seed, ensures deterministic order across multiple iterations.
:param scheduler_type: the type of weight scheduling to use:
- "linear": linear interpolation between start and end weights
- "exponential": exponential interpolation (faster change at the beginning)
- "cosine": cosine interpolation (smooth transition)
:param total_iterations: total number of iterations for weight scheduling.
If None, estimated from the sum of lengths of all manifests.
"""
return cls(
LazyIteratorWeightedMultiplexer(
*manifests,
stop_early=stop_early,
start_weights=start_weights,
end_weights=end_weights,
seed=seed,
scheduler_type=scheduler_type,
total_iterations=total_iterations,
)
)

def shuffle(
self,
rng: Optional[random.Random] = None,
Expand Down Expand Up @@ -828,3 +916,150 @@ def _make_gen(reader):
with open_best(path, read_mode) as f:
count = sum(buf.count(b"\n") for buf in _make_gen(f.read))
return count


class LazyIteratorWeightedMultiplexer(Dillable):
"""
A wrapper over multiple iterators that enables to combine lazy manifests in Lhotse
with dynamic weight scheduling. During iteration, this class randomly selects the
iterable used to yield an item, but the weights change dynamically from start_weight
to end_weight over the course of iterations, similar to a learning rate scheduler.

The weight scheduling works by interpolating between start_weight and end_weight
based on the current iteration count. This allows for gradual transitions in sampling
behavior, which can be useful for curriculum learning or adaptive sampling strategies.

Since the iterables might be of different length, we provide a ``weights`` parameter
to let the user decide which iterables should be sampled more frequently than others.
When an iterable is exhausted, we will keep sampling from the other iterables, until
we exhaust them all, unless ``stop_early`` is set to ``True``.
"""

def __init__(
self,
*iterators: Iterable,
stop_early: bool = False,
start_weights: Optional[List[Union[int, float]]] = None,
end_weights: Optional[List[Union[int, float]]] = None,
seed: Union[int, Literal["trng", "randomized"]] = 0,
scheduler_type: Literal["linear", "exponential", "cosine"] = "linear",
total_iterations: Optional[int] = None,
trainer=None,
) -> None:
self.iterators = list(iterators)
self.stop_early = stop_early
self.seed = seed
self.scheduler_type = scheduler_type
self.total_iterations = total_iterations
self.current_iteration = 0
self.trainer = trainer

self.logger = logging.getLogger(__name__)

if self.trainer is not None:
self.current_iteration = self.trainer.global_step

self.logger.info(f"Trainer is not None. Setting current iteration to {self.current_iteration} and total iterations to {self.total_iterations}")

assert (
len(self.iterators) > 1
), "There have to be at least two iterables to multiplex."

# Initialize weights
if start_weights is None:
self.start_weights = [1] * len(self.iterators)
else:
self.start_weights = start_weights

if end_weights is None:
self.end_weights = [1] * len(self.iterators)
else:
self.end_weights = end_weights

assert len(self.iterators) == len(self.start_weights)
assert len(self.iterators) == len(self.end_weights)

# If total_iterations is not provided, estimate it from the sum of lengths
if self.total_iterations is None:
self.total_iterations = sum(len(it) for it in self.iterators)


def _get_current_weights(self) -> List[float]:
"""
Calculate current weights based on the scheduler type and current iteration.
"""
if self.current_iteration >= self.total_iterations:
return self.end_weights

# Calculate progress from 0 to 1
progress = self.current_iteration / self.total_iterations

# Apply scheduler function
if self.scheduler_type == "linear":
alpha = progress
elif self.scheduler_type == "exponential":
alpha = 1 - (1 - progress) ** 2
elif self.scheduler_type == "cosine":
alpha = 1 - (1 + math.cos(math.pi * progress)) / 2
else:
raise ValueError(f"Unknown scheduler_type: {self.scheduler_type}")

# Interpolate between start and end weights
current_weights = []
for start_w, end_w in zip(self.start_weights, self.end_weights):
current_w = start_w + alpha * (end_w - start_w)
current_weights.append(current_w)

return current_weights

def __iter__(self):
from lhotse.dataset.dataloading import resolve_seed

rng = random.Random(resolve_seed(self.seed))
iters = [iter(it) for it in self.iterators]
exhausted = [False for _ in range(len(iters))]

def should_continue():
if self.stop_early:
return not any(exhausted)
else:
return not all(exhausted)

while should_continue():
# Get current weights based on iteration count

if self.current_iteration == 0 and self.trainer.global_step != 0:
self.current_iteration = round(self.trainer.global_step / self.trainer.max_steps * self.total_iterations)


current_weights = self._get_current_weights()

active_indexes, active_weights = zip(
*[
(i, w)
for i, (is_exhausted, w) in enumerate(zip(exhausted, current_weights))
if not is_exhausted
]
)

if not active_indexes: # All iterators exhausted
break

idx = rng.choices(active_indexes, weights=active_weights, k=1)[0]
selected = iters[idx]
try:
item = next(selected)

self.current_iteration += 1
self.logger.info(f"Updated iteration: {self.current_iteration} | Current weights: {current_weights} | End weights: {self.end_weights}")

yield item
except StopIteration:
exhausted[idx] = True
continue

def __len__(self) -> int:
return sum(len(it) for it in self.iterators)

def __add__(self, other) -> "LazyIteratorChain":
return LazyIteratorChain(self, other)
5 changes: 3 additions & 2 deletions lhotse/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_aistore_client():

endpoint_url = os.environ["AIS_ENDPOINT"]
version = parse_version(aistore.__version__)
return aistore.Client(endpoint_url, timeout=(1, 20)), version
return aistore.Client(endpoint_url, timeout=(5,30)), version
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you roll back all changes in this file?



def save_to_yaml(data: Any, path: Pathlike) -> None:
Expand Down Expand Up @@ -788,6 +788,7 @@ class AIStoreIOBackend(IOBackend):

def open(self, identifier: str, mode: str):
client, version = get_aistore_client()

object = client.fetch_object_by_url(identifier)
if "r" in mode:
try:
Expand All @@ -798,7 +799,7 @@ def open(self, identifier: str, mode: str):
request = object.get()
if version >= parse_version("1.9.1"):
# AIStore SDK 1.9.1 supports ObjectFile for improved read fault resiliency
return request.as_file()
return request.raw()
else:
return request.raw()
if "w" in mode:
Expand Down