-
Notifications
You must be signed in to change notification settings - Fork 254
Data Balance Scheduler #1507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Data Balance Scheduler #1507
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,9 +2,13 @@ | |
| import random | ||
| import types | ||
| import warnings | ||
| import logging | ||
| import torch | ||
| import torch.distributed as dist | ||
| from contextlib import contextmanager | ||
| from functools import partial | ||
| from typing import Any, Callable, Iterable, List, Literal, Optional, TypeVar, Union | ||
| import math | ||
|
|
||
| from lhotse.serialization import ( | ||
| LazyMixin, | ||
|
|
@@ -89,6 +93,45 @@ def mux( | |
| ) | ||
| ) | ||
|
|
||
| @classmethod | ||
| def schedule_mux( | ||
| cls, | ||
| *manifests, | ||
| stop_early: bool = False, | ||
| start_weights: Optional[List[Union[int, float]]] = None, | ||
| end_weights: Optional[List[Union[int, float]]] = None, | ||
| seed: Union[int, Literal["trng", "randomized"]] = 0, | ||
| scheduler_type: Literal["linear", "exponential", "cosine"] = "linear", | ||
| total_iterations: Optional[int] = None, | ||
| trainer=None, | ||
| ): | ||
| """ | ||
| Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time. | ||
| If one of the iterables is exhausted before the others, we will keep iterating until all iterables | ||
| are exhausted. This behavior can be changed with ``stop_early`` parameter. | ||
|
|
||
| :param manifests: iterables to be multiplexed. | ||
| They can be either lazy or eager, but the resulting manifest will always be lazy. | ||
| :param stop_early: should we stop the iteration as soon as we exhaust one of the manifests. | ||
| :param weights: an optional weight for each iterable, affects the probability of it being sampled. | ||
| The weights are uniform by default. | ||
| If lengths are known, it makes sense to pass them here for uniform distribution of | ||
| items in the expectation. | ||
| :param seed: the random seed, ensures deterministic order across multiple iterations. | ||
| """ | ||
| return cls( | ||
| LazyIteratorWeightedMultiplexer( | ||
| *manifests, | ||
| stop_early=stop_early, | ||
| start_weights=start_weights, | ||
| end_weights=end_weights, | ||
| seed=seed, | ||
| scheduler_type=scheduler_type, | ||
| total_iterations=total_iterations, | ||
| trainer=trainer, | ||
| ) | ||
| ) | ||
|
|
||
| @classmethod | ||
| def infinite_mux( | ||
| cls, | ||
|
|
@@ -133,6 +176,51 @@ def infinite_mux( | |
| ) | ||
| ) | ||
|
|
||
| @classmethod | ||
| def weighted_mux( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we rename this to |
||
| cls, | ||
| *manifests, | ||
| stop_early: bool = False, | ||
| start_weights: Optional[List[Union[int, float]]] = None, | ||
| end_weights: Optional[List[Union[int, float]]] = None, | ||
| seed: Union[int, Literal["trng", "randomized"]] = 0, | ||
| scheduler_type: Literal["linear", "exponential", "cosine"] = "linear", | ||
| total_iterations: Optional[int] = None, | ||
| ): | ||
| """ | ||
| Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time | ||
| with dynamic weight scheduling. Unlike ``mux()``, this method allows the weights to change dynamically | ||
| from start_weights to end_weights over the course of iterations, similar to a learning rate scheduler. | ||
|
|
||
| The weight scheduling works by interpolating between start_weights and end_weights based on the current | ||
| iteration count. This allows for gradual transitions in sampling behavior, which can be useful for | ||
| curriculum learning or adaptive sampling strategies. | ||
|
|
||
| :param manifests: iterables to be multiplexed. | ||
| They can be either lazy or eager, but the resulting manifest will always be lazy. | ||
| :param stop_early: should we stop the iteration as soon as we exhaust one of the manifests. | ||
| :param start_weights: initial weights for each iterable. If None, uniform weights [1, 1, ...] are used. | ||
| :param end_weights: final weights for each iterable. If None, uniform weights [1, 1, ...] are used. | ||
| :param seed: the random seed, ensures deterministic order across multiple iterations. | ||
| :param scheduler_type: the type of weight scheduling to use: | ||
| - "linear": linear interpolation between start and end weights | ||
| - "exponential": exponential interpolation (faster change at the beginning) | ||
| - "cosine": cosine interpolation (smooth transition) | ||
| :param total_iterations: total number of iterations for weight scheduling. | ||
| If None, estimated from the sum of lengths of all manifests. | ||
| """ | ||
| return cls( | ||
| LazyIteratorWeightedMultiplexer( | ||
| *manifests, | ||
| stop_early=stop_early, | ||
| start_weights=start_weights, | ||
| end_weights=end_weights, | ||
| seed=seed, | ||
| scheduler_type=scheduler_type, | ||
| total_iterations=total_iterations, | ||
| ) | ||
| ) | ||
|
|
||
| def shuffle( | ||
| self, | ||
| rng: Optional[random.Random] = None, | ||
|
|
@@ -828,3 +916,150 @@ def _make_gen(reader): | |
| with open_best(path, read_mode) as f: | ||
| count = sum(buf.count(b"\n") for buf in _make_gen(f.read)) | ||
| return count | ||
|
|
||
|
|
||
| class LazyIteratorWeightedMultiplexer(Dillable): | ||
| """ | ||
| A wrapper over multiple iterators that enables to combine lazy manifests in Lhotse | ||
| with dynamic weight scheduling. During iteration, this class randomly selects the | ||
| iterable used to yield an item, but the weights change dynamically from start_weight | ||
| to end_weight over the course of iterations, similar to a learning rate scheduler. | ||
|
|
||
| The weight scheduling works by interpolating between start_weight and end_weight | ||
| based on the current iteration count. This allows for gradual transitions in sampling | ||
| behavior, which can be useful for curriculum learning or adaptive sampling strategies. | ||
|
|
||
| Since the iterables might be of different length, we provide a ``weights`` parameter | ||
| to let the user decide which iterables should be sampled more frequently than others. | ||
| When an iterable is exhausted, we will keep sampling from the other iterables, until | ||
| we exhaust them all, unless ``stop_early`` is set to ``True``. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *iterators: Iterable, | ||
| stop_early: bool = False, | ||
| start_weights: Optional[List[Union[int, float]]] = None, | ||
| end_weights: Optional[List[Union[int, float]]] = None, | ||
| seed: Union[int, Literal["trng", "randomized"]] = 0, | ||
| scheduler_type: Literal["linear", "exponential", "cosine"] = "linear", | ||
| total_iterations: Optional[int] = None, | ||
| trainer=None, | ||
| ) -> None: | ||
| self.iterators = list(iterators) | ||
| self.stop_early = stop_early | ||
| self.seed = seed | ||
| self.scheduler_type = scheduler_type | ||
| self.total_iterations = total_iterations | ||
| self.current_iteration = 0 | ||
| self.trainer = trainer | ||
|
|
||
| self.logger = logging.getLogger(__name__) | ||
|
|
||
| if self.trainer is not None: | ||
| self.current_iteration = self.trainer.global_step | ||
|
|
||
| self.logger.info(f"Trainer is not None. Setting current iteration to {self.current_iteration} and total iterations to {self.total_iterations}") | ||
|
|
||
| assert ( | ||
| len(self.iterators) > 1 | ||
| ), "There have to be at least two iterables to multiplex." | ||
|
|
||
| # Initialize weights | ||
| if start_weights is None: | ||
| self.start_weights = [1] * len(self.iterators) | ||
| else: | ||
| self.start_weights = start_weights | ||
|
|
||
| if end_weights is None: | ||
| self.end_weights = [1] * len(self.iterators) | ||
| else: | ||
| self.end_weights = end_weights | ||
|
|
||
| assert len(self.iterators) == len(self.start_weights) | ||
| assert len(self.iterators) == len(self.end_weights) | ||
|
|
||
| # If total_iterations is not provided, estimate it from the sum of lengths | ||
| if self.total_iterations is None: | ||
| self.total_iterations = sum(len(it) for it in self.iterators) | ||
|
|
||
|
|
||
| def _get_current_weights(self) -> List[float]: | ||
| """ | ||
| Calculate current weights based on the scheduler type and current iteration. | ||
| """ | ||
| if self.current_iteration >= self.total_iterations: | ||
| return self.end_weights | ||
|
|
||
| # Calculate progress from 0 to 1 | ||
| progress = self.current_iteration / self.total_iterations | ||
|
|
||
| # Apply scheduler function | ||
| if self.scheduler_type == "linear": | ||
| alpha = progress | ||
| elif self.scheduler_type == "exponential": | ||
| alpha = 1 - (1 - progress) ** 2 | ||
| elif self.scheduler_type == "cosine": | ||
| alpha = 1 - (1 + math.cos(math.pi * progress)) / 2 | ||
| else: | ||
| raise ValueError(f"Unknown scheduler_type: {self.scheduler_type}") | ||
|
|
||
| # Interpolate between start and end weights | ||
| current_weights = [] | ||
| for start_w, end_w in zip(self.start_weights, self.end_weights): | ||
| current_w = start_w + alpha * (end_w - start_w) | ||
| current_weights.append(current_w) | ||
|
|
||
| return current_weights | ||
|
|
||
| def __iter__(self): | ||
| from lhotse.dataset.dataloading import resolve_seed | ||
|
|
||
| rng = random.Random(resolve_seed(self.seed)) | ||
| iters = [iter(it) for it in self.iterators] | ||
| exhausted = [False for _ in range(len(iters))] | ||
|
|
||
| def should_continue(): | ||
| if self.stop_early: | ||
| return not any(exhausted) | ||
| else: | ||
| return not all(exhausted) | ||
|
|
||
| while should_continue(): | ||
| # Get current weights based on iteration count | ||
|
|
||
| if self.current_iteration == 0 and self.trainer.global_step != 0: | ||
| self.current_iteration = round(self.trainer.global_step / self.trainer.max_steps * self.total_iterations) | ||
|
|
||
|
|
||
| current_weights = self._get_current_weights() | ||
|
|
||
| active_indexes, active_weights = zip( | ||
| *[ | ||
| (i, w) | ||
| for i, (is_exhausted, w) in enumerate(zip(exhausted, current_weights)) | ||
| if not is_exhausted | ||
| ] | ||
| ) | ||
|
|
||
| if not active_indexes: # All iterators exhausted | ||
| break | ||
|
|
||
| idx = rng.choices(active_indexes, weights=active_weights, k=1)[0] | ||
| selected = iters[idx] | ||
| try: | ||
| item = next(selected) | ||
|
|
||
| self.current_iteration += 1 | ||
| self.logger.info(f"Updated iteration: {self.current_iteration} | Current weights: {current_weights} | End weights: {self.end_weights}") | ||
|
|
||
| yield item | ||
| except StopIteration: | ||
| exhausted[idx] = True | ||
| continue | ||
|
|
||
| def __len__(self) -> int: | ||
| return sum(len(it) for it in self.iterators) | ||
|
|
||
| def __add__(self, other) -> "LazyIteratorChain": | ||
| return LazyIteratorChain(self, other) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,7 +81,7 @@ def get_aistore_client(): | |
|
|
||
| endpoint_url = os.environ["AIS_ENDPOINT"] | ||
| version = parse_version(aistore.__version__) | ||
| return aistore.Client(endpoint_url, timeout=(1, 20)), version | ||
| return aistore.Client(endpoint_url, timeout=(5,30)), version | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you roll back all changes in this file? |
||
|
|
||
|
|
||
| def save_to_yaml(data: Any, path: Pathlike) -> None: | ||
|
|
@@ -788,6 +788,7 @@ class AIStoreIOBackend(IOBackend): | |
|
|
||
| def open(self, identifier: str, mode: str): | ||
| client, version = get_aistore_client() | ||
|
|
||
| object = client.fetch_object_by_url(identifier) | ||
| if "r" in mode: | ||
| try: | ||
|
|
@@ -798,7 +799,7 @@ def open(self, identifier: str, mode: str): | |
| request = object.get() | ||
| if version >= parse_version("1.9.1"): | ||
| # AIStore SDK 1.9.1 supports ObjectFile for improved read fault resiliency | ||
| return request.as_file() | ||
| return request.raw() | ||
| else: | ||
| return request.raw() | ||
| if "w" in mode: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems duplicated, can we remove?