From 3cab5334eda3642aff2c8fbf5a19482b1d6ece52 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 25 Jun 2025 12:41:30 -0400 Subject: [PATCH 01/25] first commit --- recipes/configs/llama3_2/3B_full.yaml | 40 +- recipes/full_finetune_distributed.py | 500 ++++++++---------- .../torchtune/data/test_metrics_aggregator.py | 149 ++++++ .../torchtune/data/test_metrics_transform.py | 54 ++ tests/torchtune/datasets/test_hf_iterable.py | 339 ++++++++++++ tests/torchtune/datasets/test_interleaved.py | 162 ++++++ torchtune/data/__init__.py | 12 + torchtune/data/_aggregator.py | 342 ++++++++++++ torchtune/data/_metrics.py | 95 ++++ torchtune/datasets/__init__.py | 43 +- torchtune/datasets/_alpaca.py | 65 ++- torchtune/datasets/_hf_iterable.py | 271 ++++++++++ torchtune/datasets/_interleaved.py | 115 ++++ torchtune/datasets/_iterable_base.py | 37 ++ torchtune/datasets/_sft.py | 97 +++- torchtune/datasets/_slimorca.py | 69 ++- .../checkpointing/_checkpoint_client.py | 6 + 17 files changed, 2088 insertions(+), 308 deletions(-) create mode 100644 tests/torchtune/data/test_metrics_aggregator.py create mode 100644 tests/torchtune/data/test_metrics_transform.py create mode 100644 tests/torchtune/datasets/test_hf_iterable.py create mode 100644 tests/torchtune/datasets/test_interleaved.py create mode 100644 torchtune/data/_aggregator.py create mode 100644 torchtune/data/_metrics.py create mode 100644 torchtune/datasets/_hf_iterable.py create mode 100644 torchtune/datasets/_interleaved.py create mode 100644 torchtune/datasets/_iterable_base.py diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index bb765f1917..5534b305ac 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -26,21 +26,28 @@ tokenizer: path: /tmp/Llama-3.2-3B-Instruct/original/tokenizer.model max_seq_len: null -# Dataset and Sampler +# Dataloader +dataloader: + batch_size: 4 + # num_workers and pin_memory can be added here if needed + +# Dataset - now a list to support multiple weighted sources dataset: - _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False # True increases speed - split: train[:95%] + - _component_: torchtune.datasets.slimorca_iterable_dataset + shuffle_buffer_size: 1000 + weight: 0.8 + - _component_: torchtune.datasets.alpaca_iterable_dataset + shuffle_buffer_size: 1000 + weight: 0.2 + +# Packing (TBD by follow up PR) +# packing: +# _component_: torchtune.datasets.packing.SFTPacking +# max_seq_len: 8192 + seed: null -shuffle: True -batch_size: 4 -# Validation -run_val_every_n_steps: null # Change to an integer to enable validation every N steps -dataset_val: - _component_: torchtune.datasets.alpaca_cleaned_dataset - split: train[95%:] -batch_size_val: ${batch_size} +# Validation not supported yet with iterable datasets # Model Arguments model: @@ -65,10 +72,11 @@ optimizer: loss: _component_: torchtune.modules.loss.LinearCrossEntropyLoss -# Training -epochs: 1 -max_steps_per_epoch: null -gradient_accumulation_steps: 8 # Use to increase effective batch size +# Training - now step-based +num_training_steps: 100 # Total number of training steps to run +save_every_n_steps: 200 # Save a checkpoint every N steps. Using 200 to avoid ckpt. +gradient_accumulation_steps: 1 +dataset_metrics_log_freq: 10 # Log dataset-specific metrics every N steps # Environment device: cuda diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 10e0aaeb24..f34ccc6a7e 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -9,7 +9,7 @@ import time from functools import partial -from typing import Any, Optional, Union +from typing import Any, Dict, List, Optional, Union from warnings import warn import torch @@ -25,8 +25,8 @@ from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchtune import config, modules, training, utils from torchtune.config._utils import _get_component_from_path -from torchtune.data import padded_collate_packed -from torchtune.datasets import ConcatDataset +from torchtune.data import padded_collate_packed, MetricsAggregator +from torchtune.datasets import ConcatDataset, InterleavedDataset from torchtune.modules.embedding_utils import resize_token_embeddings from torchtune.modules.loss import SFTLoss from torchtune.modules.moe import utils as moe_utils @@ -207,7 +207,7 @@ def __init__(self, cfg: DictConfig) -> None: self._checkpoint_client = CheckpointClient(cfg) self._enable_fp8_training = cfg.get("enable_fp8_training", False) self._fp8_recipe_name = cfg.get("fp8_recipe_name", None) - self.save_every_n_steps = cfg.get("save_every_n_steps") + self.save_every_n_steps = cfg.get("save_every_n_steps", None) self._run_val_every_n_steps = cfg.get("run_val_every_n_steps", None) if self._run_val_every_n_steps is not None: @@ -273,18 +273,20 @@ def __init__(self, cfg: DictConfig) -> None: self.seed = training.set_seed( seed=cfg.seed, debug_mode=cfg.get("cudnn_deterministic_mode", None) ) - self.epochs_run = 0 - self.total_epochs = cfg.epochs - self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - + + # Step-based training support + self.num_training_steps = cfg.num_training_steps + self._dataset_metrics_log_freq = cfg.get("dataset_metrics_log_freq", 100) + self._metrics_aggregator = None # Will be initialized in setup + def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: """ Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict[training.EPOCHS_KEY] - self.global_step = ckpt_dict[training.STEPS_KEY] + # The new format stores steps directly + self.global_step = ckpt_dict["steps_run"] # on mismatch, warn the user and prevent the override if self.seed != ckpt_dict[training.SEED_KEY]: @@ -295,23 +297,6 @@ def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: ) ) self.seed = ckpt_dict[training.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: - warn( - message=( - "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" - ) - ) - self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] - - # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: - warn( - message=( - "Config value for total_epochs does not match the checkpoint value, " - f"using the config value: {self.total_epochs}" - ) - ) except KeyError as e: raise KeyError( @@ -324,6 +309,9 @@ def setup(self, cfg: DictConfig) -> None: Setup the recipe. This includes training state (if resume_from_checkpoint is True), model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. """ + if cfg.get("dataset_val") is not None: + raise NotImplementedError("Validation is not supported yet with iterable datasets.") + if self.fsdp_cpu_offload: # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU @@ -434,12 +422,14 @@ def setup(self, cfg: DictConfig) -> None: utils.log_rank_zero(self._logger, "Loss is initialized.") + # Initialize metrics aggregator for dataset metrics tracking + self._metrics_aggregator = MetricsAggregator() + # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, - shuffle=cfg.shuffle, batch_size=cfg.batch_size, collate_fn=collate_name, dataloader_state_dict=( @@ -457,7 +447,6 @@ def setup(self, cfg: DictConfig) -> None: cfg_dataset=cfg.dataset_val, batch_size=batch_size_val, collate_fn=collate_name, - shuffle=False, dataloader_state_dict=( state_dict[training.VAL_DATALOADER_KEY] if training.VAL_DATALOADER_KEY in state_dict @@ -465,38 +454,13 @@ def setup(self, cfg: DictConfig) -> None: ), ) - # Finally update the recipe state which can only be correctly set after all of the - # other components have been initialized and updated. - # - # Number of training steps in each epoch depends on the number of batches produced - # by the dataloader, the max_steps_per_epoch param set by the user and the - # gradient_accumulation_steps param. This value is used for logging and tracking - # training state. The computation should happen after the dataloader has been setup - self._steps_per_epoch = ( - len(self._dataloader) // self._gradient_accumulation_steps - ) - if ( - self.max_steps_per_epoch is not None - and self.max_steps_per_epoch < self._steps_per_epoch - ): - self._steps_per_epoch = self.max_steps_per_epoch - - if self.save_every_n_steps is None: - self.save_every_n_steps = self._steps_per_epoch - self.checkpoint_dir_prefix = "epoch" - else: - self.checkpoint_dir_prefix = "step" - - if ( - self._resume_from_checkpoint - and self.global_step % self._steps_per_epoch == 0 - ): - list(self._dataloader) + # Set checkpoint dir prefix to step-based + self.checkpoint_dir_prefix = "step" # Setup lr scheduler self._lr_scheduler = self._setup_lr_scheduler( cfg_lr_scheduler=cfg.get("lr_scheduler", None), - num_training_steps=self.total_epochs * self._steps_per_epoch, + num_training_steps=self.num_training_steps, last_epoch=self.global_step - 1, ) @@ -799,53 +763,69 @@ def _setup_optimizer( def _setup_data( self, - cfg_dataset: DictConfig, - shuffle: bool, + cfg_dataset: Union[DictConfig, ListConfig], batch_size: int, collate_fn: str, dataloader_state_dict: Optional[dict[str, Any]] = None, ) -> StatefulDataLoader: """ - All data related setup happens here. This recipe currently supports only - map-style datasets. If a state_dict is provided (meaning we are resuming a training run), - it is loaded into the dataloader. + Set up the dataloader for iterable datasets. """ - if isinstance(cfg_dataset, ListConfig): - datasets = [ - config.instantiate(single_cfg_dataset, self._tokenizer) - for single_cfg_dataset in cfg_dataset - ] - ds = ConcatDataset(datasets=datasets) - packed = getattr(ds, "packed", False) - else: - ds = config.instantiate(cfg_dataset, self._tokenizer) - packed = cfg_dataset.get("packed", False) - # Instantiate collate_fn - if "left_pad_sequence" in collate_fn: - raise RuntimeError("left_pad_sequence collator is only for inference.") - collate_fn = _get_component_from_path(collate_fn) - - sampler = StatefulDistributedSampler( - ds, num_replicas=self.dp_degree, rank=self.dp_rank, shuffle=shuffle, seed=0 + # 1. Create all datasets + iterable_datasets = [] + weights = [] + cfg_dataset_list = cfg_dataset + if not isinstance(cfg_dataset_list, ListConfig): + cfg_dataset_list = [cfg_dataset_list] + + for ds_cfg in cfg_dataset_list: + ds = config.instantiate(ds_cfg, model_transform=self._tokenizer) + iterable_datasets.append(ds) + weights.append(ds_cfg.get("weight", 1.0)) + + # 2. Interleave datasets if any + if len(iterable_datasets) > 1: + ds = InterleavedDataset( + datasets=iterable_datasets, + weights=weights, + seed=self.seed, + ) + else: + ds = iterable_datasets[0] + + # 3. Apply packing + # TODO: follow up PR + packed = False + + # 4. Define a collate function wrapper to handle metrics + base_collate_fn = ( + padded_collate_packed + if packed + else _get_component_from_path(collate_fn) ) + + def _collate_with_metrics_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + # TODO: handling of metrics should prob be done in collate_fn. + # putting this here for now to avoid making more changes to this PR. + all_metrics = [] + clean_batch = [] + for sample in batch: + if "metrics" in sample: + all_metrics.extend(sample.pop("metrics")) + clean_batch.append(sample) + + collated_batch = base_collate_fn(clean_batch) + collated_batch["metrics"] = all_metrics + return collated_batch + + # 5. Create DataLoader dataloader = StatefulDataLoader( dataset=ds, batch_size=batch_size, - sampler=sampler, - collate_fn=( - partial( - collate_fn, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - pad_to_multiple_of=self.parallel_dims.min_seq_len_divisor, - ) - if not packed - else padded_collate_packed - ), - # dropping last avoids shape issues with compile + flex attention - drop_last=True, + collate_fn=_collate_with_metrics_wrapper, ) + if dataloader_state_dict is not None: dataloader.load_state_dict(dataloader_state_dict) @@ -917,32 +897,29 @@ def validate(self) -> dict[str, float]: self._model.train() return log_dict - def save_checkpoint(self, *, epoch: int, full_tensors: bool): - if self.global_step % self._steps_per_epoch == 0: - epoch += 1 - + def save_checkpoint(self, *, epoch: int, step: int, full_tensors: bool): + """Save checkpoint based on global step.""" self._checkpoint_client.save_checkpoint( model=self._model, - optimizer=( - self._optimizer - if not self._optimizer_in_bwd - else self._optim_ckpt_wrapper - ), - training_progress=TrainingProgress( + optimizer=(self._optimizer if not self._optimizer_in_bwd else self._optim_ckpt_wrapper), + training_progress = TrainingProgress( seed=self.seed, - epochs_run=epoch, - total_epochs=self.total_epochs, - max_steps_per_epoch=self.max_steps_per_epoch, + epochs_run=0, # TODO: not needed. To be deprecated. + total_epochs=1, # TODO: not needed. To be deprecated. + max_steps_per_epoch=-1, # TODO: not needed. To be deprecated. steps_run=self.global_step, - total_training_steps=self.total_epochs * self._steps_per_epoch, + total_training_steps=self.num_training_steps, dataloader_state_dict=self._dataloader.state_dict(), val_dataloader_state_dict=( self._val_dataloader.state_dict() if self._val_dataloader is not None else {} ), + #FIXME: add to load_ckpt and TrainingProgress too + metrics_aggregator_state_dict=self._metrics_aggregator.state_dict(), ), - epoch=epoch, + epoch=epoch, # TODO: not needed. To be deprecated. + step=step, single_device=False, full_tensors=full_tensors, dir_prefix=self.checkpoint_dir_prefix, @@ -968,180 +945,155 @@ def train(self) -> None: num_tokens = 0 self._profiler.start() - # self.epochs_run should be non-zero when we're resuming from a checkpoint - for curr_epoch in range(self.epochs_run, self.total_epochs): - inner_step_count = self.global_step % self._steps_per_epoch - pbar = tqdm( - initial=inner_step_count, - total=self._steps_per_epoch, - desc=f"{self.epochs_run}|{self.global_step}", - ) - - # Get iterator for the dataloader - self._dataloader.sampler.set_epoch(curr_epoch) - dataloader_iter = iter(self._dataloader) - batch_count = 0 - - # Continue looping until we reach max steps or exhaust the dataset - while inner_step_count < self._steps_per_epoch: - # Try to get the next batch, break if we've reached the end of the dataset - try: - batch = next(dataloader_iter) - except StopIteration: - break - - # Start tracking CUDA memory for active steps for just the first epoch - if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and batch_count - == self.profiler_wait_steps + self.profiler_warmup_steps - and self._device.type == "cuda" - ): - torch.cuda.memory._record_memory_history() - - utils.batch_to_device(batch, self._device) - - # Calculate the number of unmasked tokens in the current batch - # and increment the total number of tokens seen in the step - current_num_tokens = ( - batch["labels"] != self._loss_fn.ignore_index - ).sum() - num_tokens += current_num_tokens - - with self.train_context( - self.context_parallel_manager(list(batch.values())) - ): - # Loss is normalized by default so we multiply by the number of tokens - # This way we can normalize by the total number of tokens if we're accumulating gradients - current_loss = self._loss_step(batch) * current_num_tokens - running_loss += current_loss - # For optimizer in backward, we need to normalize before calling backward - # This case and gradient accumulation are mutually exclusive - if self._optimizer_in_bwd: - torch.distributed.all_reduce(num_tokens) - torch.distributed.all_reduce(running_loss) - current_loss = current_loss * (self.dp_degree / num_tokens) - current_loss.backward() - - # Optimizer step (if not fused in backward call) - if (batch_count + 1) % self._gradient_accumulation_steps == 0: - if not self._optimizer_in_bwd: - # Get total number of tokens across all ranks to normalize gradients - torch.distributed.all_reduce(num_tokens) - # This will ensure that the logged loss matches what we're optimizing - torch.distributed.all_reduce(running_loss) - - # Manually scale the gradients from unnormalized loss by total # of tokens - self._grad_scaler( - list(self._model.parameters()), - self.world_size / num_tokens, - False if self.parallel_dims.tp_enabled else None, - ) - - if self._clip_grad_norm is not None: - grad_norm = torch.nn.utils.clip_grad_norm_( - self._model.parameters(), - max_norm=float(self._clip_grad_norm), - ) - # If sharded, collect the DTensor here - if isinstance(grad_norm, DTensor): - grad_norm = grad_norm.full_tensor() - self._optimizer.step() - self._optimizer.zero_grad(set_to_none=True) - - # Step the learning rate scheduler - if self._lr_scheduler is not None: - self._lr_scheduler.step() - - self.global_step += 1 - inner_step_count += 1 - - # If float8 training is enabled, perform a single all-reduce to compute the - # scale for all float8 parameters efficiently instead of doing many small - # all-reduces for each parameter - if ( - self._enable_fp8_training - and is_fp8_tensorwise_scaling(self._fp8_recipe_name) - and self.dp_degree > 1 - ): - precompute_float8_dynamic_scale_for_fsdp(self._model) - - loss_to_log = running_loss.detach().item() / num_tokens - pbar.update(1) - pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + + pbar = tqdm(initial=self.global_step, total=self.num_training_steps, desc="Training") + + dataloader_iter = iter(self._dataloader) + batch_count = 0 + + while self.global_step < self.num_training_steps: + try: + batch = next(dataloader_iter) + except StopIteration: + self._logger.warning("Dataloader iterator exhausted unexpectedly. Ending training.") + break + + if "metrics" in batch: + self._metrics_aggregator.update(batch.pop("metrics")) + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and self.profiler_profile_memory + and batch_count + == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" + ): + torch.cuda.memory._record_memory_history() + + utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + with self.train_context( + self.context_parallel_manager(list(batch.values())) + ): + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + current_loss = self._loss_step(batch) * current_num_tokens + running_loss += current_loss + # For optimizer in backward, we need to normalize before calling backward + # This case and gradient accumulation are mutually exclusive + if self._optimizer_in_bwd: + torch.distributed.all_reduce(num_tokens) + torch.distributed.all_reduce(running_loss) + current_loss = current_loss * (self.dp_degree / num_tokens) + current_loss.backward() + + # Optimizer step (if not fused in backward call) + if (batch_count + 1) % self._gradient_accumulation_steps == 0: + if not self._optimizer_in_bwd: + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + + # Manually scale the gradients from unnormalized loss by total # of tokens + self._grad_scaler( + list(self._model.parameters()), + self.world_size / num_tokens, + False if self.parallel_dims.tp_enabled else None, ) - # Log per-step metrics - if ( - self.global_step % self._log_every_n_steps == 0 - and self._is_rank_zero - ): - time_per_step = time.perf_counter() - t0 - log_dict = { - "loss": loss_to_log, - "lr": get_lr( - ( - self._optimizer - if not self._optimizer_in_bwd - else self._optim_ckpt_wrapper - ), - ), - "tokens_per_second_per_gpu": ( - num_tokens / self.parallel_dims.non_data_parallel_size - ) - / (time_per_step * self.world_size), - } - if self._log_peak_memory_stats: - log_dict.update( - training.get_memory_stats(device=self._device) - ) - if self._clip_grad_norm is not None: - log_dict.update({"grad_norm": grad_norm}) - self._metric_logger.log_dict( - log_dict, - step=self.global_step, + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), ) - - # Save checkpoint if specified by user - if self.global_step % self.save_every_n_steps == 0: - self.save_checkpoint(epoch=curr_epoch, full_tensors=False) - - # Reset running stats for the next step - running_loss = 0 - num_tokens = 0 - t0 = time.perf_counter() - - # Stop tracking CUDA memory now that active steps are complete + # If sharded, collect the DTensor here + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.full_tensor() + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Step the learning rate scheduler + if self._lr_scheduler is not None: + self._lr_scheduler.step() + + self.global_step += 1 + # If float8 training is enabled, perform a single all-reduce to compute the + # scale for all float8 parameters efficiently instead of doing many small + # all-reduces for each parameter if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and batch_count - == self.profiler_wait_steps - + self.profiler_warmup_steps - + self.profiler_active_steps - and self._device.type == "cuda" + self._enable_fp8_training + and is_fp8_tensorwise_scaling(self._fp8_recipe_name) + and self.dp_degree > 1 ): - torch.cuda.memory._record_memory_history(enabled=None) - - self._profiler.step() - batch_count += 1 - - # Run validation after gradient update - if ( - self._run_val_every_n_steps is not None - and self.global_step % self._run_val_every_n_steps == 0 - ): - pbar.refresh() - self.validate() - - self.epochs_run += 1 + precompute_float8_dynamic_scale_for_fsdp(self._model) + + loss_to_log = running_loss.detach().item() / num_tokens + pbar.update(1) + pbar.set_description(f"Step: {self.global_step}|Loss: {loss_to_log:.4f}") + + # Log per-step metrics + if self.global_step % self._log_every_n_steps == 0 and self._is_rank_zero: + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": get_lr(self._optimizer if not self._optimizer_in_bwd else self._optim_ckpt_wrapper), + "tokens_per_second_per_gpu": (num_tokens / self.parallel_dims.non_data_parallel_size) / (time_per_step * self.world_size), + } + if self._log_peak_memory_stats: + log_dict.update(training.get_memory_stats(device=self._device)) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict(log_dict, step=self.global_step) + + # Log dataset metrics + # #TODO: it requires all_gather. Should we keep a separate log_freq for this? + if self.global_step % self._dataset_metrics_log_freq == 0 and self._is_rank_zero: + dataset_metrics = self._metrics_aggregator.get_metrics_for_logging(prefix="train") + self._metric_logger.log_dict(dataset_metrics, step=self.global_step) + + # Save checkpoint if specified by user + if self.save_every_n_steps is not None and self.global_step % self.save_every_n_steps == 0: + self.save_checkpoint(epoch=0, step=self.global_step, full_tensors=False) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and self.profiler_profile_memory + and batch_count + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + and self._device.type == "cuda" + ): + torch.cuda.memory._record_memory_history(enabled=None) + + self._profiler.step() + batch_count += 1 + + # Run validation after gradient update + if ( + self._run_val_every_n_steps is not None + and self.global_step % self._run_val_every_n_steps == 0 + ): + pbar.refresh() + self.validate() self._profiler.stop() - self.save_checkpoint(epoch=curr_epoch, full_tensors=True) + self.save_checkpoint(epoch=0, step=self.global_step, full_tensors=True) def cleanup(self) -> None: if self._is_rank_zero: diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py new file mode 100644 index 0000000000..382d968704 --- /dev/null +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import pytest +from unittest.mock import patch + +from torchtune.data import AggregationType, Metric, MetricsAggregator + + +class TestMetricsAggregator: + """Focused tests for MetricsAggregator functionality.""" + + @pytest.mark.parametrize( + "agg_type,test_values,expected", + [ + (AggregationType.SUM, [1, 2, 3, 4], 10), + (AggregationType.MEAN, [10, 20, 30, 40], 25.0), + (AggregationType.MAX, [-5, 10, 3, 15], 15), + (AggregationType.MIN, [5, -2, 8, 1], -2), + ( + AggregationType.CATEGORICAL_COUNT, + ["A", "B", "A", "C", "A"], + {"A": 3, "B": 1, "C": 1}, + ), + ], + ) + def test_aggregation_types(self, agg_type, test_values, expected): + """Tests each `AggregationType` to ensure it computes the correct value.""" + aggregator = MetricsAggregator() + + metrics = [ + Metric(dataset_name="test", name="metric", value=val, agg_type=agg_type) + for val in test_values + ] + aggregator.update(metrics) + + result = aggregator.get_metrics_for_logging() + + if agg_type == AggregationType.CATEGORICAL_COUNT: + for category, count in expected.items(): + assert result[f"test/metric_{category}_count"] == count + else: + assert result["test/metric"] == expected + + def test_distribution_metrics(self): + """Tests that `AggregationType.DISTRIBUTION` computes all expected statistics (mean, min, max, p50).""" + aggregator = MetricsAggregator() + values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + metrics = [ + Metric("test", "dist_metric", val, AggregationType.DISTRIBUTION) + for val in values + ] + aggregator.update(metrics) + + result = aggregator.get_metrics_for_logging(prefix="train") + + # Verify distribution statistics + assert result["train/test/dist_metric_mean"] == 5.5 + assert result["train/test/dist_metric_min"] == 1 + assert result["train/test/dist_metric_max"] == 10 + assert result["train/test/dist_metric_p50"] == 5 # Median of 1-10 is 5 (index 4, value 5) + + def test_state_management(self): + """Test aggregator checkpointing and restoration.""" + # Create aggregator with some state + aggregator1 = MetricsAggregator() + initial_metrics = [ + Metric("ds1", "counter", 10, AggregationType.SUM), + Metric("ds1", "average", 5.0, AggregationType.MEAN), + Metric("ds2", "categories", "X", AggregationType.CATEGORICAL_COUNT), + ] + aggregator1.update(initial_metrics) + + # Save state + state = aggregator1.state_dict() + + # Create new aggregator and restore state + aggregator2 = MetricsAggregator() + aggregator2.load_state_dict(state) + + # Both should have identical metrics + metrics1 = aggregator1.get_metrics_for_logging() + metrics2 = aggregator2.get_metrics_for_logging() + assert metrics1 == metrics2 + + # Continue updating both - should remain identical + additional_metrics = [ + Metric("ds1", "counter", 5, AggregationType.SUM), + Metric("ds1", "average", 15.0, AggregationType.MEAN), + ] + aggregator1.update(additional_metrics) + aggregator2.update(additional_metrics) + + final_metrics1 = aggregator1.get_metrics_for_logging() + final_metrics2 = aggregator2.get_metrics_for_logging() + assert final_metrics1 == final_metrics2 + + # Verify expected values + assert final_metrics1["ds1/counter"] == 15 # 10 + 5 + assert final_metrics1["ds1/average"] == 10.0 # (5 + 15) / 2 + + def test_multiple_datasets(self): + """Test that metrics from multiple datasets are correctly namespaced.""" + aggregator = MetricsAggregator() + + metrics = [ + Metric("dataset1", "samples", 100, AggregationType.SUM), + Metric("dataset2", "samples", 200, AggregationType.SUM), + Metric("dataset1", "tokens", 1000, AggregationType.SUM), + Metric("dataset2", "tokens", 2000, AggregationType.SUM), + ] + aggregator.update(metrics) + + result = aggregator.get_metrics_for_logging(prefix="train") + + assert result["train/dataset1/samples"] == 100 + assert result["train/dataset2/samples"] == 200 + assert result["train/dataset1/tokens"] == 1000 + assert result["train/dataset2/tokens"] == 2000 + + def test_empty_aggregator(self): + """Test that empty aggregator returns empty metrics.""" + aggregator = MetricsAggregator() + result = aggregator.get_metrics_for_logging() + assert result == {} + + def test_prefix_handling(self): + """Test that prefix is correctly applied to metric keys.""" + aggregator = MetricsAggregator() + metrics = [ + Metric("test_ds", "metric1", 42, AggregationType.SUM), + Metric("test_ds", "metric2", 84, AggregationType.SUM), + ] + aggregator.update(metrics) + + # Test with prefix + result_with_prefix = aggregator.get_metrics_for_logging(prefix="validation") + assert result_with_prefix["validation/test_ds/metric1"] == 42 + assert result_with_prefix["validation/test_ds/metric2"] == 84 + + # Test without prefix + result_no_prefix = aggregator.get_metrics_for_logging() + assert result_no_prefix["test_ds/metric1"] == 42 + assert result_no_prefix["test_ds/metric2"] == 84 \ No newline at end of file diff --git a/tests/torchtune/data/test_metrics_transform.py b/tests/torchtune/data/test_metrics_transform.py new file mode 100644 index 0000000000..1eed534e42 --- /dev/null +++ b/tests/torchtune/data/test_metrics_transform.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +from torchtune.data import AggregationType, Metric, StandardMetricTransform + + +class TestStandardMetricTransform: + """Tests for StandardMetricTransform functionality.""" + + def test_dataset_name_not_set_raises_error(self): + """Test that using transform without setting dataset name raises error.""" + transform = StandardMetricTransform() + sample = {"tokens": [1, 2, 3]} + + with pytest.raises(RuntimeError, match="set_dataset_name"): + transform(sample) + + def test_basic_metrics_generation(self): + """Test that transform generates expected metrics for a sample.""" + transform = StandardMetricTransform() + transform.set_dataset_name("test_dataset") + + sample = {"tokens": [1, 2, 3, 4, 5]} + result = transform(sample) + + # Should preserve original sample data + assert result["tokens"] == [1, 2, 3, 4, 5] + + # Should add metrics + assert "metrics" in result + metrics = result["metrics"] + assert len(metrics) == 3 + + # Check each metric + for metric in metrics: + if metric.name == "samples_seen": + assert metric.dataset_name == "test_dataset" + assert metric.value == 1 + assert metric.agg_type == AggregationType.SUM + + elif metric.name == "tokens_seen": + assert metric.dataset_name == "test_dataset" + assert metric.value == 5 + assert metric.agg_type == AggregationType.SUM + + elif metric.name == "seq_len": + assert metric.dataset_name == "test_dataset" + assert metric.value == 5 + assert metric.agg_type == AggregationType.DISTRIBUTION \ No newline at end of file diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py new file mode 100644 index 0000000000..4cf303c6fd --- /dev/null +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -0,0 +1,339 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import tempfile +from pathlib import Path +from itertools import islice +from typing import Any, Callable, Dict, List, Optional +from unittest.mock import Mock, patch + +import pytest +import torch +from torch.nn.utils.rnn import pad_sequence +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchtune.data import AggregationType, Metric, MetricsAggregator, StandardMetricTransform, padded_collate_sft +from torchtune.datasets import HfIterableDataset + + +# Test Constants - Avoid perfect divisions +SMALL_DATASET_SIZE = 23 +MEDIUM_DATASET_SIZE = 35 +SEED = 42 +BATCH_SIZE = 5 +DEFAULT_SHUFFLE_BUFFER_SIZE = 8 + + +def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None: + """Creates a dummy JSON test data file with token samples of varying lengths. + + Args: + path (Path): The path to the file to create + num_samples (int): The number of samples to create + offset (int): The offset to add to the sample ID to ensure unique IDs in different datasets + """ + with open(path, "w") as f: + for i in range(num_samples): + sample_id = i + offset + # Realistic token length variation (1-3 tokens) + token_len = (i % 3) + 1 + tokens = list(range(sample_id, sample_id + token_len)) + f.write( + f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}"}}\n' + ) + + +def collate_with_metrics(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + """Collate function that extracts metrics and uses padded_collate_sft as base collator.""" + # Extract metrics first + all_metrics = [] + clean_batch = [] + for sample in batch: + if "metrics" in sample: + all_metrics.extend(sample.pop("metrics")) + clean_batch.append(sample) + + if not clean_batch: + return {"metrics": all_metrics} + + # Use torchtune's padded_collate_sft as base collator + collated_batch = padded_collate_sft(clean_batch) + collated_batch["metrics"] = all_metrics + return collated_batch + + +def generate_ckpt( + dataloader: StatefulDataLoader, + aggregator: MetricsAggregator, + steps_before_checkpoint: int, + steps_after_checkpoint: int, + resume_dataloader: Optional[StatefulDataLoader] = None, + resume_aggregator: Optional[MetricsAggregator] = None, +) -> Dict[str, Any]: + """ + Generates a checkpoint by running through data and saving checkpoint mid-stream. + Optionally, a second dataloader and aggregator can be given to resume from ckpt + and run steps_after_checkpoint to match the first one. + + Args: + dataloader: The dataloader to test + aggregator: The metrics aggregator to use + steps_before_checkpoint: Number of steps to run before saving checkpoint + steps_after_checkpoint: Number of steps to run after checkpoint + resume_dataloader: Optional new dataloader to test resuming. If None, returns empty resumed_batches. + resume_aggregator: Optional new aggregator to test resuming. If None, returns empty resumed_metrics. + + Returns dict with batches/metrics from both pre and post checkpoint runs. + """ + iterator = iter(dataloader) + + # Collect batches before and after checkpoint + batches = [] + checkpoint_state = None + metrics_at_checkpoint = {} + + total_steps = steps_before_checkpoint + steps_after_checkpoint + + for idx, batch in enumerate(iterator): + batches.append(batch) + + # Process metrics + if "metrics" in batch: + aggregator.update(batch.pop("metrics")) + + # Save checkpoint state after steps_before_checkpoint + if idx == steps_before_checkpoint - 1: # -1 because idx is 0-based + checkpoint_state = { + "loader": dataloader.state_dict(), + "aggregator": aggregator.state_dict(), + } + metrics_at_checkpoint = aggregator.get_metrics_for_logging(prefix="train") + + # Stop after total steps + if idx == total_steps - 1: + break + + # Split batches + pre_checkpoint_batches = batches[:steps_before_checkpoint] + post_checkpoint_batches = batches[steps_before_checkpoint:] + + # Resume with new instances if provided + resumed_batches = [] + resumed_metrics = {} + + if ( + resume_dataloader is not None + and resume_aggregator is not None + and checkpoint_state is not None + ): + # Test resuming with new instances + resume_dataloader.load_state_dict(checkpoint_state["loader"]) + resume_aggregator.load_state_dict(checkpoint_state["aggregator"]) + resume_iterator = iter(resume_dataloader) + + # Collect only the post-checkpoint batches when resuming + for idx, batch in enumerate(resume_iterator): + resumed_batches.append(batch) + + # Process metrics + if "metrics" in batch: + resume_aggregator.update(batch.pop("metrics")) + + # Stop after steps_after_checkpoint + if idx == steps_after_checkpoint - 1: + break + + resumed_metrics = resume_aggregator.get_metrics_for_logging(prefix="train") + + return { + # Original run + "pre_checkpoint_batches": pre_checkpoint_batches, + "post_checkpoint_batches": post_checkpoint_batches, + "metrics_at_checkpoint": metrics_at_checkpoint, + "final_metrics": aggregator.get_metrics_for_logging(prefix="train"), + # Resumed run + "resumed_batches": resumed_batches, + "resumed_metrics": resumed_metrics, + # Internal state for loading - only if someone needs to manually load + "_checkpoint_state": checkpoint_state, + } + + +@pytest.fixture +def tmp_data_dir(tmp_path): + """Provide temporary directory for test data files.""" + return tmp_path + + +@pytest.fixture +def small_dataset_file(tmp_data_dir): + path = tmp_data_dir / "small_data.json" + create_test_json_file(path, SMALL_DATASET_SIZE, offset=0) + return str(path) + + +@pytest.fixture +def dataset_factory(): + """Factory for creating HfIterableDataset instances with common defaults.""" + def _create_dataset( + data_file: str, + dataset_name: str = "test_dataset", + shuffle: bool = False, + **kwargs + ) -> HfIterableDataset: + return HfIterableDataset( + path="json", + data_files=data_file, + split="train", + dataset_name=dataset_name, + seed=SEED, + shuffle_buffer_size=10 if shuffle else 0, + metric_transform=StandardMetricTransform(), + num_shards_per_rank=2, + **kwargs + ) + return _create_dataset + + +class TestHfIterableDataset: + """Tests for HfIterableDataset basic functionality.""" + + def test_default_dataset_name(self, small_dataset_file): + """Test that dataset name is auto-generated from path when not provided.""" + # Create dataset without specifying name + dataset = HfIterableDataset( + path="json", + data_files=small_dataset_file, + split="train", + # dataset_name not provided - should auto-generate + seed=SEED, + metric_transform=StandardMetricTransform(), + num_shards_per_rank=4, + ) + + # Should generate name from path and split + assert dataset.dataset_name == "json_train" + + # Test giving a name + dataset2 = HfIterableDataset( + path="json", + data_files=small_dataset_file, + split="train", + dataset_name = "my_dataset", + seed=SEED, + metric_transform=StandardMetricTransform(), + num_shards_per_rank=4, + ) + + # Should generate name from path and split + assert dataset2.dataset_name == "my_dataset" + + @pytest.mark.parametrize("num_epochs", [0.5, 1.0, 2.5]) + def test_epoch_boundaries_and_checkpointing( + self, num_epochs, dataset_factory, small_dataset_file + ): + """ + Tests that for N epochs, each sample appears exactly N times (rounded down), + the epoch metric is correct, and checkpointing works as expected. + """ + + # 1. Setup Dataloaders and Aggregators for original and resumed runs + def create_loader_and_aggregator(): + dataset = dataset_factory(small_dataset_file, shuffle=False) + loader = StatefulDataLoader( + dataset, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics + ) + aggregator = MetricsAggregator() + return loader, aggregator + + loader1, aggregator1 = create_loader_and_aggregator() + loader2, aggregator2 = create_loader_and_aggregator() + + # 2. Calculate steps for the test run + total_samples = int(SMALL_DATASET_SIZE * num_epochs) + total_steps = total_samples // BATCH_SIZE + + steps_before_checkpoint = max(1, total_steps // 2) + steps_after_checkpoint = total_steps - steps_before_checkpoint + + # 3. Generate checkpoint and resume + result = generate_ckpt( + loader1, + aggregator1, + steps_before_checkpoint=steps_before_checkpoint, + steps_after_checkpoint=steps_after_checkpoint, + resume_dataloader=loader2, + resume_aggregator=aggregator2, + ) + + # 4. Verify checkpointing and resumption + orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] + assert ( + orig_post_ids == resumed_ids + ), "Resumed batches should be identical for deterministic run" + assert ( + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics should match" + + def test_shuffling_behavior(self, dataset_factory, small_dataset_file): + """Tests that shuffling changes data order between epochs but preserves the set of samples.""" + # Test unshuffled dataset + unshuffled_ds = dataset_factory( + small_dataset_file, dataset_name="unshuffled", shuffle=False + ) + + # Get samples from two passes through the dataset + epoch_samples = islice(iter(unshuffled_ds), SMALL_DATASET_SIZE*2) + + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] + second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] + + # Unshuffled should have same order in both epochs + assert first_epoch_samples == list(range(SMALL_DATASET_SIZE)) + assert second_epoch_samples == list(range(SMALL_DATASET_SIZE)) + + # Test shuffled dataset + shuffled_ds = dataset_factory( + small_dataset_file, dataset_name="shuffled", shuffle=True + ) + + # Collect full epochs to compare + epoch_samples = islice(iter(shuffled_ds), SMALL_DATASET_SIZE*2) + + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] + second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] + + # Shuffled epochs should have different order + assert first_epoch_samples != list( + range(SMALL_DATASET_SIZE) + ), f"Shuffled should not be sorted, got {first_epoch_samples}" + assert ( + first_epoch_samples != second_epoch_samples + ), f"Shuffled epochs should be shuffled differently, got {first_epoch_samples} and {second_epoch_samples}" + + # But should contain the same set of IDs + assert set(first_epoch_samples) == set(range(SMALL_DATASET_SIZE)), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_samples}" + assert set(second_epoch_samples) == set(range(SMALL_DATASET_SIZE)), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_samples}" + + def test_epoch_tracking(self, dataset_factory, small_dataset_file): + """Test that epoch number is correctly tracked across dataset restarts.""" + dataset = dataset_factory(small_dataset_file, shuffle=False) + + # Two epoch samples + epoch_samples = islice(iter(dataset), SMALL_DATASET_SIZE*2) + + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] + second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] + + # All should have epoch 0 + epoch_values = [epoch_metric.value for epoch_metric in first_epoch_samples["metrics"]] + assert all(epoch_value == 0 for epoch_value in epoch_values), f"Epoch values should be 0, got {epoch_values}" + + # All should have epoch 1 + epoch_values = [epoch_metric.value for epoch_metric in second_epoch_samples["metrics"]] + assert all(epoch_value == 1 for epoch_value in epoch_values), f"Epoch values should be 1, got {epoch_values}" diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py new file mode 100644 index 0000000000..1190d6d774 --- /dev/null +++ b/tests/torchtune/datasets/test_interleaved.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import islice +from typing import Any, Dict, Iterator +from unittest.mock import patch + +import pytest +import torch + +from torchtune.data import AggregationType, Metric, MetricsAggregator +from torchtune.datasets import InterleavedDataset, TuneIterableDataset + + +class TestInterleavedDataset: + """Tests for multi-dataset interleaving functionality.""" + + def test_initialization_validation(self, dataset_factory, small_dataset_file): + """Tests that the dataset raises errors for invalid configurations, like duplicate names.""" + # Test duplicate dataset names + ds1 = dataset_factory(small_dataset_file, dataset_name="duplicate") + ds2 = dataset_factory(small_dataset_file, dataset_name="duplicate") + + with pytest.raises(ValueError, match="Duplicate dataset names detected"): + InterleavedDataset(datasets=[ds1, ds2], weights=[0.5, 0.5], seed=SEED) + + # Test weight normalization (should work with warning) + ds3 = dataset_factory(small_dataset_file, dataset_name="ds3") + ds4 = dataset_factory(small_dataset_file, dataset_name="ds4") + + with patch("logging.Logger.warning") as mock_warning: + interleaved = InterleavedDataset( + datasets=[ds3, ds4], + weights=[0.5, 1.5], + seed=SEED, + dataset_name="test_interleaved" # Sum = 2.0 != 1.0 + ) + + # Check that weights were normalized + assert torch.allclose(interleaved._weights, torch.tensor([0.25, 0.75])) + mock_warning.assert_called_once() + + assert interleaved.dataset_name == "test_interleaved" + + def test_sampling_ratios( + self, dataset_factory, small_dataset_file, medium_dataset_file + ): + """Tests that datasets are sampled according to their assigned weights.""" + # Create two datasets with distinct ID ranges + # ds1 has IDs 0-22 (small dataset) + # ds2 has IDs 100-134 (medium dataset with offset) + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1") + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2") + + # Test with 70/30 weighting + weights = [0.7, 0.3] + interleaved = InterleavedDataset([ds1, ds2], weights, seed=SEED) + + # Collect 300 samples + sample_count = 300 + samples = list(islice(iter(interleaved), sample_count)) + + # Count samples by checking ID ranges + # ds1 has IDs < 100, ds2 has IDs >= 100 + ds1_count = sum(1 for s in samples if s["id"] < 100) + ds2_count = sum(1 for s in samples if s["id"] >= 100) + + assert ds1_count + ds2_count == sample_count + + # Check ratios are approximately correct + ds1_ratio = ds1_count / sample_count + ds2_ratio = ds2_count / sample_count + + # Allow 10% tolerance due to randomness + assert abs(ds1_ratio - 0.7) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.7" + assert abs(ds2_ratio - 0.3) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~0.3" + + def test_metrics_aggregation( + self, dataset_factory, small_dataset_file, medium_dataset_file + ): + """Tests that metrics from all child datasets are collected and aggregated.""" + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1") + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2") + + interleaved = InterleavedDataset([ds1, ds2], [0.2, 0.8], seed=SEED) + aggregator = MetricsAggregator() + + # Process some samples + TOTAL_SAMPLES = 200 + for sample in islice(iter(interleaved), 200): + aggregator.update(sample["metrics"]) + + metrics = aggregator.get_metrics_for_logging() + + # Should have metrics from both datasets, with flat keys + assert "ds1/samples_seen" in metrics + assert "ds2/samples_seen" in metrics + + # Both datasets should have contributed samples + assert metrics["ds1/samples_seen"] > 0 + assert metrics["ds2/samples_seen"] > 0 + + # Total samples should equal what we processed + calculated_total_samples = ( + metrics["ds1/samples_seen"] + metrics["ds2/samples_seen"] + ) + assert calculated_total_samples == TOTAL_SAMPLES + + # Test that ratio is approximately correct + ds1_ratio = metrics["ds1/samples_seen"] / TOTAL_SAMPLES + ds2_ratio = metrics["ds2/samples_seen"] / TOTAL_SAMPLES + + # Allow 10% tolerance due to randomness + assert abs(ds1_ratio - 0.2) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.2" + assert abs(ds2_ratio - 0.8) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~0.8" + + def test_checkpointing( + self, dataset_factory, small_dataset_file, medium_dataset_file + ): + """Tests that interleaved dataset checkpointing preserves sampling state.""" + + def create_interleaved(): + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1") + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2") + return InterleavedDataset([ds1, ds2], [0.7, 0.3], seed=SEED) + + # Original run + interleaved1 = create_interleaved() + loader1 = StatefulDataLoader( + interleaved1, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics + ) + aggregator1 = MetricsAggregator() + + # Resumed run + interleaved2 = create_interleaved() + loader2 = StatefulDataLoader( + interleaved2, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics + ) + aggregator2 = MetricsAggregator() + + result = generate_ckpt( + loader1, + aggregator1, + steps_before_checkpoint=10, + steps_after_checkpoint=20, + resume_dataloader=loader2, + resume_aggregator=aggregator2, + ) + + orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] + assert ( + orig_post_ids == resumed_ids + ), "Resumed batches should be identical for deterministic run" + assert ( + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics should match" + + \ No newline at end of file diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index a75e16780a..e1d7d687dd 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -32,11 +32,23 @@ QuestionAnswerTemplate, SummarizeTemplate, ) +from torchtune.data._metrics import ( + AggregationType, + Metric, + MetricTransform, + StandardMetricTransform, +) from torchtune.data._utils import format_content_with_images, load_image, truncate +from torchtune.data._aggregator import MetricsAggregator __all__ = [ + "AggregationType", "CROSS_ENTROPY_IGNORE_IDX", "GrammarErrorCorrectionTemplate", + "Metric", + "MetricsAggregator", + "MetricTransform", + "StandardMetricTransform", "SummarizeTemplate", "OpenAIToMessages", "ShareGPTToMessages", diff --git a/torchtune/data/_aggregator.py b/torchtune/data/_aggregator.py new file mode 100644 index 0000000000..f6b962c84c --- /dev/null +++ b/torchtune/data/_aggregator.py @@ -0,0 +1,342 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import collections +import logging +from typing import Any, Dict, List, Tuple + +import torch +import torch.distributed as dist + +from torchtune.data._metrics import AggregationType, Metric + +logger = logging.getLogger(__name__) + + +class MetricsAggregator: + """ + Aggregates metrics across datasets and distributed ranks. + + The internal state `_state` is a dictionary where the key is a tuple + of `(dataset_name, metric_name)` and the value is another dictionary + holding the metric's specific state (e.g., `{'type': AggregationType.SUM, 'value': 10}`). + + Usage: + aggregator = MetricsAggregator() + aggregator.update(metrics) + # Get logger-ready metrics {key: value} + metrics = aggregator.get_metrics_for_logging(prefix="train") # {"train/dataset1/tokens": 1234, ...} + """ + + def __init__(self, dist_window_size: int = 1000): + # State shape: {(dataset_name, metric_name): {type: AggType, value/sum/counts/etc}} + self._state: Dict[Tuple[str, str], Dict[str, Any]] = {} + + # For distributions, we keep a window of values to compute percentiles + self._dist_window_size = dist_window_size + + def update(self, metrics: List[Metric]) -> None: + """Update internal state with new metrics. + + Args: + metrics: List of Metric objects + """ + for metric in metrics: + key = (metric.dataset_name, metric.name) + + if key not in self._state: + self._initialize_state(key, metric.agg_type) + + state = self._state[key] + + # Update based on aggregation type + if metric.agg_type == AggregationType.SUM: + state["value"] += metric.value + elif metric.agg_type == AggregationType.MAX: + if state["value"] is not None: + state["value"] = max(state["value"], metric.value) + else: + state["value"] = metric.value + elif metric.agg_type == AggregationType.MIN: + if state["value"] is not None: + state["value"] = min(state["value"], metric.value) + else: + state["value"] = metric.value + elif metric.agg_type == AggregationType.MEAN: + state["sum"] += metric.value + state["count"] += 1 + elif metric.agg_type == AggregationType.DISTRIBUTION: + state["values"].append(metric.value) + elif metric.agg_type == AggregationType.CATEGORICAL_COUNT: + state["counts"][metric.value] += 1 + + def _initialize_state( + self, key: Tuple[str, str], agg_type: AggregationType + ) -> None: + """Initialize state for a new metric.""" + self._state[key] = {"type": agg_type} + state = self._state[key] + + if agg_type == AggregationType.SUM: + state["value"] = 0.0 + elif agg_type in (AggregationType.MAX, AggregationType.MIN): + state["value"] = None + elif agg_type == AggregationType.MEAN: + state["sum"] = 0.0 + state["count"] = 0 + elif agg_type == AggregationType.DISTRIBUTION: + state["values"] = collections.deque(maxlen=self._dist_window_size) + elif agg_type == AggregationType.CATEGORICAL_COUNT: + state["counts"] = collections.Counter() + + def get_metrics_for_logging(self, prefix: str = "") -> Dict[str, float]: + """ + Returns aggregated metrics ready for logging to wandb/tensorboard. + + Args: + prefix: Optional prefix like "train" or "valid" for metric keys + + Returns: + Flat dictionary with keys like "train/dataset1/tokens_seen" -> float value + Ready to be logged directly: wandb.log(metrics) + """ + # Always compute local metrics first + local_metrics = self._compute_local_metrics() + + # In distributed mode, perform reduction + if dist.is_initialized() and dist.get_world_size() > 1: + metrics = self._compute_distributed_metrics(local_metrics) + else: + metrics = local_metrics + + # Format for logging with proper key structure + return self._format_for_logging(metrics, prefix) + + def _compute_local_metrics(self) -> Dict[Tuple[str, str], Dict[str, Any]]: + """ + Compute metrics from current state. + + For distributions and categoricals, expands into multiple entries. + The dict format allows future extensions with additional fields. + + Args: + None + + Returns: + Dictionary mapping (dataset_name, metric_name) -> {"value": value, "agg_type": aggregation_type} + """ + metrics = {} + + for (ds_name, metric_name), state in self._state.items(): + agg_type = state["type"] + + if agg_type in ( + AggregationType.SUM, + AggregationType.MAX, + AggregationType.MIN, + ): + # For sum, max, and min, we just need to return the value + metrics[(ds_name, metric_name)] = { + "value": state["value"], + "agg_type": agg_type, + } + + elif agg_type == AggregationType.MEAN: + if state["count"] > 0: + value = state["sum"] / state["count"] + metrics[(ds_name, metric_name)] = { + "value": value, + "agg_type": agg_type, + } + + elif agg_type == AggregationType.DISTRIBUTION: + # queue -> list + values = list(state["values"]) + + # Sort to get percentiles efficiently + sorted_values = sorted(values) + n = len(sorted_values) + + # Each stat becomes its own metric + # For percentiles, it is an approximattion by computing avg of averages + metrics[(ds_name, f"{metric_name}_mean")] = { + "value": sum(values) / n, + "agg_type": AggregationType.MEAN, + } + metrics[(ds_name, f"{metric_name}_min")] = { + "value": sorted_values[0], + "agg_type": AggregationType.MIN, + } + metrics[(ds_name, f"{metric_name}_max")] = { + "value": sorted_values[-1], + "agg_type": AggregationType.MAX, + } + metrics[(ds_name, f"{metric_name}_p05")] = { + "value": sorted_values[max(0, int(0.05 * n) - 1)], + "agg_type": AggregationType.MEAN, + } + metrics[(ds_name, f"{metric_name}_p50")] = { + "value": sorted_values[max(0, int(0.5 * n) - 1)], + "agg_type": AggregationType.MEAN, + } + metrics[(ds_name, f"{metric_name}_p95")] = { + "value": sorted_values[max(0, int(0.95 * n) - 1)], + "agg_type": AggregationType.MEAN, + } + + elif agg_type == AggregationType.CATEGORICAL_COUNT: + # Expand categorical counts into individual metrics + for category, count in state["counts"].items(): + metrics[(ds_name, f"{metric_name}_{category}_count")] = { + "value": count, + "agg_type": AggregationType.SUM, + } + + return metrics + + def _compute_distributed_metrics( + self, local_metrics: Dict[Tuple[str, str], Dict[str, Any]] + ) -> Dict[Tuple[str, str], Dict[str, Any]]: + """ + Performs distributed reduction on metrics. + + Strategy: + 1. Do a single all_gather_object to collect all metrics from all ranks + 2. Group metrics by key and aggregation type + 3. Apply the appropriate reduction operation locally + + This avoids complex tensor operations and handles all reduction in one pass. + + Args: + local_metrics: Dict mapping (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} + + Returns: + Reduced metrics in same format as input + + Example: + rank_1_metrics = + { + ("ds1", "metric1"): {"value": 10, "agg_type": AggregationType.SUM}, + ("ds2", "metric2"): {"value": 20, "agg_type": AggregationType.MEAN}, + } + rank_2_metrics = + { + ("ds1", "metric1"): {"value": 30, "agg_type": AggregationType.SUM}, + ("ds2", "metric2"): {"value": 40, "agg_type": AggregationType.MEAN}, + } + + # After reduction + result = + { + ("ds1", "metric1"): {"value": 40, "agg_type": AggregationType.SUM}, + ("ds2", "metric2"): {"value": 30, "agg_type": AggregationType.MEAN}, + } + """ + world_size = dist.get_world_size() + + # Gather all metrics from all ranks in one operation + dist.barrier() + all_metrics = [None] * world_size + dist.all_gather_object(all_metrics, local_metrics) + + # Group values by key for reduction + grouped = collections.defaultdict(list) + for rank_metrics in all_metrics: + if rank_metrics: # It's possible a rank has no metrics + for key, metric_dict in rank_metrics.items(): + # A key is a tuple (dataset, metric) + grouped[key].append(metric_dict) + + # Reduce based on aggregation type + reduced = {} + if not grouped: + return reduced + + for key, metric_dicts in grouped.items(): + # All metrics for a key should have same type, just take first + values = [m["value"] for m in metric_dicts] + agg_type = metric_dicts[0]["agg_type"] + + # Start with copy of first dict to preserve any extra fields + result_dict = metric_dicts[0].copy() + + if agg_type == AggregationType.SUM: + result_dict["value"] = sum(values) + elif agg_type == AggregationType.MAX: + result_dict["value"] = max(values) + elif agg_type == AggregationType.MIN: + result_dict["value"] = min(values) + elif agg_type == AggregationType.MEAN: + result_dict["value"] = sum(values) / len(values) + + reduced[key] = result_dict + + return reduced + + def _format_for_logging( + self, metrics: Dict[Tuple[str, str], Dict[str, Any]], prefix: str + ) -> Dict[str, float]: + """ + Format metrics for wandb/tensorboard logging. + + Args: + metrics: Dict mapping (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} + prefix: Optional prefix like "train" or "valid" + + Returns: + Flat dict with string keys like "train/dataset1/tokens_seen" -> float + """ + formatted = {} + + for (ds_name, metric_name), metric_dict in metrics.items(): + # Build key: "prefix/dataset/metric" or "dataset/metric" if no prefix + if prefix: + key = f"{prefix}/{ds_name}/{metric_name}" + else: + key = f"{ds_name}/{metric_name}" + + formatted[key] = metric_dict["value"] + + return formatted + + def state_dict(self) -> Dict[str, Any]: + """Serialize aggregator state. The state is almost directly serializable.""" + serializable_state = {} + for key, state in self._state.items(): + state_copy = state.copy() + + # Convert non-serializable types + if "values" in state_copy: + state_copy["values"] = list(state_copy["values"]) # deque → list + if "counts" in state_copy: + state_copy["counts"] = dict(state_copy["counts"]) # Counter → dict + + # Convert tuple key to string for JSON compatibility + # JSON doesn't support tuple keys, so we convert (dataset, metric) → "('dataset', 'metric')" + serializable_state[str(key)] = state_copy + return {"state": serializable_state, "dist_window_size": self._dist_window_size} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Load aggregator state from checkpoint.""" + self._dist_window_size = state_dict["dist_window_size"] + + deserialized_state = {} + for key_str, state in state_dict["state"].items(): + # Convert string keys back to tuples + # "('dataset', 'metric')" → ('dataset', 'metric') + key = ast.literal_eval(key_str) + + # Re-wrap values in their original types + if state.get("type") == AggregationType.DISTRIBUTION: + state["values"] = collections.deque( + state["values"], maxlen=self._dist_window_size + ) + if state.get("type") == AggregationType.CATEGORICAL_COUNT: + state["counts"] = collections.Counter(state["counts"]) + + deserialized_state[key] = state + self._state = deserialized_state \ No newline at end of file diff --git a/torchtune/data/_metrics.py b/torchtune/data/_metrics.py new file mode 100644 index 0000000000..f61d0e579e --- /dev/null +++ b/torchtune/data/_metrics.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from functools import partial +from typing import Any, Callable, Dict, Optional, Protocol, Union + + +class AggregationType(Enum): + """Defines how a metric's value should be aggregated.""" + + SUM = "sum" + MEAN = "mean" + DISTRIBUTION = "distribution" + CATEGORICAL_COUNT = "categorical_count" + MAX = "max" + MIN = "min" + + +@dataclass(frozen=True) +class Metric: + """A self-describing metric object.""" + + dataset_name: str + name: str + value: Union[int, float, str] + agg_type: AggregationType + + +class MetricTransform(Protocol): + """Protocol for metric transforms.""" + + def set_dataset_name(self, dataset_name: str) -> None: ... + def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: ... + + +class StandardMetricTransform(MetricTransform): + """ + Attaches per-sample metrics for tracking training progress. + + This transform is responsible for generating metrics on a per-sample + basis (e.g., tokens per sample). The actual aggregation of these metrics + (eg calculating sum of samples seen) is handled by the + `MetricsAggregator`. This separation of concerns ensures that metrics are + correctly aggregated even with multiple dataloader workers and in a + distributed setting. + + Tracked metrics include: + - samples_seen: A count of samples processed. + - tokens_seen: The cumulative sum of all tokens processed. + - seq_len: A distribution of sequence lengths. + """ + + def __init__(self): + # dataset_name is set by the dataset using set_dataset_name + self.dataset_name: Optional[str] = None + self.new_metric: Optional[Callable] = None + + def set_dataset_name(self, dataset_name: str) -> None: + """Called by dataset to set the namespace for metrics. + The dataset name is used to differentiate multiple datasets stats, + e.g. "train/dataset1/tokens_seen" and "train/dataset2/tokens_seen".""" + self.dataset_name = dataset_name + self.new_metric = partial(Metric, dataset_name=dataset_name) + + def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + if self.dataset_name is None or self.new_metric is None: + raise RuntimeError( + "set_dataset_name() must be called before using the transform." + ) + + # Determine token key + token_key = "tokens" if "tokens" in sample else "input_ids" + token_len = len(sample.get(token_key, [])) + + # Create metrics for this sample + metrics = [ + self.new_metric(name="samples_seen", value=1, agg_type=AggregationType.SUM), + self.new_metric( + name="tokens_seen", value=token_len, agg_type=AggregationType.SUM + ), + self.new_metric( + name="seq_len", value=token_len, agg_type=AggregationType.DISTRIBUTION + ), + ] + + # Append to existing metrics list or create new one + if "metrics" not in sample: + sample["metrics"] = [] + sample["metrics"].extend(metrics) + return sample \ No newline at end of file diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index b0c7c11738..4ea863169d 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -5,18 +5,25 @@ # LICENSE file in the root directory of this source tree. from torchtune.datasets import multimodal -from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset +from torchtune.datasets._alpaca import ( + alpaca_cleaned_dataset, + alpaca_dataset, + alpaca_iterable_dataset, +) from torchtune.datasets._chat import chat_dataset from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset from torchtune.datasets._concat import ConcatDataset from torchtune.datasets._grammar import grammar_dataset +from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset from torchtune.datasets._instruct import instruct_dataset +from torchtune.datasets._interleaved import InterleavedDataset +from torchtune.datasets._iterable_base import TuneIterableDataset from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import preference_dataset, PreferenceDataset from torchtune.datasets._samsum import samsum_dataset -from torchtune.datasets._sft import SFTDataset -from torchtune.datasets._slimorca import slimorca_dataset +from torchtune.datasets._sft import SFTDataset, sft_iterable_dataset +from torchtune.datasets._slimorca import slimorca_dataset, slimorca_iterable_dataset from torchtune.datasets._stack_exchange_paired import stack_exchange_paired_dataset from torchtune.datasets._text_completion import ( text_completion_dataset, @@ -25,23 +32,29 @@ from torchtune.datasets._wikitext import wikitext_dataset __all__ = [ - "alpaca_dataset", "alpaca_cleaned_dataset", + "alpaca_dataset", + "alpaca_iterable_dataset", + "chat_dataset", + "cnn_dailymail_articles_dataset", + "ConcatDataset", "grammar_dataset", - "samsum_dataset", - "stack_exchange_paired_dataset", - "slimorca_dataset", + "hh_rlhf_helpful_dataset", + "HfIterableDataset", "instruct_dataset", + "InterleavedDataset", + "multimodal", + "PackedDataset", "preference_dataset", - "chat_dataset", + "PreferenceDataset", + "samsum_dataset", + "SFTDataset", + "sft_iterable_dataset", + "slimorca_dataset", + "slimorca_iterable_dataset", + "stack_exchange_paired_dataset", "text_completion_dataset", "TextCompletionDataset", - "cnn_dailymail_articles_dataset", - "PackedDataset", - "ConcatDataset", + "TuneIterableDataset", "wikitext_dataset", - "PreferenceDataset", - "SFTDataset", - "hh_rlhf_helpful_dataset", - "multimodal", ] diff --git a/torchtune/datasets/_alpaca.py b/torchtune/datasets/_alpaca.py index 1ecee62f53..4225ab4bf5 100644 --- a/torchtune/datasets/_alpaca.py +++ b/torchtune/datasets/_alpaca.py @@ -9,9 +9,11 @@ from typing import Any, Callable, Optional, Union from torchtune.data._messages import AlpacaToMessages +from torchtune.data._metrics import StandardMetricTransform +from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._packed import PackedDataset -from torchtune.datasets._sft import SFTDataset +from torchtune.datasets._sft import SFTDataset, sft_iterable_dataset from torchtune.modules.transforms.tokenizers import ModelTokenizer @@ -101,3 +103,64 @@ def alpaca_dataset( original Alpaca dataset, `yahma/alpaca-cleaned `_. See the dataset page and :func:`~torchtune.datasets.alpaca_dataset` for more details. """ + + +def alpaca_iterable_dataset( + model_transform: ModelTokenizer, + *, + source: str = "tatsu-lab/alpaca", + column_map: Optional[dict[str, str]] = None, + train_on_input: bool = True, + shuffle_buffer_size: Optional[int] = 1000, + seed: int = 42, + dataset_name: Optional[str] = None, + filter_fn: Optional[Callable] = None, + split: str = "train", + **load_dataset_kwargs: dict[str, Any], +) -> HfIterableDataset: + """ + Support for iterable version of Alpaca-style datasets. + + This returns an infinite iterable dataset that supports checkpointing + and metrics tracking, designed for step-based training. + + Args: + model_transform (ModelTokenizer): Model tokenizer used to tokenize the messages. + source (str): path to dataset repository on Hugging Face. Default is ``tatsu-lab/alpaca``. + column_map (Optional[dict[str, str]]): a mapping from the expected columns in the message transform + :class:`~torchtune.data.AlpacaToMessages` to the new column names in the dataset. Keys should be + "instruction", "input", and "output" and values should be the actual column names. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is True. + shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. If None or 0, no shuffling is done. + seed (int): Seed for shuffling. + dataset_name (Optional[str]): Name of the dataset for metrics tracking. If None, auto-generated. + filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. + split (str): ``split`` argument for ``datasets.load_dataset``. Default is "train". + **load_dataset_kwargs (dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. + + Returns: + HfIterableDataset: iterable dataset configured with source data and transforms + + Example: + >>> from torchdata.stateful_dataloader import StatefulDataLoader + >>> alpaca_ds = alpaca_iterable_dataset(tokenizer=tokenizer) + >>> dataloader = StatefulDataLoader(alpaca_ds, batch_size=8) + >>> for batch in dataloader: + >>> print(f"Batch size: {len(batch)}") + >>> Batch size: 8 + """ + message_transform = AlpacaToMessages( + train_on_input=train_on_input, column_map=column_map + ) + + return sft_iterable_dataset( + message_transform=message_transform, + model_transform=model_transform, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + dataset_name=dataset_name, + filter_fn=filter_fn, + split=split, + path=source, + **load_dataset_kwargs, + ) diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py new file mode 100644 index 0000000000..9a206445d5 --- /dev/null +++ b/torchtune/datasets/_hf_iterable.py @@ -0,0 +1,271 @@ +import logging +from typing import Any, Callable, Dict, Iterator, List, Optional + +import torch +import torch.distributed as dist +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node + +from torchtune.data._metrics import AggregationType, Metric, StandardMetricTransform +from torchtune.datasets._iterable_base import TuneIterableDataset + +logger = logging.getLogger(__name__) + + +class HfIterableDataset(TuneIterableDataset): + """HuggingFace dataset implementation with composable metrics. + + This is an infinite dataset. After exhausting the dataset, it will restart from the beginning. + + This dataset is responsible for: + - Loading and sharding the dataset + - Shuffling at initialization and after each epoch + - Applying transforms + - Returning an infinite iterator over the dataset + + Args: + message_transform (Optional[Callable]): Transforms raw data into Message + model_transform (Optional[Callable]): Take messages and prepares it for the model. Usually the tokenizer. + output_transform (Optional[Callable]): Takes tokenized inputs and prepares it for the recipe. Usually + does some label manipulation, e.g. ignore index. Think of it as recipe-dependent, e.g. SFT, RL, DPO, etc. + metric_transform (Optional[Callable]): Takes the sample and computes metrics, e.g. token count. + If None, a default transform is used. To stop tracking metrics, set it to lambda x: x. + shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. If None or 0, no shuffling is done. + seed (int): Seed for shuffling. + num_shards_per_rank (int): Target number of shards per worker (GPU). It will find a multiple + of world_size * dataloader_workers. + dataset_name (Optional[str]): Name of the dataset. If None, a default name is generated + from the path, source, and split. + filter_fn (Optional[Callable]): Filter function to apply to the dataset. + filter_kwargs (Optional[Dict[str, Any]]): Keyword arguments to pass to the filter function. + load_dataset_kwargs (Dict[str, Any]): Keyword arguments to pass to the load_dataset function. + + """ + + def __init__( + self, + *, + message_transform: Optional[Callable] = None, + model_transform: Optional[Callable] = None, + output_transform: Optional[Callable] = None, + metric_transform: Optional[Callable] = None, + shuffle_buffer_size: Optional[int] = 1000, + weight: Optional[float] = 1.0, + seed: int = 42, + num_shards_per_rank: int = 64, + dataset_name: Optional[str] = None, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[Dict[str, Any]] = None, + **load_dataset_kwargs, + ): + # Store configuration + self._shuffle_buffer_size = shuffle_buffer_size + self._seed = seed + self._message_transform = message_transform + self._model_transform = model_transform + self._output_transform = output_transform + self._weight = weight # TODO: make it a property? + + # Create default transform if not provided + self._metric_transform = metric_transform or StandardMetricTransform() + + # Auto-generate dataset name if not provided, ensuring it's always a string. + if dataset_name is None: + path = load_dataset_kwargs.get("path", None) + source = load_dataset_kwargs.get("source", None) + split = load_dataset_kwargs.get("split", None) + name_parts = [] + for item in [path, source, split]: + if item is not None: + name_parts.append(str(item).replace("/", "_")) + self._dataset_name: str = "_".join(name_parts) + else: + self._dataset_name: str = dataset_name + + # Set dataset name on the transform if it supports it + if hasattr(self._metric_transform, "set_dataset_name"): + self._metric_transform.set_dataset_name(self._dataset_name) + + # Internal state for resumption + self._num_epochs = 0 + + # Load and setup HF dataset + self._setup_hf_dataset( + load_dataset_kwargs, num_shards_per_rank, filter_fn, filter_kwargs + ) + + @property + def dataset_name(self) -> str: + return self._dataset_name + + def _apply_transforms(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Apply transforms if they exist, otherwise return sample unchanged.""" + if self._message_transform is not None: + sample = self._message_transform(sample) + if self._model_transform is not None: + sample = self._model_transform(sample) + if self._output_transform is not None: + sample = self._output_transform(sample) + if self._metric_transform is not None: + sample = self._metric_transform(sample) + return sample + + def _setup_hf_dataset( + self, + load_dataset_kwargs: Dict[str, Any], + num_shards_per_rank: int, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Configures the Hugging Face dataset, including sharding, filtering, and + transform mapping. This method is called only once during initialization + to avoid expensive re-computation on each epoch. + """ + + # Distributed setup + world_size, rank = 1, 0 + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + # Load and shard dataset + ds = load_dataset(**load_dataset_kwargs) + + # Use to_iterable_dataset for streaming datasets + if not load_dataset_kwargs.get("streaming", False): + + # Define number of shards based on (world_size, num of shards per GPU, dataloader workers) + # E.g. world_size=2, num_shards_per_rank=16, dataloader_workers=3 + # we will try 2*16 = 32 shards. Since 32 is not a multiple of 3, we will do 36 shards. + # Each rank gets 16 shards, each dataloader worker in that rankgets 6 shards. + worker_info = torch.utils.data.get_worker_info() + num_dataloader_workers = worker_info.num_workers if worker_info else 1 + + # Calculate total workers + total_workers = world_size * num_dataloader_workers + + # Calculate desired shards + desired_shards = world_size * num_shards_per_rank + + # Find the smallest multiple of total_workers that is >= desired_shards + if desired_shards % total_workers == 0: + num_shards = desired_shards + else: + num_shards = total_workers * ( + (desired_shards + total_workers - 1) // total_workers + ) + + # If the dataset is not streaming and has a defined length, + # we cannot have num_shards > dataset_size. + if not load_dataset_kwargs.get("streaming", False) and hasattr( + ds, "__len__" + ): + dataset_size = len(ds) + if num_shards > dataset_size: + raise ValueError( + f"Number of shards ({num_shards}) is greater than the dataset size ({dataset_size})." + f"Please decrease num_shards_per_rank." + ) + + ds = ds.to_iterable_dataset(num_shards=num_shards) + + # Shuffle the dataset + if self._shuffle_buffer_size and self._shuffle_buffer_size > 0: + ds = ds.shuffle(seed=self._seed, buffer_size=self._shuffle_buffer_size) + + # Distribute across ranks + if world_size > 1: + ds = split_dataset_by_node(ds, rank=rank, world_size=world_size) + + # Apply filtering if specified + if filter_fn: + filter_kwargs = filter_kwargs or {} + ds = ds.filter(filter_fn, **filter_kwargs) + + self._ds = ds + + def __iter__(self) -> Iterator[Dict[str, Any]]: + """Iterate through the dataset infinitely. + + It will restart from the beginning after exhausting the dataset. + + If shuffle_buffer_size is set, it will shuffle the dataset at the beginning of each epoch + when set_epoch is called. + + An additional metric "num_epochs" is added to the sample. + """ + epoch_ds = self._ds + + while True: # Infinite iteration + epoch_seed = self._seed + self._num_epochs + epoch_ds.set_epoch(epoch_seed) + epoch_iterator = iter(epoch_ds) + samples_yielded = 0 + + try: + for sample in epoch_iterator: + # NOTE: We apply transforms here instead of using .map() call + # to work around https://github.com/huggingface/datasets/issues/7630 + # where .map() can cause incorrect resumption from a checkpoint. + sample = self._apply_transforms(sample) + + # Track the number of epochs completed for each dataset. This is + # especially useful when interleaving multiple datasets, but + # also necessary to track dataset-level metrics. + metric_num_epochs = Metric( + dataset_name=self.dataset_name, + name="num_epochs", + value=self._num_epochs, + agg_type=AggregationType.MAX, + ) + if "metrics" not in sample: + sample["metrics"] = [] + sample["metrics"].append(metric_num_epochs) + + samples_yielded += 1 + yield sample + + except StopIteration: + pass # Iterator is exhausted, which is expected. + except Exception as e: + logger.warning( + f"Dataset {self.dataset_name} encountered an unexpected error: {e}." + ) + raise + + # Check if we got zero samples - this might indicate an issue + if samples_yielded == 0: + logger.warning( + f"Dataset {self.dataset_name} epoch {self._num_epochs} yielded 0 samples - potential issue!" + ) + + # Epoch complete - increment and continue infinite loop + self._num_epochs += 1 + + # Reset to the base dataset for the next epoch's shuffling. + epoch_ds = self._ds + + def state_dict(self) -> Dict[str, Any]: + """ + The dataset returns its own state directly, without namespacing. + """ + hf_state = self._ds.state_dict() + state = { + "num_epochs": self._num_epochs, + "seed": self._seed, + "hf_dataset_state": hf_state, + } + return state + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Load state from checkpoint, including restoring the state of the + Hugging Face IterableDataset. + """ + self._num_epochs = state_dict["num_epochs"] + hf_state = state_dict["hf_dataset_state"] + + # HF is responsible for resuming the dataset state + # where it last left off + self._ds.load_state_dict(hf_state) \ No newline at end of file diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py new file mode 100644 index 0000000000..cbfe36338c --- /dev/null +++ b/torchtune/datasets/_interleaved.py @@ -0,0 +1,115 @@ +import collections +import logging +import math +from typing import Any, Dict, Iterator, List + +import torch + +from torchtune.datasets._iterable_base import TuneIterableDataset + +logger = logging.getLogger(__name__) + + +class InterleavedDataset(TuneIterableDataset): + """Infinitely interleaves multiple TuneIterableDatasets according to a list of weights. + - The weights are normalized to sum to 1.0. + - This dataset is responsible for managing the state of its child datasets + to ensure correct checkpointing and resumption. + + Args: + datasets (List[TuneIterableDataset]): List of TuneIterableDatasets to interleave. + weights (List[float]): List of weights for each dataset. Must sum to 1.0. + seed (int): Seed for sampling. + dataset_name (str): Name of the dataset. If None, defaults to "interleaved_dataset". + """ + + def __init__( + self, + datasets: List[TuneIterableDataset], + weights: List[float], + seed: int, + dataset_name: str = "interleaved_dataset", + ): + self._dataset_name = dataset_name + + # Preserve original order for weighted sampling + self._dataset_names = [ds.dataset_name for ds in datasets] + + # Create a name-to-dataset mapping for robust state management + self._datasets: Dict[str, TuneIterableDataset] = { + ds.dataset_name: ds for ds in datasets + } + + # Validate unique dataset names upfront - fail fast with clear error + names = self._dataset_names + if len(names) != len(set(names)): + duplicates = [ + name for name, count in collections.Counter(names).items() if count > 1 + ] + raise ValueError( + f"Duplicate dataset names detected: {duplicates}. All {names=}" + f"Please provide a unique 'dataset_name' for each dataset in the interleaved list." + ) + + self._sampling_generator = torch.Generator().manual_seed(seed) + + # Normalize weights to sum to 1 + #TODO: make it a property? rely on ds.weight? + total_weight = sum(weights) + self._weights = torch.tensor( + [w / total_weight for w in weights], dtype=torch.float + ) + if not math.isclose(total_weight, 1.0, rel_tol=1e-9): + logger.warning( + f"Interleaved dataset normalized weights to sum to 1.0. Found {total_weight=}. Previous {weights=}, new {self._weights.tolist()}" + ) + + @property + def dataset_name(self) -> str: + return self._dataset_name + + def __iter__(self) -> Iterator[Dict[str, Any]]: + """Interleave samples from child infinite datasets""" + child_iters = {name: iter(ds) for name, ds in self._datasets.items()} + + while True: + # Sample which dataset to use + ds_idx = torch.multinomial( + self._weights, 1, replacement=True, generator=self._sampling_generator + ).item() + + # Sample an index, then get the name for safe lookup + ds_name = self._dataset_names[ds_idx] + + try: + sample = next(child_iters[ds_name]) + yield sample + except StopIteration: + # Per the design, child datasets must be infinite. + # We re-initialize to allow for continuous operation but warn loudly + # as this may indicate a design problem in the child dataset. + logger.warning( + f"Child dataset {self._datasets[ds_name].dataset_name} was exhausted. " + "This is unexpected for an infinite dataset. Re-initializing its iterator." + ) + child_iters[ds_name] = iter(self._datasets[ds_name]) + sample = next(child_iters[ds_name]) + yield sample + + def state_dict(self) -> Dict[str, Any]: + """Save state for the interleaver and its children.""" + # The parent is responsible for namespacing the child states. + child_states = {name: ds.state_dict() for name, ds in self._datasets.items()} + return { + "sampling_generator_state": self._sampling_generator.get_state(), + "child_states": child_states, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Load state for the interleaver and its children.""" + self._sampling_generator.set_state(state_dict["sampling_generator_state"]) + child_states = state_dict["child_states"] + for name, ds in self._datasets.items(): + if name in child_states: + # Pass the raw state dict to the child + ds.load_state_dict(child_states[name]) \ No newline at end of file diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py new file mode 100644 index 0000000000..725810541c --- /dev/null +++ b/torchtune/datasets/_iterable_base.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Iterator + +from torch.utils.data import IterableDataset + + +class TuneIterableDataset(IterableDataset, ABC): + """ + Abstract base class for all torchtune iterable datasets. + It defines the minimal, consistent interface required for all dataset + implementations to ensure they are compatible with the training loop, + checkpointing, and metric logging systems. + """ + + @property + @abstractmethod + def dataset_name(self) -> str: + """A unique identifier for the dataset, used for namespacing in metrics and checkpoints.""" + pass + + @abstractmethod + def __iter__(self) -> Iterator[Dict[str, Any]]: + """ + Returns an infinite iterator over the dataset. Each implementation is responsible + for its own iteration logic, including shuffling and making it an infinite stream. + """ + pass + + @abstractmethod + def state_dict(self) -> Dict[str, Any]: + """Returns a state dictionary for checkpointing""" + pass + + @abstractmethod + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Load state from a state dictionary, used when resuming from a checkpoint.""" + pass \ No newline at end of file diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 2e74ec66a0..70bfb75fd5 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Mapping, Optional +from typing import Any, Callable, Mapping, Optional, Dict import numpy as np from datasets import load_dataset @@ -14,6 +14,8 @@ from torchtune.data._messages import validate_messages from torchtune.modules.transforms import Transform +from torchtune.data._metrics import StandardMetricTransform +from torchtune.datasets._hf_iterable import HfIterableDataset class SFTDataset(Dataset): @@ -178,3 +180,96 @@ def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: tokenized_dict = transformed_sample return tokenized_dict + + +class SFTOutputTransform(Transform): + """ + Output transform to be used in SFT recipes as an input to TuneIterableDataset. + It takes tokenized inputs with "tokens" and "mask" keys and + creates the "labels" key for SFT training. + + The labels are created by: + 1. Shifting tokens by 1 position (for autoregressive training) + 2. Masking positions where mask[1:] is True with CROSS_ENTROPY_IGNORE_IDX + 3. Adding CROSS_ENTROPY_IGNORE_IDX at the end + """ + + def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: + # Create a copy to avoid modifying the original + tokenized_dict = dict(sample) + + if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): + keys_str = ", ".join(tokenized_dict.keys()) + raise ValueError( + f"SFTOutputTransform expects 'tokens' and 'mask' keys. " + f"Got keys: {keys_str}" + ) + + # Create labels for SFT training + tokenized_dict["labels"] = list( + np.where( + tokenized_dict["mask"][1:], + CROSS_ENTROPY_IGNORE_IDX, + tokenized_dict["tokens"][1:], + ) + ) + tokenized_dict["labels"].append(CROSS_ENTROPY_IGNORE_IDX) + assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) + + return tokenized_dict + + +def sft_iterable_dataset( + model_transform: Transform, + *, + message_transform: Transform, + shuffle_buffer_size: Optional[int] = 1000, + seed: int = 42, + num_shards_per_rank: int = 64, + dataset_name: Optional[str] = None, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[Dict[str, Any]] = None, + **load_dataset_kwargs: Dict[str, Any], +) -> HfIterableDataset: + """ + Creates an SFT-ready iterable dataset with appropriate output transform. + + Args: + model_transform (Transform): Usually the tokenizer + message_transform (Transform): Transform to convert raw data to messages + shuffle_buffer_size (Optional[int]): Buffer size for shuffling + seed (int): Random seed for shuffling + num_shards_per_rank (int): Target shards per worker + dataset_name (Optional[str]): Name for metrics namespacing + filter_fn (Optional[Callable]): Filter function + filter_kwargs (Optional[Dict[str, Any]]): Filter function kwargs + **load_dataset_kwargs: Args passed to load_dataset + + Returns: + HfIterableDataset: Configured for SFT training + + Example: + >>> from torchtune.data import AlpacaToMessages + >>> message_transform = AlpacaToMessages(train_on_input=False) + >>> ds = sft_iterable_dataset( + ... message_transform=message_transform, + ... model_transform=tokenizer, + ... path="tatsu-lab/alpaca" + ... ) + """ + + output_transform = SFTOutputTransform() + + return HfIterableDataset( + message_transform=message_transform, + model_transform=model_transform, + output_transform=output_transform, + metric_transform=StandardMetricTransform(), + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + num_shards_per_rank=num_shards_per_rank, + dataset_name=dataset_name, + filter_fn=filter_fn, + filter_kwargs=filter_kwargs, + **load_dataset_kwargs, + ) diff --git a/torchtune/datasets/_slimorca.py b/torchtune/datasets/_slimorca.py index ac49b56d63..77667aa579 100644 --- a/torchtune/datasets/_slimorca.py +++ b/torchtune/datasets/_slimorca.py @@ -7,9 +7,11 @@ from typing import Any, Callable, Optional, Union from torchtune.data import ShareGPTToMessages +from torchtune.data._metrics import StandardMetricTransform +from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._packed import PackedDataset -from torchtune.datasets._sft import SFTDataset +from torchtune.datasets._sft import SFTDataset, sft_iterable_dataset from torchtune.modules.transforms.tokenizers import ModelTokenizer @@ -94,3 +96,68 @@ def slimorca_dataset( ) return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len) return ds + + +def slimorca_iterable_dataset( + model_transform: ModelTokenizer, + *, + source: str = "Open-Orca/SlimOrca-Dedup", + column_map: Optional[dict[str, str]] = None, + train_on_input: bool = False, + new_system_prompt: Optional[str] = None, + shuffle_buffer_size: Optional[int] = 1000, + seed: int = 42, + num_shards_per_rank: int = 64, + dataset_name: Optional[str] = None, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[dict[str, Any]] = None, + **load_dataset_kwargs: dict[str, Any], +) -> HfIterableDataset: + """ + Support for SlimOrca-style conversational datasets using iterable approach. + + This creates an infinite iterable dataset that automatically shards and shuffles data, + making it suitable for step-based training without explicit epoch boundaries. + + Args: + model_transform (ModelTokenizer): Model tokenizer used to tokenize the messages. + source (str): path to dataset repository on Hugging Face. Default is "Open-Orca/SlimOrca-Dedup". + column_map (Optional[dict[str, str]]): mapping from expected "conversations" column + to actual column name in dataset. If None, uses default "conversations". + train_on_input (bool): Whether to train on input or mask it. Default is False. + new_system_prompt (Optional[str]): If specified, prepend system message to every sample. + shuffle_buffer_size (Optional[int]): Size of shuffle buffer. If None or 0, no shuffling. + seed (int): Seed for shuffling. Default is 42. + num_shards_per_rank (int): Target number of shards per worker. Default is 64. + dataset_name (Optional[str]): Name for metrics. If None, auto-generated from source. + filter_fn (Optional[Callable]): Filter function to apply to dataset. + filter_kwargs (Optional[dict[str, Any]]): Kwargs for filter function. + **load_dataset_kwargs: Additional kwargs for load_dataset. + + Returns: + HfIterableDataset: Configured iterable dataset + + Example: + >>> from torchtune.datasets import slimorca_iterable_dataset + >>> ds = slimorca_iterable_dataset(shuffle_buffer_size=1000) + >>> for sample in ds: + >>> print(sample["tokens"][:10]) # First 10 tokens + """ + message_transform = ShareGPTToMessages( + train_on_input=train_on_input, + column_map=column_map, + new_system_prompt=new_system_prompt, + ) + + return sft_iterable_dataset( + source=source, + message_transform=message_transform, + model_transform=model_transform, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + num_shards_per_rank=num_shards_per_rank, + dataset_name=dataset_name, + filter_fn=filter_fn, + filter_kwargs=filter_kwargs, + **load_dataset_kwargs, + ) diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index 39b8989284..5f29e038c6 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -28,6 +28,7 @@ from torchtune.training.checkpointing._checkpointer import DistributedCheckpointer from torchtune.training.checkpointing._utils import get_most_recent_checkpoint from torchtune.training.memory import OptimizerInBackwardWrapper +from torchtune.data import MetricsAggregator log = utils.get_logger("DEBUG") import torchdata @@ -47,6 +48,7 @@ class TrainingProgress: total_training_steps: Optional[int] = None dataloader_state_dict: Optional[dict[str, Any]] = None val_dataloader_state_dict: Optional[dict[str, Any]] = None + metrics_aggregator_state_dict: Optional[dict[str, Any]] = None def state_dict(self) -> dict[str, object]: return { @@ -58,6 +60,7 @@ def state_dict(self) -> dict[str, object]: "total_training_steps": self.total_training_steps, training.DATALOADER_KEY: self.dataloader_state_dict, training.VAL_DATALOADER_KEY: self.val_dataloader_state_dict, + "metrics_aggregator_state_dict": self.metrics_aggregator_state_dict, } @@ -442,6 +445,7 @@ def load_distributed_checkpoint( adapter_config: Optional[dict[str, Any]] = None, dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader] = None, single_device: bool = False, + metrics_aggregator: Optional[MetricsAggregator] = None, ) -> dict[str, Any]: """ This method is used to resume training from a distributed checkpoint state. @@ -459,6 +463,7 @@ def load_distributed_checkpoint( checkpoint_dict: dict[str, Any] = {} model_state_dict = model.state_dict() optim_state_dict = optimizer.state_dict() + metrics_aggregator_state_dict = metrics_aggregator.state_dict() if metrics_aggregator else {} # Hack to properly initialize the learning rate scheduler # TODO: Find a better way to do this, possibly by including the following @@ -481,6 +486,7 @@ def load_distributed_checkpoint( "steps_run": 0, "total_training_steps": 0, training.DATALOADER_KEY: dataloader.state_dict() if dataloader else {}, + "metrics_aggregator_state_dict": metrics_aggregator_state_dict, } ) From 2212b19d861ed07381def0b8a93680549a888617 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 25 Jun 2025 12:59:02 -0400 Subject: [PATCH 02/25] update tests --- tests/torchtune/datasets/test_interleaved.py | 80 ++++++++++- .../torchtune/datasets/test_iterable_utils.py | 126 ++++++++++++++++++ 2 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 tests/torchtune/datasets/test_iterable_utils.py diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 1190d6d774..3c7a1c6fae 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -4,15 +4,89 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import tempfile +from pathlib import Path from itertools import islice from typing import Any, Dict, Iterator from unittest.mock import patch import pytest import torch +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchtune.data import AggregationType, Metric, MetricsAggregator, StandardMetricTransform +from torchtune.datasets import InterleavedDataset, HfIterableDataset + +# Import test utilities +from .test_iterable_utils import collate_with_metrics, generate_ckpt + +# Test Constants +SMALL_DATASET_SIZE = 23 +MEDIUM_DATASET_SIZE = 35 +SEED = 42 +BATCH_SIZE = 5 + + +def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None: + """Creates a dummy JSON test data file with token samples of varying lengths. + + Args: + path (Path): The path to the file to create + num_samples (int): The number of samples to create + offset (int): The offset to add to the sample ID to ensure unique IDs in different datasets + """ + with open(path, "w") as f: + for i in range(num_samples): + sample_id = i + offset + # Realistic token length variation (1-3 tokens) + token_len = (i % 3) + 1 + tokens = list(range(sample_id, sample_id + token_len)) + f.write( + f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}"}}\n' + ) + -from torchtune.data import AggregationType, Metric, MetricsAggregator -from torchtune.datasets import InterleavedDataset, TuneIterableDataset +@pytest.fixture +def tmp_data_dir(tmp_path): + """Provide temporary directory for test data files.""" + return tmp_path + + +@pytest.fixture +def small_dataset_file(tmp_data_dir): + path = tmp_data_dir / "small_data.json" + create_test_json_file(path, SMALL_DATASET_SIZE, offset=0) + return str(path) + + +@pytest.fixture +def medium_dataset_file(tmp_data_dir): + path = tmp_data_dir / "medium_data.json" + create_test_json_file(path, MEDIUM_DATASET_SIZE, offset=100) + return str(path) + + +@pytest.fixture +def dataset_factory(): + """Factory for creating HfIterableDataset instances with common defaults.""" + def _create_dataset( + data_file: str, + dataset_name: str = "test_dataset", + shuffle: bool = False, + **kwargs + ) -> HfIterableDataset: + return HfIterableDataset( + path="json", + data_files=data_file, + split="train", + dataset_name=dataset_name, + seed=SEED, + shuffle_buffer_size=10 if shuffle else 0, + metric_transform=StandardMetricTransform(), + num_shards_per_rank=2, + **kwargs + ) + return _create_dataset class TestInterleavedDataset: @@ -90,7 +164,7 @@ def test_metrics_aggregation( # Process some samples TOTAL_SAMPLES = 200 - for sample in islice(iter(interleaved), 200): + for sample in islice(iter(interleaved), TOTAL_SAMPLES): aggregator.update(sample["metrics"]) metrics = aggregator.get_metrics_for_logging() diff --git a/tests/torchtune/datasets/test_iterable_utils.py b/tests/torchtune/datasets/test_iterable_utils.py new file mode 100644 index 0000000000..8d4d6d7849 --- /dev/null +++ b/tests/torchtune/datasets/test_iterable_utils.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional + +from torch.utils.data import DataLoader +from torchtune.data import padded_collate_sft +from torchtune.data._metrics import MetricsAggregator + + +def collate_with_metrics(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + """Collate function that extracts metrics and uses padded_collate_sft for the rest.""" + all_metrics = [] + clean_batch = [] + for sample in batch: + if "metrics" in sample: + all_metrics.extend(sample.pop("metrics")) + clean_batch.append(sample) + + if not clean_batch: + return {"metrics": all_metrics} + + # Use torchtune's standard SFT collate function + collated = padded_collate_sft(clean_batch) + collated["metrics"] = all_metrics + return collated + + +def generate_ckpt( + dataloader: DataLoader, + aggregator: MetricsAggregator, + steps_before_checkpoint: int, + steps_after_checkpoint: int, + resume_dataloader: Optional[DataLoader] = None, + resume_aggregator: Optional[MetricsAggregator] = None, +) -> Dict[str, Any]: + """ + Generates a checkpoint by running through data and saving checkpoint mid-stream. + Optionally, a second dataloader and aggregator can be given to resume from ckpt + and run steps_after_checkpoint to match the first one. + + Args: + dataloader: The dataloader to test + aggregator: The metrics aggregator to use + steps_before_checkpoint: Number of steps to run before saving checkpoint + steps_after_checkpoint: Number of steps to run after checkpoint + resume_dataloader: Optional new dataloader to test resuming. If None, returns empty resumed_batches. + resume_aggregator: Optional new aggregator to test resuming. If None, returns empty resumed_metrics. + + Returns dict with batches/metrics from both pre and post checkpoint runs. + """ + iterator = iter(dataloader) + + # Collect batches before and after checkpoint + batches = [] + checkpoint_state = None + metrics_at_checkpoint = {} + + total_steps = steps_before_checkpoint + steps_after_checkpoint + + for idx, batch in enumerate(iterator): + batches.append(batch) + + # Process metrics + if "metrics" in batch: + aggregator.update(batch.pop("metrics")) + + # Save checkpoint state after steps_before_checkpoint + if idx == steps_before_checkpoint - 1: # -1 because idx is 0-based + checkpoint_state = { + "loader": dataloader.state_dict(), + "aggregator": aggregator.state_dict(), + } + metrics_at_checkpoint = aggregator.get_metrics_for_logging(prefix="train") + + # Stop after total steps + if idx == total_steps - 1: + break + + # Split batches + pre_checkpoint_batches = batches[:steps_before_checkpoint] + post_checkpoint_batches = batches[steps_before_checkpoint:] + + # Resume with new instances if provided + resumed_batches = [] + resumed_metrics = {} + + if ( + resume_dataloader is not None + and resume_aggregator is not None + and checkpoint_state is not None + ): + # Test resuming with new instances + resume_dataloader.load_state_dict(checkpoint_state["loader"]) + resume_aggregator.load_state_dict(checkpoint_state["aggregator"]) + resume_iterator = iter(resume_dataloader) + + # Collect only the post-checkpoint batches when resuming + for idx, batch in enumerate(resume_iterator): + resumed_batches.append(batch) + + # Process metrics + if "metrics" in batch: + resume_aggregator.update(batch.pop("metrics")) + + # Stop after steps_after_checkpoint + if idx == steps_after_checkpoint - 1: + break + + resumed_metrics = resume_aggregator.get_metrics_for_logging(prefix="train") + + return { + # Original run + "pre_checkpoint_batches": pre_checkpoint_batches, + "post_checkpoint_batches": post_checkpoint_batches, + "metrics_at_checkpoint": metrics_at_checkpoint, + "final_metrics": aggregator.get_metrics_for_logging(prefix="train"), + # Resumed run + "resumed_batches": resumed_batches, + "resumed_metrics": resumed_metrics, + # Internal state for loading - only if someone needs to manually load + "_checkpoint_state": checkpoint_state, + } \ No newline at end of file From 2eb68b6d00b301bdde940e8cbfd4efef6a5cba60 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 25 Jun 2025 12:40:43 -0700 Subject: [PATCH 03/25] linter --- recipes/full_finetune_distributed.py | 148 +++++++++++------- .../torchtune/data/test_metrics_aggregator.py | 12 +- .../torchtune/data/test_metrics_transform.py | 18 +-- tests/torchtune/datasets/test_hf_iterable.py | 86 +++++----- tests/torchtune/datasets/test_interleaved.py | 32 ++-- .../torchtune/datasets/test_iterable_utils.py | 27 ++-- torchtune/data/__init__.py | 14 +- torchtune/data/_aggregator.py | 51 +++--- torchtune/data/_metrics.py | 13 +- torchtune/datasets/__init__.py | 2 +- torchtune/datasets/_alpaca.py | 5 +- torchtune/datasets/_hf_iterable.py | 30 ++-- torchtune/datasets/_interleaved.py | 34 ++-- torchtune/datasets/_iterable_base.py | 16 +- torchtune/datasets/_sft.py | 38 ++--- torchtune/datasets/_slimorca.py | 13 +- .../checkpointing/_checkpoint_client.py | 6 +- 17 files changed, 307 insertions(+), 238 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index f34ccc6a7e..a4eb87e7d6 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -9,11 +9,11 @@ import time from functools import partial -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from warnings import warn import torch -from omegaconf import DictConfig, ListConfig +from omegaconf import dictConfig, listConfig from torch import nn from torch.distributed import destroy_process_group, init_process_group @@ -22,11 +22,10 @@ from torch.optim import Optimizer from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp from torchdata.stateful_dataloader import StatefulDataLoader -from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchtune import config, modules, training, utils from torchtune.config._utils import _get_component_from_path -from torchtune.data import padded_collate_packed, MetricsAggregator -from torchtune.datasets import ConcatDataset, InterleavedDataset +from torchtune.data import MetricsAggregator, padded_collate_packed +from torchtune.datasets import InterleavedDataset from torchtune.modules.embedding_utils import resize_token_embeddings from torchtune.modules.loss import SFTLoss from torchtune.modules.moe import utils as moe_utils @@ -120,7 +119,7 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): has example commands for how to kick-off training. Args: - cfg (DictConfig): OmegaConf object parsed from yaml file + cfg (dictConfig): OmegaConf object parsed from yaml file Raises: ValueError: If ``dtype`` is set to fp16. @@ -130,7 +129,7 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ - def __init__(self, cfg: DictConfig) -> None: + def __init__(self, cfg: dictConfig) -> None: device_type = cfg.device self._device = utils.get_device(device=device_type) self._dtype = training.get_dtype(cfg.dtype, device=self._device) @@ -274,12 +273,12 @@ def __init__(self, cfg: DictConfig) -> None: seed=cfg.seed, debug_mode=cfg.get("cudnn_deterministic_mode", None) ) self.global_step = 0 - + # Step-based training support self.num_training_steps = cfg.num_training_steps self._dataset_metrics_log_freq = cfg.get("dataset_metrics_log_freq", 100) self._metrics_aggregator = None # Will be initialized in setup - + def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: """ Updates the recipe state from checkpoint. @@ -304,14 +303,16 @@ def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: "Are you sure you passed in the right recipe checkpoint?" ) from e - def setup(self, cfg: DictConfig) -> None: + def setup(self, cfg: dictConfig) -> None: """ Setup the recipe. This includes training state (if resume_from_checkpoint is True), model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. """ if cfg.get("dataset_val") is not None: - raise NotImplementedError("Validation is not supported yet with iterable datasets.") - + raise NotImplementedError( + "Validation is not supported yet with iterable datasets." + ) + if self.fsdp_cpu_offload: # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU @@ -333,7 +334,7 @@ def setup(self, cfg: DictConfig) -> None: self._compile_loss = compile_bool self._compile_optimizer_step = compile_bool self._compile_scale_grads = compile_bool - if isinstance(compile, DictConfig): + if isinstance(compile, dictConfig): self._compile_model = compile.get("model", True) self._compile_loss = compile.get("loss", True) self._compile_optimizer_step = compile.get("optimizer_step", False) @@ -470,7 +471,7 @@ def setup(self, cfg: DictConfig) -> None: def _setup_lr_scheduler( self, - cfg_lr_scheduler: Optional[DictConfig], + cfg_lr_scheduler: Optional[dictConfig], num_training_steps: int, last_epoch: int, ) -> Optional[Optimizer]: @@ -479,7 +480,7 @@ def _setup_lr_scheduler( It supports both standard optimization and optimizer-in-backward cases. Args: - cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration. + cfg_lr_scheduler (Optional[dictConfig]): The learning rate scheduler configuration. num_training_steps (int): The total number of training steps. last_epoch (int): The index of the last epoch. @@ -518,14 +519,14 @@ def _setup_lr_scheduler( return lr_scheduler def _setup_profiler( - self, cfg_profiler: Optional[DictConfig] = None + self, cfg_profiler: Optional[dictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: """ Parses the `profiler` section of top-level `cfg` and sets up profiler """ # Missing profiler section in config, assume disabled if cfg_profiler is None: - cfg_profiler = DictConfig({"enabled": False}) + cfg_profiler = dictConfig({"enabled": False}) # Check that component is included and set correctly if cfg_profiler.get("_component_", None) is None: @@ -552,7 +553,7 @@ def _setup_profiler( def _setup_model( self, - cfg_model: DictConfig, + cfg_model: dictConfig, enable_activation_checkpointing: bool, enable_activation_offloading: bool, activation_offloading_use_streams: bool, @@ -710,7 +711,7 @@ def _setup_model( def _setup_optimizer( self, - cfg_optimizer: DictConfig, + cfg_optimizer: dictConfig, optimizer_in_bwd: bool = False, opt_state_dict: Optional[dict[str, Any]] = None, ) -> Optional[Optimizer]: @@ -763,7 +764,7 @@ def _setup_optimizer( def _setup_data( self, - cfg_dataset: Union[DictConfig, ListConfig], + cfg_dataset: Union[dictConfig, listConfig], batch_size: int, collate_fn: str, dataloader_state_dict: Optional[dict[str, Any]] = None, @@ -776,7 +777,7 @@ def _setup_data( iterable_datasets = [] weights = [] cfg_dataset_list = cfg_dataset - if not isinstance(cfg_dataset_list, ListConfig): + if not isinstance(cfg_dataset_list, listConfig): cfg_dataset_list = [cfg_dataset_list] for ds_cfg in cfg_dataset_list: @@ -787,25 +788,25 @@ def _setup_data( # 2. Interleave datasets if any if len(iterable_datasets) > 1: ds = InterleavedDataset( - datasets=iterable_datasets, + datasets=iterable_datasets, weights=weights, seed=self.seed, ) else: ds = iterable_datasets[0] - + # 3. Apply packing # TODO: follow up PR packed = False # 4. Define a collate function wrapper to handle metrics base_collate_fn = ( - padded_collate_packed - if packed - else _get_component_from_path(collate_fn) + padded_collate_packed if packed else _get_component_from_path(collate_fn) ) - def _collate_with_metrics_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: + def _collate_with_metrics_wrapper( + batch: list[dict[str, Any]] + ) -> dict[str, Any]: # TODO: handling of metrics should prob be done in collate_fn. # putting this here for now to avoid making more changes to this PR. all_metrics = [] @@ -814,7 +815,7 @@ def _collate_with_metrics_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any] if "metrics" in sample: all_metrics.extend(sample.pop("metrics")) clean_batch.append(sample) - + collated_batch = base_collate_fn(clean_batch) collated_batch["metrics"] = all_metrics return collated_batch @@ -901,12 +902,16 @@ def save_checkpoint(self, *, epoch: int, step: int, full_tensors: bool): """Save checkpoint based on global step.""" self._checkpoint_client.save_checkpoint( model=self._model, - optimizer=(self._optimizer if not self._optimizer_in_bwd else self._optim_ckpt_wrapper), - training_progress = TrainingProgress( + optimizer=( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + training_progress=TrainingProgress( seed=self.seed, - epochs_run=0, # TODO: not needed. To be deprecated. - total_epochs=1, # TODO: not needed. To be deprecated. - max_steps_per_epoch=-1, # TODO: not needed. To be deprecated. + epochs_run=0, # TODO: not needed. To be deprecated. + total_epochs=1, # TODO: not needed. To be deprecated. + max_steps_per_epoch=-1, # TODO: not needed. To be deprecated. steps_run=self.global_step, total_training_steps=self.num_training_steps, dataloader_state_dict=self._dataloader.state_dict(), @@ -915,10 +920,10 @@ def save_checkpoint(self, *, epoch: int, step: int, full_tensors: bool): if self._val_dataloader is not None else {} ), - #FIXME: add to load_ckpt and TrainingProgress too + # FIXME: add to load_ckpt and TrainingProgress too metrics_aggregator_state_dict=self._metrics_aggregator.state_dict(), ), - epoch=epoch, # TODO: not needed. To be deprecated. + epoch=epoch, # TODO: not needed. To be deprecated. step=step, single_device=False, full_tensors=full_tensors, @@ -945,9 +950,11 @@ def train(self) -> None: num_tokens = 0 self._profiler.start() - - pbar = tqdm(initial=self.global_step, total=self.num_training_steps, desc="Training") - + + pbar = tqdm( + initial=self.global_step, total=self.num_training_steps, desc="Training" + ) + dataloader_iter = iter(self._dataloader) batch_count = 0 @@ -955,18 +962,19 @@ def train(self) -> None: try: batch = next(dataloader_iter) except StopIteration: - self._logger.warning("Dataloader iterator exhausted unexpectedly. Ending training.") + self._logger.warning( + "Dataloader iterator exhausted unexpectedly. Ending training." + ) break - + if "metrics" in batch: self._metrics_aggregator.update(batch.pop("metrics")) - - # Start tracking CUDA memory for active steps for just the first epoch + + # Start tracking CUDA memory for active steps for just the first epoch if ( self._is_rank_zero and self.profiler_profile_memory - and batch_count - == self.profiler_wait_steps + self.profiler_warmup_steps + and batch_count == self.profiler_wait_steps + self.profiler_warmup_steps and self._device.type == "cuda" ): torch.cuda.memory._record_memory_history() @@ -975,9 +983,7 @@ def train(self) -> None: # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step - current_num_tokens = ( - batch["labels"] != self._loss_fn.ignore_index - ).sum() + current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum() num_tokens += current_num_tokens with self.train_context( @@ -1038,32 +1044,54 @@ def train(self) -> None: loss_to_log = running_loss.detach().item() / num_tokens pbar.update(1) - pbar.set_description(f"Step: {self.global_step}|Loss: {loss_to_log:.4f}") - + pbar.set_description( + f"Step: {self.global_step}|Loss: {loss_to_log:.4f}" + ) + # Log per-step metrics - if self.global_step % self._log_every_n_steps == 0 and self._is_rank_zero: + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): time_per_step = time.perf_counter() - t0 log_dict = { "loss": loss_to_log, - "lr": get_lr(self._optimizer if not self._optimizer_in_bwd else self._optim_ckpt_wrapper), - "tokens_per_second_per_gpu": (num_tokens / self.parallel_dims.non_data_parallel_size) / (time_per_step * self.world_size), + "lr": get_lr( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + "tokens_per_second_per_gpu": ( + num_tokens / self.parallel_dims.non_data_parallel_size + ) + / (time_per_step * self.world_size), } if self._log_peak_memory_stats: log_dict.update(training.get_memory_stats(device=self._device)) if self._clip_grad_norm is not None: log_dict.update({"grad_norm": grad_norm}) self._metric_logger.log_dict(log_dict, step=self.global_step) - + # Log dataset metrics # #TODO: it requires all_gather. Should we keep a separate log_freq for this? - if self.global_step % self._dataset_metrics_log_freq == 0 and self._is_rank_zero: - dataset_metrics = self._metrics_aggregator.get_metrics_for_logging(prefix="train") + if ( + self.global_step % self._dataset_metrics_log_freq == 0 + and self._is_rank_zero + ): + dataset_metrics = self._metrics_aggregator.get_metrics_for_logging( + prefix="train" + ) self._metric_logger.log_dict(dataset_metrics, step=self.global_step) - + # Save checkpoint if specified by user - if self.save_every_n_steps is not None and self.global_step % self.save_every_n_steps == 0: - self.save_checkpoint(epoch=0, step=self.global_step, full_tensors=False) - + if ( + self.save_every_n_steps is not None + and self.global_step % self.save_every_n_steps == 0 + ): + self.save_checkpoint( + epoch=0, step=self.global_step, full_tensors=False + ) + # Reset running stats for the next step running_loss = 0 num_tokens = 0 @@ -1102,7 +1130,7 @@ def cleanup(self) -> None: @config.parse -def recipe_main(cfg: DictConfig) -> None: +def recipe_main(cfg: dictConfig) -> None: """ Entry point for the recipe. diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index 382d968704..69a2f32967 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -4,9 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import collections import pytest -from unittest.mock import patch from torchtune.data import AggregationType, Metric, MetricsAggregator @@ -63,7 +61,9 @@ def test_distribution_metrics(self): assert result["train/test/dist_metric_mean"] == 5.5 assert result["train/test/dist_metric_min"] == 1 assert result["train/test/dist_metric_max"] == 10 - assert result["train/test/dist_metric_p50"] == 5 # Median of 1-10 is 5 (index 4, value 5) + assert ( + result["train/test/dist_metric_p50"] == 5 + ) # Median of 1-10 is 5 (index 4, value 5) def test_state_management(self): """Test aggregator checkpointing and restoration.""" @@ -107,7 +107,7 @@ def test_state_management(self): def test_multiple_datasets(self): """Test that metrics from multiple datasets are correctly namespaced.""" aggregator = MetricsAggregator() - + metrics = [ Metric("dataset1", "samples", 100, AggregationType.SUM), Metric("dataset2", "samples", 200, AggregationType.SUM), @@ -117,7 +117,7 @@ def test_multiple_datasets(self): aggregator.update(metrics) result = aggregator.get_metrics_for_logging(prefix="train") - + assert result["train/dataset1/samples"] == 100 assert result["train/dataset2/samples"] == 200 assert result["train/dataset1/tokens"] == 1000 @@ -146,4 +146,4 @@ def test_prefix_handling(self): # Test without prefix result_no_prefix = aggregator.get_metrics_for_logging() assert result_no_prefix["test_ds/metric1"] == 42 - assert result_no_prefix["test_ds/metric2"] == 84 \ No newline at end of file + assert result_no_prefix["test_ds/metric2"] == 84 diff --git a/tests/torchtune/data/test_metrics_transform.py b/tests/torchtune/data/test_metrics_transform.py index 1eed534e42..2c8f3023b2 100644 --- a/tests/torchtune/data/test_metrics_transform.py +++ b/tests/torchtune/data/test_metrics_transform.py @@ -6,7 +6,7 @@ import pytest -from torchtune.data import AggregationType, Metric, StandardMetricTransform +from torchtune.data import AggregationType, StandardMetricTransform class TestStandardMetricTransform: @@ -16,7 +16,7 @@ def test_dataset_name_not_set_raises_error(self): """Test that using transform without setting dataset name raises error.""" transform = StandardMetricTransform() sample = {"tokens": [1, 2, 3]} - + with pytest.raises(RuntimeError, match="set_dataset_name"): transform(sample) @@ -24,31 +24,31 @@ def test_basic_metrics_generation(self): """Test that transform generates expected metrics for a sample.""" transform = StandardMetricTransform() transform.set_dataset_name("test_dataset") - + sample = {"tokens": [1, 2, 3, 4, 5]} result = transform(sample) - + # Should preserve original sample data assert result["tokens"] == [1, 2, 3, 4, 5] - + # Should add metrics assert "metrics" in result metrics = result["metrics"] assert len(metrics) == 3 - + # Check each metric for metric in metrics: if metric.name == "samples_seen": assert metric.dataset_name == "test_dataset" assert metric.value == 1 assert metric.agg_type == AggregationType.SUM - + elif metric.name == "tokens_seen": assert metric.dataset_name == "test_dataset" assert metric.value == 5 assert metric.agg_type == AggregationType.SUM - + elif metric.name == "seq_len": assert metric.dataset_name == "test_dataset" assert metric.value == 5 - assert metric.agg_type == AggregationType.DISTRIBUTION \ No newline at end of file + assert metric.agg_type == AggregationType.DISTRIBUTION diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index 4cf303c6fd..a263258ae8 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -4,19 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import collections -import tempfile -from pathlib import Path from itertools import islice -from typing import Any, Callable, Dict, List, Optional -from unittest.mock import Mock, patch +from pathlib import Path +from typing import Any, Optional import pytest -import torch -from torch.nn.utils.rnn import pad_sequence from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data import AggregationType, Metric, MetricsAggregator, StandardMetricTransform, padded_collate_sft +from torchtune.data import ( + MetricsAggregator, + padded_collate_sft, + StandardMetricTransform, +) from torchtune.datasets import HfIterableDataset @@ -47,7 +46,7 @@ def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None ) -def collate_with_metrics(batch: List[Dict[str, Any]]) -> Dict[str, Any]: +def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: """Collate function that extracts metrics and uses padded_collate_sft as base collator.""" # Extract metrics first all_metrics = [] @@ -73,21 +72,24 @@ def generate_ckpt( steps_after_checkpoint: int, resume_dataloader: Optional[StatefulDataLoader] = None, resume_aggregator: Optional[MetricsAggregator] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Generates a checkpoint by running through data and saving checkpoint mid-stream. Optionally, a second dataloader and aggregator can be given to resume from ckpt and run steps_after_checkpoint to match the first one. Args: - dataloader: The dataloader to test - aggregator: The metrics aggregator to use - steps_before_checkpoint: Number of steps to run before saving checkpoint - steps_after_checkpoint: Number of steps to run after checkpoint - resume_dataloader: Optional new dataloader to test resuming. If None, returns empty resumed_batches. - resume_aggregator: Optional new aggregator to test resuming. If None, returns empty resumed_metrics. - - Returns dict with batches/metrics from both pre and post checkpoint runs. + dataloader (StatefulDataLoader): The dataloader to test + aggregator (MetricsAggregator): The metrics aggregator to use + steps_before_checkpoint (int): Number of steps to run before saving checkpoint + steps_after_checkpoint (int): Number of steps to run after checkpoint + resume_dataloader (Optional[StatefulDataLoader]): Optional new dataloader to test resuming. + If None, returns empty resumed_batches. + resume_aggregator (Optional[MetricsAggregator]): Optional new aggregator to test resuming. + If None, returns empty resumed_metrics. + + Returns: + dict[str, Any]: Dict with batches/metrics from both pre and post checkpoint runs. """ iterator = iter(dataloader) @@ -179,11 +181,12 @@ def small_dataset_file(tmp_data_dir): @pytest.fixture def dataset_factory(): """Factory for creating HfIterableDataset instances with common defaults.""" + def _create_dataset( data_file: str, dataset_name: str = "test_dataset", shuffle: bool = False, - **kwargs + **kwargs, ) -> HfIterableDataset: return HfIterableDataset( path="json", @@ -194,8 +197,9 @@ def _create_dataset( shuffle_buffer_size=10 if shuffle else 0, metric_transform=StandardMetricTransform(), num_shards_per_rank=2, - **kwargs + **kwargs, ) + return _create_dataset @@ -223,7 +227,7 @@ def test_default_dataset_name(self, small_dataset_file): path="json", data_files=small_dataset_file, split="train", - dataset_name = "my_dataset", + dataset_name="my_dataset", seed=SEED, metric_transform=StandardMetricTransform(), num_shards_per_rank=4, @@ -288,8 +292,8 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file): ) # Get samples from two passes through the dataset - epoch_samples = islice(iter(unshuffled_ds), SMALL_DATASET_SIZE*2) - + epoch_samples = islice(iter(unshuffled_ds), SMALL_DATASET_SIZE * 2) + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] @@ -303,8 +307,8 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file): ) # Collect full epochs to compare - epoch_samples = islice(iter(shuffled_ds), SMALL_DATASET_SIZE*2) - + epoch_samples = islice(iter(shuffled_ds), SMALL_DATASET_SIZE * 2) + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] @@ -317,23 +321,35 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file): ), f"Shuffled epochs should be shuffled differently, got {first_epoch_samples} and {second_epoch_samples}" # But should contain the same set of IDs - assert set(first_epoch_samples) == set(range(SMALL_DATASET_SIZE)), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_samples}" - assert set(second_epoch_samples) == set(range(SMALL_DATASET_SIZE)), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_samples}" + assert set(first_epoch_samples) == set( + range(SMALL_DATASET_SIZE) + ), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_samples}" + assert set(second_epoch_samples) == set( + range(SMALL_DATASET_SIZE) + ), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_samples}" def test_epoch_tracking(self, dataset_factory, small_dataset_file): """Test that epoch number is correctly tracked across dataset restarts.""" dataset = dataset_factory(small_dataset_file, shuffle=False) - + # Two epoch samples - epoch_samples = islice(iter(dataset), SMALL_DATASET_SIZE*2) - + epoch_samples = islice(iter(dataset), SMALL_DATASET_SIZE * 2) + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] # All should have epoch 0 - epoch_values = [epoch_metric.value for epoch_metric in first_epoch_samples["metrics"]] - assert all(epoch_value == 0 for epoch_value in epoch_values), f"Epoch values should be 0, got {epoch_values}" - + epoch_values = [ + epoch_metric.value for epoch_metric in first_epoch_samples["metrics"] + ] + assert all( + epoch_value == 0 for epoch_value in epoch_values + ), f"Epoch values should be 0, got {epoch_values}" + # All should have epoch 1 - epoch_values = [epoch_metric.value for epoch_metric in second_epoch_samples["metrics"]] - assert all(epoch_value == 1 for epoch_value in epoch_values), f"Epoch values should be 1, got {epoch_values}" + epoch_values = [ + epoch_metric.value for epoch_metric in second_epoch_samples["metrics"] + ] + assert all( + epoch_value == 1 for epoch_value in epoch_values + ), f"Epoch values should be 1, got {epoch_values}" diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 3c7a1c6fae..e06ef670c1 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -4,18 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import tempfile -from pathlib import Path from itertools import islice -from typing import Any, Dict, Iterator +from pathlib import Path from unittest.mock import patch import pytest import torch from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data import AggregationType, Metric, MetricsAggregator, StandardMetricTransform -from torchtune.datasets import InterleavedDataset, HfIterableDataset +from torchtune.data import MetricsAggregator, StandardMetricTransform +from torchtune.datasets import HfIterableDataset, InterleavedDataset # Import test utilities from .test_iterable_utils import collate_with_metrics, generate_ckpt @@ -69,11 +67,12 @@ def medium_dataset_file(tmp_data_dir): @pytest.fixture def dataset_factory(): """Factory for creating HfIterableDataset instances with common defaults.""" + def _create_dataset( data_file: str, dataset_name: str = "test_dataset", shuffle: bool = False, - **kwargs + **kwargs, ) -> HfIterableDataset: return HfIterableDataset( path="json", @@ -84,8 +83,9 @@ def _create_dataset( shuffle_buffer_size=10 if shuffle else 0, metric_transform=StandardMetricTransform(), num_shards_per_rank=2, - **kwargs + **kwargs, ) + return _create_dataset @@ -107,10 +107,10 @@ def test_initialization_validation(self, dataset_factory, small_dataset_file): with patch("logging.Logger.warning") as mock_warning: interleaved = InterleavedDataset( - datasets=[ds3, ds4], - weights=[0.5, 1.5], + datasets=[ds3, ds4], + weights=[0.5, 1.5], seed=SEED, - dataset_name="test_interleaved" # Sum = 2.0 != 1.0 + dataset_name="test_interleaved", # Sum = 2.0 != 1.0 ) # Check that weights were normalized @@ -163,8 +163,8 @@ def test_metrics_aggregation( aggregator = MetricsAggregator() # Process some samples - TOTAL_SAMPLES = 200 - for sample in islice(iter(interleaved), TOTAL_SAMPLES): + total_samples = 200 + for sample in islice(iter(interleaved), total_samples): aggregator.update(sample["metrics"]) metrics = aggregator.get_metrics_for_logging() @@ -181,11 +181,11 @@ def test_metrics_aggregation( calculated_total_samples = ( metrics["ds1/samples_seen"] + metrics["ds2/samples_seen"] ) - assert calculated_total_samples == TOTAL_SAMPLES + assert calculated_total_samples == total_samples # Test that ratio is approximately correct - ds1_ratio = metrics["ds1/samples_seen"] / TOTAL_SAMPLES - ds2_ratio = metrics["ds2/samples_seen"] / TOTAL_SAMPLES + ds1_ratio = metrics["ds1/samples_seen"] / total_samples + ds2_ratio = metrics["ds2/samples_seen"] / total_samples # Allow 10% tolerance due to randomness assert abs(ds1_ratio - 0.2) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.2" @@ -232,5 +232,3 @@ def create_interleaved(): assert ( result["final_metrics"] == result["resumed_metrics"] ), "Final metrics should match" - - \ No newline at end of file diff --git a/tests/torchtune/datasets/test_iterable_utils.py b/tests/torchtune/datasets/test_iterable_utils.py index 8d4d6d7849..5a4fbda8a5 100644 --- a/tests/torchtune/datasets/test_iterable_utils.py +++ b/tests/torchtune/datasets/test_iterable_utils.py @@ -4,14 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from torch.utils.data import DataLoader from torchtune.data import padded_collate_sft from torchtune.data._metrics import MetricsAggregator -def collate_with_metrics(batch: List[Dict[str, Any]]) -> Dict[str, Any]: +def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: """Collate function that extracts metrics and uses padded_collate_sft for the rest.""" all_metrics = [] clean_batch = [] @@ -36,21 +36,24 @@ def generate_ckpt( steps_after_checkpoint: int, resume_dataloader: Optional[DataLoader] = None, resume_aggregator: Optional[MetricsAggregator] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Generates a checkpoint by running through data and saving checkpoint mid-stream. Optionally, a second dataloader and aggregator can be given to resume from ckpt and run steps_after_checkpoint to match the first one. Args: - dataloader: The dataloader to test - aggregator: The metrics aggregator to use - steps_before_checkpoint: Number of steps to run before saving checkpoint - steps_after_checkpoint: Number of steps to run after checkpoint - resume_dataloader: Optional new dataloader to test resuming. If None, returns empty resumed_batches. - resume_aggregator: Optional new aggregator to test resuming. If None, returns empty resumed_metrics. - - Returns dict with batches/metrics from both pre and post checkpoint runs. + dataloader (DataLoader): The dataloader to test + aggregator (MetricsAggregator): The metrics aggregator to use + steps_before_checkpoint (int): Number of steps to run before saving checkpoint + steps_after_checkpoint (int): Number of steps to run after checkpoint + resume_dataloader (Optional[DataLoader]): Optional new dataloader to test resuming. + If None, returns empty resumed_batches. + resume_aggregator (Optional[MetricsAggregator]): Optional new aggregator to test resuming. + If None, returns empty resumed_metrics. + + Returns: + dict[str, Any]: Dict with batches/metrics from both pre and post checkpoint runs. """ iterator = iter(dataloader) @@ -123,4 +126,4 @@ def generate_ckpt( "resumed_metrics": resumed_metrics, # Internal state for loading - only if someone needs to manually load "_checkpoint_state": checkpoint_state, - } \ No newline at end of file + } diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index e1d7d687dd..09292b9ba9 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from torchtune.data._aggregator import MetricsAggregator from torchtune.data._collate import ( left_pad_sequence, padded_collate, @@ -24,6 +25,12 @@ ShareGPTToMessages, validate_messages, ) +from torchtune.data._metrics import ( + AggregationType, + Metric, + MetricTransform, + StandardMetricTransform, +) from torchtune.data._prompt_templates import ( ChatMLTemplate, GrammarErrorCorrectionTemplate, @@ -32,14 +39,7 @@ QuestionAnswerTemplate, SummarizeTemplate, ) -from torchtune.data._metrics import ( - AggregationType, - Metric, - MetricTransform, - StandardMetricTransform, -) from torchtune.data._utils import format_content_with_images, load_image, truncate -from torchtune.data._aggregator import MetricsAggregator __all__ = [ "AggregationType", diff --git a/torchtune/data/_aggregator.py b/torchtune/data/_aggregator.py index f6b962c84c..9c933ddfa3 100644 --- a/torchtune/data/_aggregator.py +++ b/torchtune/data/_aggregator.py @@ -7,9 +7,8 @@ import ast import collections import logging -from typing import Any, Dict, List, Tuple +from typing import Any -import torch import torch.distributed as dist from torchtune.data._metrics import AggregationType, Metric @@ -34,16 +33,16 @@ class MetricsAggregator: def __init__(self, dist_window_size: int = 1000): # State shape: {(dataset_name, metric_name): {type: AggType, value/sum/counts/etc}} - self._state: Dict[Tuple[str, str], Dict[str, Any]] = {} + self._state: dict[tuple[str, str], dict[str, Any]] = {} # For distributions, we keep a window of values to compute percentiles self._dist_window_size = dist_window_size - def update(self, metrics: List[Metric]) -> None: + def update(self, metrics: list[Metric]) -> None: """Update internal state with new metrics. Args: - metrics: List of Metric objects + metrics (list[Metric]): list of Metric objects """ for metric in metrics: key = (metric.dataset_name, metric.name) @@ -75,7 +74,7 @@ def update(self, metrics: List[Metric]) -> None: state["counts"][metric.value] += 1 def _initialize_state( - self, key: Tuple[str, str], agg_type: AggregationType + self, key: tuple[str, str], agg_type: AggregationType ) -> None: """Initialize state for a new metric.""" self._state[key] = {"type": agg_type} @@ -93,15 +92,15 @@ def _initialize_state( elif agg_type == AggregationType.CATEGORICAL_COUNT: state["counts"] = collections.Counter() - def get_metrics_for_logging(self, prefix: str = "") -> Dict[str, float]: + def get_metrics_for_logging(self, prefix: str = "") -> dict[str, float]: """ Returns aggregated metrics ready for logging to wandb/tensorboard. Args: - prefix: Optional prefix like "train" or "valid" for metric keys + prefix (str): Optional prefix like "train" or "valid" for metric keys Returns: - Flat dictionary with keys like "train/dataset1/tokens_seen" -> float value + dict[str, float]: Flat dictionary with keys like "train/dataset1/tokens_seen" -> float value Ready to be logged directly: wandb.log(metrics) """ # Always compute local metrics first @@ -116,18 +115,16 @@ def get_metrics_for_logging(self, prefix: str = "") -> Dict[str, float]: # Format for logging with proper key structure return self._format_for_logging(metrics, prefix) - def _compute_local_metrics(self) -> Dict[Tuple[str, str], Dict[str, Any]]: + def _compute_local_metrics(self) -> dict[tuple[str, str], dict[str, Any]]: """ Compute metrics from current state. For distributions and categoricals, expands into multiple entries. The dict format allows future extensions with additional fields. - Args: - None - Returns: - Dictionary mapping (dataset_name, metric_name) -> {"value": value, "agg_type": aggregation_type} + dict[tuple[str, str], dict[str, Any]]: dictionary mapping + (dataset_name, metric_name) -> {"value": value, "agg_type": aggregation_type} """ metrics = {} @@ -199,8 +196,8 @@ def _compute_local_metrics(self) -> Dict[Tuple[str, str], Dict[str, Any]]: return metrics def _compute_distributed_metrics( - self, local_metrics: Dict[Tuple[str, str], Dict[str, Any]] - ) -> Dict[Tuple[str, str], Dict[str, Any]]: + self, local_metrics: dict[tuple[str, str], dict[str, Any]] + ) -> dict[tuple[str, str], dict[str, Any]]: """ Performs distributed reduction on metrics. @@ -212,10 +209,11 @@ def _compute_distributed_metrics( This avoids complex tensor operations and handles all reduction in one pass. Args: - local_metrics: Dict mapping (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} + local_metrics (dict[tuple[str, str], dict[str, Any]]): dict mapping + (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} Returns: - Reduced metrics in same format as input + dict[tuple[str, str], dict[str, Any]]: Reduced metrics in same format as input Example: rank_1_metrics = @@ -278,17 +276,18 @@ def _compute_distributed_metrics( return reduced def _format_for_logging( - self, metrics: Dict[Tuple[str, str], Dict[str, Any]], prefix: str - ) -> Dict[str, float]: + self, metrics: dict[tuple[str, str], dict[str, Any]], prefix: str + ) -> dict[str, float]: """ Format metrics for wandb/tensorboard logging. Args: - metrics: Dict mapping (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} - prefix: Optional prefix like "train" or "valid" + metrics (dict[tuple[str, str], dict[str, Any]]): dict mapping + (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} + prefix (str): Optional prefix like "train" or "valid" Returns: - Flat dict with string keys like "train/dataset1/tokens_seen" -> float + dict[str, float]: Flat dict with string keys like "train/dataset1/tokens_seen" -> float """ formatted = {} @@ -303,7 +302,7 @@ def _format_for_logging( return formatted - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Serialize aggregator state. The state is almost directly serializable.""" serializable_state = {} for key, state in self._state.items(): @@ -320,7 +319,7 @@ def state_dict(self) -> Dict[str, Any]: serializable_state[str(key)] = state_copy return {"state": serializable_state, "dist_window_size": self._dist_window_size} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load aggregator state from checkpoint.""" self._dist_window_size = state_dict["dist_window_size"] @@ -339,4 +338,4 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: state["counts"] = collections.Counter(state["counts"]) deserialized_state[key] = state - self._state = deserialized_state \ No newline at end of file + self._state = deserialized_state diff --git a/torchtune/data/_metrics.py b/torchtune/data/_metrics.py index f61d0e579e..7a38febb1e 100644 --- a/torchtune/data/_metrics.py +++ b/torchtune/data/_metrics.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Callable, Dict, Optional, Protocol, Union +from typing import Any, Callable, Optional, Protocol, Union class AggregationType(Enum): @@ -34,8 +34,11 @@ class Metric: class MetricTransform(Protocol): """Protocol for metric transforms.""" - def set_dataset_name(self, dataset_name: str) -> None: ... - def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: ... + def set_dataset_name(self, dataset_name: str) -> None: + ... + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + ... class StandardMetricTransform(MetricTransform): @@ -67,7 +70,7 @@ def set_dataset_name(self, dataset_name: str) -> None: self.dataset_name = dataset_name self.new_metric = partial(Metric, dataset_name=dataset_name) - def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: if self.dataset_name is None or self.new_metric is None: raise RuntimeError( "set_dataset_name() must be called before using the transform." @@ -92,4 +95,4 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: if "metrics" not in sample: sample["metrics"] = [] sample["metrics"].extend(metrics) - return sample \ No newline at end of file + return sample diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index 4ea863169d..f5ecbb95ea 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -22,7 +22,7 @@ from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import preference_dataset, PreferenceDataset from torchtune.datasets._samsum import samsum_dataset -from torchtune.datasets._sft import SFTDataset, sft_iterable_dataset +from torchtune.datasets._sft import sft_iterable_dataset, SFTDataset from torchtune.datasets._slimorca import slimorca_dataset, slimorca_iterable_dataset from torchtune.datasets._stack_exchange_paired import stack_exchange_paired_dataset from torchtune.datasets._text_completion import ( diff --git a/torchtune/datasets/_alpaca.py b/torchtune/datasets/_alpaca.py index 4225ab4bf5..bae7613729 100644 --- a/torchtune/datasets/_alpaca.py +++ b/torchtune/datasets/_alpaca.py @@ -9,11 +9,10 @@ from typing import Any, Callable, Optional, Union from torchtune.data._messages import AlpacaToMessages -from torchtune.data._metrics import StandardMetricTransform from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._packed import PackedDataset -from torchtune.datasets._sft import SFTDataset, sft_iterable_dataset +from torchtune.datasets._sft import sft_iterable_dataset, SFTDataset from torchtune.modules.transforms.tokenizers import ModelTokenizer @@ -152,7 +151,7 @@ def alpaca_iterable_dataset( message_transform = AlpacaToMessages( train_on_input=train_on_input, column_map=column_map ) - + return sft_iterable_dataset( message_transform=message_transform, model_transform=model_transform, diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 9a206445d5..0be4c0cc53 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -1,5 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import logging -from typing import Any, Callable, Dict, Iterator, List, Optional +from typing import Any, Callable, Iterator, Optional import torch import torch.distributed as dist @@ -37,8 +43,8 @@ class HfIterableDataset(TuneIterableDataset): dataset_name (Optional[str]): Name of the dataset. If None, a default name is generated from the path, source, and split. filter_fn (Optional[Callable]): Filter function to apply to the dataset. - filter_kwargs (Optional[Dict[str, Any]]): Keyword arguments to pass to the filter function. - load_dataset_kwargs (Dict[str, Any]): Keyword arguments to pass to the load_dataset function. + filter_kwargs (Optional[dict[str, Any]]): Keyword arguments to pass to the filter function. + load_dataset_kwargs (dict[str, Any]): Keyword arguments to pass to the load_dataset function. """ @@ -55,7 +61,7 @@ def __init__( num_shards_per_rank: int = 64, dataset_name: Optional[str] = None, filter_fn: Optional[Callable] = None, - filter_kwargs: Optional[Dict[str, Any]] = None, + filter_kwargs: Optional[dict[str, Any]] = None, **load_dataset_kwargs, ): # Store configuration @@ -64,7 +70,7 @@ def __init__( self._message_transform = message_transform self._model_transform = model_transform self._output_transform = output_transform - self._weight = weight # TODO: make it a property? + self._weight = weight # TODO: make it a property? # Create default transform if not provided self._metric_transform = metric_transform or StandardMetricTransform() @@ -98,7 +104,7 @@ def __init__( def dataset_name(self) -> str: return self._dataset_name - def _apply_transforms(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def _apply_transforms(self, sample: dict[str, Any]) -> dict[str, Any]: """Apply transforms if they exist, otherwise return sample unchanged.""" if self._message_transform is not None: sample = self._message_transform(sample) @@ -112,10 +118,10 @@ def _apply_transforms(self, sample: Dict[str, Any]) -> Dict[str, Any]: def _setup_hf_dataset( self, - load_dataset_kwargs: Dict[str, Any], + load_dataset_kwargs: dict[str, Any], num_shards_per_rank: int, filter_fn: Optional[Callable] = None, - filter_kwargs: Optional[Dict[str, Any]] = None, + filter_kwargs: Optional[dict[str, Any]] = None, ): """ Configures the Hugging Face dataset, including sharding, filtering, and @@ -185,7 +191,7 @@ def _setup_hf_dataset( self._ds = ds - def __iter__(self) -> Iterator[Dict[str, Any]]: + def __iter__(self) -> Iterator[dict[str, Any]]: """Iterate through the dataset infinitely. It will restart from the beginning after exhausting the dataset. @@ -246,7 +252,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: # Reset to the base dataset for the next epoch's shuffling. epoch_ds = self._ds - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """ The dataset returns its own state directly, without namespacing. """ @@ -258,7 +264,7 @@ def state_dict(self) -> Dict[str, Any]: } return state - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """ Load state from checkpoint, including restoring the state of the Hugging Face IterableDataset. @@ -268,4 +274,4 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # HF is responsible for resuming the dataset state # where it last left off - self._ds.load_state_dict(hf_state) \ No newline at end of file + self._ds.load_state_dict(hf_state) diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index cbfe36338c..53185993dd 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -1,7 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import collections import logging import math -from typing import Any, Dict, Iterator, List +from typing import Any, dict, Iterator import torch @@ -17,16 +23,19 @@ class InterleavedDataset(TuneIterableDataset): to ensure correct checkpointing and resumption. Args: - datasets (List[TuneIterableDataset]): List of TuneIterableDatasets to interleave. - weights (List[float]): List of weights for each dataset. Must sum to 1.0. + datasets (list[TuneIterableDataset]): list of TuneIterableDatasets to interleave. + weights (list[float]): list of weights for each dataset. Must sum to 1.0. seed (int): Seed for sampling. dataset_name (str): Name of the dataset. If None, defaults to "interleaved_dataset". + + Raises: + ValueError: If duplicate dataset names are detected in the provided datasets. """ def __init__( self, - datasets: List[TuneIterableDataset], - weights: List[float], + datasets: list[TuneIterableDataset], + weights: list[float], seed: int, dataset_name: str = "interleaved_dataset", ): @@ -36,7 +45,7 @@ def __init__( self._dataset_names = [ds.dataset_name for ds in datasets] # Create a name-to-dataset mapping for robust state management - self._datasets: Dict[str, TuneIterableDataset] = { + self._datasets: dict[str, TuneIterableDataset] = { ds.dataset_name: ds for ds in datasets } @@ -54,21 +63,22 @@ def __init__( self._sampling_generator = torch.Generator().manual_seed(seed) # Normalize weights to sum to 1 - #TODO: make it a property? rely on ds.weight? + # TODO: make it a property? rely on ds.weight? total_weight = sum(weights) self._weights = torch.tensor( [w / total_weight for w in weights], dtype=torch.float ) if not math.isclose(total_weight, 1.0, rel_tol=1e-9): logger.warning( - f"Interleaved dataset normalized weights to sum to 1.0. Found {total_weight=}. Previous {weights=}, new {self._weights.tolist()}" + f"Interleaved dataset normalized weights to sum to 1.0. " + f"Found {total_weight=}. Previous {weights=}, new {self._weights.tolist()}" ) @property def dataset_name(self) -> str: return self._dataset_name - def __iter__(self) -> Iterator[Dict[str, Any]]: + def __iter__(self) -> Iterator[dict[str, Any]]: """Interleave samples from child infinite datasets""" child_iters = {name: iter(ds) for name, ds in self._datasets.items()} @@ -96,7 +106,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: sample = next(child_iters[ds_name]) yield sample - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Save state for the interleaver and its children.""" # The parent is responsible for namespacing the child states. child_states = {name: ds.state_dict() for name, ds in self._datasets.items()} @@ -105,11 +115,11 @@ def state_dict(self) -> Dict[str, Any]: "child_states": child_states, } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load state for the interleaver and its children.""" self._sampling_generator.set_state(state_dict["sampling_generator_state"]) child_states = state_dict["child_states"] for name, ds in self._datasets.items(): if name in child_states: # Pass the raw state dict to the child - ds.load_state_dict(child_states[name]) \ No newline at end of file + ds.load_state_dict(child_states[name]) diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index 725810541c..9dac9ee0b1 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -1,5 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from abc import ABC, abstractmethod -from typing import Any, Dict, Iterator +from typing import Any, dict, Iterator from torch.utils.data import IterableDataset @@ -19,7 +25,7 @@ def dataset_name(self) -> str: pass @abstractmethod - def __iter__(self) -> Iterator[Dict[str, Any]]: + def __iter__(self) -> Iterator[dict[str, Any]]: """ Returns an infinite iterator over the dataset. Each implementation is responsible for its own iteration logic, including shuffling and making it an infinite stream. @@ -27,11 +33,11 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: pass @abstractmethod - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Returns a state dictionary for checkpointing""" pass @abstractmethod - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load state from a state dictionary, used when resuming from a checkpoint.""" - pass \ No newline at end of file + pass diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 70bfb75fd5..04f78a9911 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Mapping, Optional, Dict +from typing import Any, Callable, Mapping, Optional import numpy as np from datasets import load_dataset @@ -12,11 +12,11 @@ from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages - -from torchtune.modules.transforms import Transform from torchtune.data._metrics import StandardMetricTransform from torchtune.datasets._hf_iterable import HfIterableDataset +from torchtune.modules.transforms import Transform + class SFTDataset(Dataset): """ @@ -187,24 +187,24 @@ class SFTOutputTransform(Transform): Output transform to be used in SFT recipes as an input to TuneIterableDataset. It takes tokenized inputs with "tokens" and "mask" keys and creates the "labels" key for SFT training. - + The labels are created by: 1. Shifting tokens by 1 position (for autoregressive training) 2. Masking positions where mask[1:] is True with CROSS_ENTROPY_IGNORE_IDX 3. Adding CROSS_ENTROPY_IGNORE_IDX at the end """ - + def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: # Create a copy to avoid modifying the original tokenized_dict = dict(sample) - + if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): keys_str = ", ".join(tokenized_dict.keys()) raise ValueError( f"SFTOutputTransform expects 'tokens' and 'mask' keys. " f"Got keys: {keys_str}" ) - + # Create labels for SFT training tokenized_dict["labels"] = list( np.where( @@ -215,12 +215,12 @@ def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: ) tokenized_dict["labels"].append(CROSS_ENTROPY_IGNORE_IDX) assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) - + return tokenized_dict def sft_iterable_dataset( - model_transform: Transform, + model_transform: Transform, *, message_transform: Transform, shuffle_buffer_size: Optional[int] = 1000, @@ -228,26 +228,26 @@ def sft_iterable_dataset( num_shards_per_rank: int = 64, dataset_name: Optional[str] = None, filter_fn: Optional[Callable] = None, - filter_kwargs: Optional[Dict[str, Any]] = None, - **load_dataset_kwargs: Dict[str, Any], + filter_kwargs: Optional[dict[str, Any]] = None, + **load_dataset_kwargs: dict[str, Any], ) -> HfIterableDataset: """ Creates an SFT-ready iterable dataset with appropriate output transform. - + Args: model_transform (Transform): Usually the tokenizer message_transform (Transform): Transform to convert raw data to messages - shuffle_buffer_size (Optional[int]): Buffer size for shuffling + shuffle_buffer_size (Optional[int]): Buffer size for shuffling seed (int): Random seed for shuffling num_shards_per_rank (int): Target shards per worker dataset_name (Optional[str]): Name for metrics namespacing filter_fn (Optional[Callable]): Filter function - filter_kwargs (Optional[Dict[str, Any]]): Filter function kwargs - **load_dataset_kwargs: Args passed to load_dataset - + filter_kwargs (Optional[dict[str, Any]]): Filter function kwargs + **load_dataset_kwargs (dict[str, Any]): Args passed to load_dataset + Returns: HfIterableDataset: Configured for SFT training - + Example: >>> from torchtune.data import AlpacaToMessages >>> message_transform = AlpacaToMessages(train_on_input=False) @@ -259,11 +259,11 @@ def sft_iterable_dataset( """ output_transform = SFTOutputTransform() - + return HfIterableDataset( message_transform=message_transform, model_transform=model_transform, - output_transform=output_transform, + output_transform=output_transform, metric_transform=StandardMetricTransform(), shuffle_buffer_size=shuffle_buffer_size, seed=seed, diff --git a/torchtune/datasets/_slimorca.py b/torchtune/datasets/_slimorca.py index 77667aa579..5a5e9bc94f 100644 --- a/torchtune/datasets/_slimorca.py +++ b/torchtune/datasets/_slimorca.py @@ -7,11 +7,10 @@ from typing import Any, Callable, Optional, Union from torchtune.data import ShareGPTToMessages -from torchtune.data._metrics import StandardMetricTransform from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._packed import PackedDataset -from torchtune.datasets._sft import SFTDataset, sft_iterable_dataset +from torchtune.datasets._sft import sft_iterable_dataset, SFTDataset from torchtune.modules.transforms.tokenizers import ModelTokenizer @@ -121,7 +120,7 @@ def slimorca_iterable_dataset( Args: model_transform (ModelTokenizer): Model tokenizer used to tokenize the messages. - source (str): path to dataset repository on Hugging Face. Default is "Open-Orca/SlimOrca-Dedup". + source (str): path to dataset repository on Hugging Face. Default is "Open-Orca/SlimOrca-Dedup". column_map (Optional[dict[str, str]]): mapping from expected "conversations" column to actual column name in dataset. If None, uses default "conversations". train_on_input (bool): Whether to train on input or mask it. Default is False. @@ -132,13 +131,13 @@ def slimorca_iterable_dataset( dataset_name (Optional[str]): Name for metrics. If None, auto-generated from source. filter_fn (Optional[Callable]): Filter function to apply to dataset. filter_kwargs (Optional[dict[str, Any]]): Kwargs for filter function. - **load_dataset_kwargs: Additional kwargs for load_dataset. + **load_dataset_kwargs (dict[str, Any]): Additional kwargs for load_dataset. Returns: HfIterableDataset: Configured iterable dataset Example: - >>> from torchtune.datasets import slimorca_iterable_dataset + >>> from torchtune.datasets import slimorca_iterable_dataset >>> ds = slimorca_iterable_dataset(shuffle_buffer_size=1000) >>> for sample in ds: >>> print(sample["tokens"][:10]) # First 10 tokens @@ -148,9 +147,9 @@ def slimorca_iterable_dataset( column_map=column_map, new_system_prompt=new_system_prompt, ) - + return sft_iterable_dataset( - source=source, + source=source, message_transform=message_transform, model_transform=model_transform, shuffle_buffer_size=shuffle_buffer_size, diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index 5f29e038c6..05fd46e395 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -19,6 +19,7 @@ StateDictOptions, ) from torchtune import config, training, utils +from torchtune.data import MetricsAggregator from torchtune.modules.optim import OptimizerInBackward from torchtune.modules.peft import ( get_adapter_state_dict, @@ -28,7 +29,6 @@ from torchtune.training.checkpointing._checkpointer import DistributedCheckpointer from torchtune.training.checkpointing._utils import get_most_recent_checkpoint from torchtune.training.memory import OptimizerInBackwardWrapper -from torchtune.data import MetricsAggregator log = utils.get_logger("DEBUG") import torchdata @@ -463,7 +463,9 @@ def load_distributed_checkpoint( checkpoint_dict: dict[str, Any] = {} model_state_dict = model.state_dict() optim_state_dict = optimizer.state_dict() - metrics_aggregator_state_dict = metrics_aggregator.state_dict() if metrics_aggregator else {} + metrics_aggregator_state_dict = ( + metrics_aggregator.state_dict() if metrics_aggregator else {} + ) # Hack to properly initialize the learning rate scheduler # TODO: Find a better way to do this, possibly by including the following From 2e51e04f01150a251f1defb49de6991e1d8f8256 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 25 Jun 2025 13:56:58 -0700 Subject: [PATCH 04/25] tests pass --- tests/torchtune/datasets/test_hf_iterable.py | 170 +++--------------- tests/torchtune/datasets/test_interleaved.py | 3 +- .../torchtune/datasets/test_iterable_utils.py | 25 ++- torchtune/datasets/_interleaved.py | 2 +- torchtune/datasets/_iterable_base.py | 2 +- 5 files changed, 53 insertions(+), 149 deletions(-) diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index a263258ae8..83144f0ae9 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -6,18 +6,15 @@ from itertools import islice from pathlib import Path -from typing import Any, Optional import pytest + from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data import ( - MetricsAggregator, - padded_collate_sft, - StandardMetricTransform, -) +from torchtune.data import MetricsAggregator, StandardMetricTransform from torchtune.datasets import HfIterableDataset +from .test_iterable_utils import collate_with_metrics, generate_ckpt # Test Constants - Avoid perfect divisions SMALL_DATASET_SIZE = 23 @@ -42,129 +39,10 @@ def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None token_len = (i % 3) + 1 tokens = list(range(sample_id, sample_id + token_len)) f.write( - f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}"}}\n' + f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}", "labels": {tokens}}}\n' ) -def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: - """Collate function that extracts metrics and uses padded_collate_sft as base collator.""" - # Extract metrics first - all_metrics = [] - clean_batch = [] - for sample in batch: - if "metrics" in sample: - all_metrics.extend(sample.pop("metrics")) - clean_batch.append(sample) - - if not clean_batch: - return {"metrics": all_metrics} - - # Use torchtune's padded_collate_sft as base collator - collated_batch = padded_collate_sft(clean_batch) - collated_batch["metrics"] = all_metrics - return collated_batch - - -def generate_ckpt( - dataloader: StatefulDataLoader, - aggregator: MetricsAggregator, - steps_before_checkpoint: int, - steps_after_checkpoint: int, - resume_dataloader: Optional[StatefulDataLoader] = None, - resume_aggregator: Optional[MetricsAggregator] = None, -) -> dict[str, Any]: - """ - Generates a checkpoint by running through data and saving checkpoint mid-stream. - Optionally, a second dataloader and aggregator can be given to resume from ckpt - and run steps_after_checkpoint to match the first one. - - Args: - dataloader (StatefulDataLoader): The dataloader to test - aggregator (MetricsAggregator): The metrics aggregator to use - steps_before_checkpoint (int): Number of steps to run before saving checkpoint - steps_after_checkpoint (int): Number of steps to run after checkpoint - resume_dataloader (Optional[StatefulDataLoader]): Optional new dataloader to test resuming. - If None, returns empty resumed_batches. - resume_aggregator (Optional[MetricsAggregator]): Optional new aggregator to test resuming. - If None, returns empty resumed_metrics. - - Returns: - dict[str, Any]: Dict with batches/metrics from both pre and post checkpoint runs. - """ - iterator = iter(dataloader) - - # Collect batches before and after checkpoint - batches = [] - checkpoint_state = None - metrics_at_checkpoint = {} - - total_steps = steps_before_checkpoint + steps_after_checkpoint - - for idx, batch in enumerate(iterator): - batches.append(batch) - - # Process metrics - if "metrics" in batch: - aggregator.update(batch.pop("metrics")) - - # Save checkpoint state after steps_before_checkpoint - if idx == steps_before_checkpoint - 1: # -1 because idx is 0-based - checkpoint_state = { - "loader": dataloader.state_dict(), - "aggregator": aggregator.state_dict(), - } - metrics_at_checkpoint = aggregator.get_metrics_for_logging(prefix="train") - - # Stop after total steps - if idx == total_steps - 1: - break - - # Split batches - pre_checkpoint_batches = batches[:steps_before_checkpoint] - post_checkpoint_batches = batches[steps_before_checkpoint:] - - # Resume with new instances if provided - resumed_batches = [] - resumed_metrics = {} - - if ( - resume_dataloader is not None - and resume_aggregator is not None - and checkpoint_state is not None - ): - # Test resuming with new instances - resume_dataloader.load_state_dict(checkpoint_state["loader"]) - resume_aggregator.load_state_dict(checkpoint_state["aggregator"]) - resume_iterator = iter(resume_dataloader) - - # Collect only the post-checkpoint batches when resuming - for idx, batch in enumerate(resume_iterator): - resumed_batches.append(batch) - - # Process metrics - if "metrics" in batch: - resume_aggregator.update(batch.pop("metrics")) - - # Stop after steps_after_checkpoint - if idx == steps_after_checkpoint - 1: - break - - resumed_metrics = resume_aggregator.get_metrics_for_logging(prefix="train") - - return { - # Original run - "pre_checkpoint_batches": pre_checkpoint_batches, - "post_checkpoint_batches": post_checkpoint_batches, - "metrics_at_checkpoint": metrics_at_checkpoint, - "final_metrics": aggregator.get_metrics_for_logging(prefix="train"), - # Resumed run - "resumed_batches": resumed_batches, - "resumed_metrics": resumed_metrics, - # Internal state for loading - only if someone needs to manually load - "_checkpoint_state": checkpoint_state, - } - - @pytest.fixture def tmp_data_dir(tmp_path): """Provide temporary directory for test data files.""" @@ -292,14 +170,16 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file): ) # Get samples from two passes through the dataset - epoch_samples = islice(iter(unshuffled_ds), SMALL_DATASET_SIZE * 2) + epoch_samples = list(islice(iter(unshuffled_ds), SMALL_DATASET_SIZE * 2)) first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] # Unshuffled should have same order in both epochs - assert first_epoch_samples == list(range(SMALL_DATASET_SIZE)) - assert second_epoch_samples == list(range(SMALL_DATASET_SIZE)) + first_epoch_ids = [sample["id"] for sample in first_epoch_samples] + second_epoch_ids = [sample["id"] for sample in second_epoch_samples] + assert first_epoch_ids == list(range(SMALL_DATASET_SIZE)) + assert second_epoch_ids == list(range(SMALL_DATASET_SIZE)) # Test shuffled dataset shuffled_ds = dataset_factory( @@ -307,48 +187,56 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file): ) # Collect full epochs to compare - epoch_samples = islice(iter(shuffled_ds), SMALL_DATASET_SIZE * 2) + epoch_samples = list(islice(iter(shuffled_ds), SMALL_DATASET_SIZE * 2)) first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] # Shuffled epochs should have different order - assert first_epoch_samples != list( + first_epoch_ids = [sample["id"] for sample in first_epoch_samples] + second_epoch_ids = [sample["id"] for sample in second_epoch_samples] + assert first_epoch_ids != list( range(SMALL_DATASET_SIZE) - ), f"Shuffled should not be sorted, got {first_epoch_samples}" + ), f"Shuffled should not be sorted, got {first_epoch_ids}" assert ( - first_epoch_samples != second_epoch_samples - ), f"Shuffled epochs should be shuffled differently, got {first_epoch_samples} and {second_epoch_samples}" + first_epoch_ids != second_epoch_ids + ), f"Shuffled epochs should be shuffled differently, got {first_epoch_ids} and {second_epoch_ids}" # But should contain the same set of IDs - assert set(first_epoch_samples) == set( + assert set(first_epoch_ids) == set( range(SMALL_DATASET_SIZE) - ), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_samples}" - assert set(second_epoch_samples) == set( + ), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_ids}" + assert set(second_epoch_ids) == set( range(SMALL_DATASET_SIZE) - ), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_samples}" + ), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_ids}" def test_epoch_tracking(self, dataset_factory, small_dataset_file): """Test that epoch number is correctly tracked across dataset restarts.""" dataset = dataset_factory(small_dataset_file, shuffle=False) # Two epoch samples - epoch_samples = islice(iter(dataset), SMALL_DATASET_SIZE * 2) + epoch_samples = list(islice(iter(dataset), SMALL_DATASET_SIZE * 2)) first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] # All should have epoch 0 + first_epoch_metrics = [] + for sample in first_epoch_samples: + first_epoch_metrics.extend(sample["metrics"]) epoch_values = [ - epoch_metric.value for epoch_metric in first_epoch_samples["metrics"] + metric.value for metric in first_epoch_metrics if metric.name == "epoch" ] assert all( epoch_value == 0 for epoch_value in epoch_values ), f"Epoch values should be 0, got {epoch_values}" # All should have epoch 1 + second_epoch_metrics = [] + for sample in second_epoch_samples: + second_epoch_metrics.extend(sample["metrics"]) epoch_values = [ - epoch_metric.value for epoch_metric in second_epoch_samples["metrics"] + metric.value for metric in second_epoch_metrics if metric.name == "epoch" ] assert all( epoch_value == 1 for epoch_value in epoch_values diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index e06ef670c1..96139fd868 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -9,6 +9,7 @@ from unittest.mock import patch import pytest + import torch from torchdata.stateful_dataloader import StatefulDataLoader @@ -40,7 +41,7 @@ def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None token_len = (i % 3) + 1 tokens = list(range(sample_id, sample_id + token_len)) f.write( - f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}"}}\n' + f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}", "labels": {tokens}}}\n' ) diff --git a/tests/torchtune/datasets/test_iterable_utils.py b/tests/torchtune/datasets/test_iterable_utils.py index 5a4fbda8a5..e160345bc1 100644 --- a/tests/torchtune/datasets/test_iterable_utils.py +++ b/tests/torchtune/datasets/test_iterable_utils.py @@ -6,13 +6,14 @@ from typing import Any, Optional +import torch + from torch.utils.data import DataLoader -from torchtune.data import padded_collate_sft -from torchtune.data._metrics import MetricsAggregator +from torchtune.data import MetricsAggregator def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: - """Collate function that extracts metrics and uses padded_collate_sft for the rest.""" + """Simple collate that extracts metrics and pads tokens.""" all_metrics = [] clean_batch = [] for sample in batch: @@ -23,8 +24,22 @@ def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: if not clean_batch: return {"metrics": all_metrics} - # Use torchtune's standard SFT collate function - collated = padded_collate_sft(clean_batch) + # Simple padding for tokens + ids = torch.tensor([item["id"] for item in clean_batch]) + tokens = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(item["tokens"]) for item in clean_batch], + batch_first=True, + padding_value=-1, # Use -1 for padding to distinguish from valid IDs + ) + collated = { + "id": ids, + "tokens": tokens, + } + + # Add text field for non-tensor data + if "text" in clean_batch[0]: + collated["text"] = [item["text"] for item in clean_batch] + collated["metrics"] = all_metrics return collated diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 53185993dd..0245d4e94e 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -7,7 +7,7 @@ import collections import logging import math -from typing import Any, dict, Iterator +from typing import Any, Iterator import torch diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index 9dac9ee0b1..f0821dc3f1 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, dict, Iterator +from typing import Any, Iterator from torch.utils.data import IterableDataset From 93fa7436aa4ad294b6aea813c8d69528494e1d5c Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 26 Jun 2025 07:16:11 -0700 Subject: [PATCH 05/25] it works --- recipes/configs/llama3_2/3B_full.yaml | 10 ++-- recipes/full_finetune_distributed.py | 56 +++++++++---------- .../torchtune/data/test_metrics_aggregator.py | 46 +++++++-------- tests/torchtune/datasets/test_interleaved.py | 16 +++--- torchtune/data/_aggregator.py | 19 ++++--- torchtune/datasets/_alpaca.py | 6 +- torchtune/datasets/_sft.py | 3 + torchtune/datasets/_slimorca.py | 10 +++- 8 files changed, 86 insertions(+), 80 deletions(-) diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index 5534b305ac..f825e9194e 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -28,7 +28,7 @@ tokenizer: # Dataloader dataloader: - batch_size: 4 + batch_size: 16 # num_workers and pin_memory can be added here if needed # Dataset - now a list to support multiple weighted sources @@ -36,16 +36,18 @@ dataset: - _component_: torchtune.datasets.slimorca_iterable_dataset shuffle_buffer_size: 1000 weight: 0.8 + split: train[:5%] # simular 1 epoch quickly - _component_: torchtune.datasets.alpaca_iterable_dataset shuffle_buffer_size: 1000 weight: 0.2 + split: train[:5%] # simular 1 epoch quickly # Packing (TBD by follow up PR) # packing: # _component_: torchtune.datasets.packing.SFTPacking # max_seq_len: 8192 -seed: null +seed: 42 # Validation not supported yet with iterable datasets @@ -76,7 +78,7 @@ loss: num_training_steps: 100 # Total number of training steps to run save_every_n_steps: 200 # Save a checkpoint every N steps. Using 200 to avoid ckpt. gradient_accumulation_steps: 1 -dataset_metrics_log_freq: 10 # Log dataset-specific metrics every N steps +dataset_metrics_log_freq: 5 # Log dataset-specific metrics every N steps # Environment device: cuda @@ -91,7 +93,7 @@ optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_ste # Logging metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger + _component_: torchtune.training.metric_logging.WandBLogger log_dir: ${output_dir}/logs log_every_n_steps: 1 log_peak_memory_stats: True diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index a4eb87e7d6..5b65a6e9f5 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -13,7 +13,7 @@ from warnings import warn import torch -from omegaconf import dictConfig, listConfig +from omegaconf import DictConfig, ListConfig from torch import nn from torch.distributed import destroy_process_group, init_process_group @@ -119,7 +119,7 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): has example commands for how to kick-off training. Args: - cfg (dictConfig): OmegaConf object parsed from yaml file + cfg (DictConfig): OmegaConf object parsed from yaml file Raises: ValueError: If ``dtype`` is set to fp16. @@ -129,7 +129,7 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ - def __init__(self, cfg: dictConfig) -> None: + def __init__(self, cfg: DictConfig) -> None: device_type = cfg.device self._device = utils.get_device(device=device_type) self._dtype = training.get_dtype(cfg.dtype, device=self._device) @@ -303,7 +303,7 @@ def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: "Are you sure you passed in the right recipe checkpoint?" ) from e - def setup(self, cfg: dictConfig) -> None: + def setup(self, cfg: DictConfig) -> None: """ Setup the recipe. This includes training state (if resume_from_checkpoint is True), model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. @@ -334,7 +334,7 @@ def setup(self, cfg: dictConfig) -> None: self._compile_loss = compile_bool self._compile_optimizer_step = compile_bool self._compile_scale_grads = compile_bool - if isinstance(compile, dictConfig): + if isinstance(compile, DictConfig): self._compile_model = compile.get("model", True) self._compile_loss = compile.get("loss", True) self._compile_optimizer_step = compile.get("optimizer_step", False) @@ -431,7 +431,7 @@ def setup(self, cfg: dictConfig) -> None: collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, - batch_size=cfg.batch_size, + cfg_dataloader=cfg.dataloader, collate_fn=collate_name, dataloader_state_dict=( state_dict[training.DATALOADER_KEY] @@ -443,10 +443,9 @@ def setup(self, cfg: dictConfig) -> None: # Setup validation dataloader if validation dataset is provided self._val_dataloader = None if cfg.get("dataset_val") is not None: - batch_size_val = cfg.get("batch_size_val", cfg.batch_size) self._val_dataloader = self._setup_data( cfg_dataset=cfg.dataset_val, - batch_size=batch_size_val, + cfg_dataloader=cfg.get("dataloader_val", None), collate_fn=collate_name, dataloader_state_dict=( state_dict[training.VAL_DATALOADER_KEY] @@ -471,7 +470,7 @@ def setup(self, cfg: dictConfig) -> None: def _setup_lr_scheduler( self, - cfg_lr_scheduler: Optional[dictConfig], + cfg_lr_scheduler: Optional[DictConfig], num_training_steps: int, last_epoch: int, ) -> Optional[Optimizer]: @@ -480,7 +479,7 @@ def _setup_lr_scheduler( It supports both standard optimization and optimizer-in-backward cases. Args: - cfg_lr_scheduler (Optional[dictConfig]): The learning rate scheduler configuration. + cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration. num_training_steps (int): The total number of training steps. last_epoch (int): The index of the last epoch. @@ -519,14 +518,14 @@ def _setup_lr_scheduler( return lr_scheduler def _setup_profiler( - self, cfg_profiler: Optional[dictConfig] = None + self, cfg_profiler: Optional[DictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: """ Parses the `profiler` section of top-level `cfg` and sets up profiler """ # Missing profiler section in config, assume disabled if cfg_profiler is None: - cfg_profiler = dictConfig({"enabled": False}) + cfg_profiler = DictConfig({"enabled": False}) # Check that component is included and set correctly if cfg_profiler.get("_component_", None) is None: @@ -553,7 +552,7 @@ def _setup_profiler( def _setup_model( self, - cfg_model: dictConfig, + cfg_model: DictConfig, enable_activation_checkpointing: bool, enable_activation_offloading: bool, activation_offloading_use_streams: bool, @@ -711,7 +710,7 @@ def _setup_model( def _setup_optimizer( self, - cfg_optimizer: dictConfig, + cfg_optimizer: DictConfig, optimizer_in_bwd: bool = False, opt_state_dict: Optional[dict[str, Any]] = None, ) -> Optional[Optimizer]: @@ -764,8 +763,8 @@ def _setup_optimizer( def _setup_data( self, - cfg_dataset: Union[dictConfig, listConfig], - batch_size: int, + cfg_dataset: Union[DictConfig, ListConfig], + cfg_dataloader: DictConfig, collate_fn: str, dataloader_state_dict: Optional[dict[str, Any]] = None, ) -> StatefulDataLoader: @@ -777,7 +776,7 @@ def _setup_data( iterable_datasets = [] weights = [] cfg_dataset_list = cfg_dataset - if not isinstance(cfg_dataset_list, listConfig): + if not isinstance(cfg_dataset_list, ListConfig): cfg_dataset_list = [cfg_dataset_list] for ds_cfg in cfg_dataset_list: @@ -823,8 +822,8 @@ def _collate_with_metrics_wrapper( # 5. Create DataLoader dataloader = StatefulDataLoader( dataset=ds, - batch_size=batch_size, collate_fn=_collate_with_metrics_wrapper, + **cfg_dataloader, ) if dataloader_state_dict is not None: @@ -898,7 +897,7 @@ def validate(self) -> dict[str, float]: self._model.train() return log_dict - def save_checkpoint(self, *, epoch: int, step: int, full_tensors: bool): + def save_checkpoint(self, *, epoch: int, full_tensors: bool): """Save checkpoint based on global step.""" self._checkpoint_client.save_checkpoint( model=self._model, @@ -924,7 +923,6 @@ def save_checkpoint(self, *, epoch: int, step: int, full_tensors: bool): metrics_aggregator_state_dict=self._metrics_aggregator.state_dict(), ), epoch=epoch, # TODO: not needed. To be deprecated. - step=step, single_device=False, full_tensors=full_tensors, dir_prefix=self.checkpoint_dir_prefix, @@ -1074,23 +1072,21 @@ def train(self) -> None: # Log dataset metrics # #TODO: it requires all_gather. Should we keep a separate log_freq for this? - if ( - self.global_step % self._dataset_metrics_log_freq == 0 - and self._is_rank_zero - ): + if self.global_step % self._dataset_metrics_log_freq == 0: dataset_metrics = self._metrics_aggregator.get_metrics_for_logging( prefix="train" ) - self._metric_logger.log_dict(dataset_metrics, step=self.global_step) + if self._is_rank_zero: + self._metric_logger.log_dict( + dataset_metrics, step=self.global_step + ) # Save checkpoint if specified by user if ( self.save_every_n_steps is not None and self.global_step % self.save_every_n_steps == 0 ): - self.save_checkpoint( - epoch=0, step=self.global_step, full_tensors=False - ) + self.save_checkpoint(epoch=0, full_tensors=False) # Reset running stats for the next step running_loss = 0 @@ -1121,7 +1117,7 @@ def train(self) -> None: self.validate() self._profiler.stop() - self.save_checkpoint(epoch=0, step=self.global_step, full_tensors=True) + self.save_checkpoint(epoch=0, full_tensors=True) def cleanup(self) -> None: if self._is_rank_zero: @@ -1130,7 +1126,7 @@ def cleanup(self) -> None: @config.parse -def recipe_main(cfg: dictConfig) -> None: +def recipe_main(cfg: DictConfig) -> None: """ Entry point for the recipe. diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index 69a2f32967..a9fda513a4 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -36,13 +36,13 @@ def test_aggregation_types(self, agg_type, test_values, expected): ] aggregator.update(metrics) - result = aggregator.get_metrics_for_logging() + result = aggregator.get_metrics_for_logging(prefix="train") if agg_type == AggregationType.CATEGORICAL_COUNT: for category, count in expected.items(): - assert result[f"test/metric_{category}_count"] == count + assert result[f"train_test/metric_{category}_count"] == count else: - assert result["test/metric"] == expected + assert result["train_test/metric"] == expected def test_distribution_metrics(self): """Tests that `AggregationType.DISTRIBUTION` computes all expected statistics (mean, min, max, p50).""" @@ -58,11 +58,11 @@ def test_distribution_metrics(self): result = aggregator.get_metrics_for_logging(prefix="train") # Verify distribution statistics - assert result["train/test/dist_metric_mean"] == 5.5 - assert result["train/test/dist_metric_min"] == 1 - assert result["train/test/dist_metric_max"] == 10 + assert result["train_test/dist_metric_mean"] == 5.5 + assert result["train_test/dist_metric_min"] == 1 + assert result["train_test/dist_metric_max"] == 10 assert ( - result["train/test/dist_metric_p50"] == 5 + result["train_test/dist_metric_p50"] == 5 ) # Median of 1-10 is 5 (index 4, value 5) def test_state_management(self): @@ -84,8 +84,8 @@ def test_state_management(self): aggregator2.load_state_dict(state) # Both should have identical metrics - metrics1 = aggregator1.get_metrics_for_logging() - metrics2 = aggregator2.get_metrics_for_logging() + metrics1 = aggregator1.get_metrics_for_logging(prefix="train") + metrics2 = aggregator2.get_metrics_for_logging(prefix="train") assert metrics1 == metrics2 # Continue updating both - should remain identical @@ -96,13 +96,13 @@ def test_state_management(self): aggregator1.update(additional_metrics) aggregator2.update(additional_metrics) - final_metrics1 = aggregator1.get_metrics_for_logging() - final_metrics2 = aggregator2.get_metrics_for_logging() + final_metrics1 = aggregator1.get_metrics_for_logging(prefix="train") + final_metrics2 = aggregator2.get_metrics_for_logging(prefix="train") assert final_metrics1 == final_metrics2 # Verify expected values - assert final_metrics1["ds1/counter"] == 15 # 10 + 5 - assert final_metrics1["ds1/average"] == 10.0 # (5 + 15) / 2 + assert final_metrics1["train_ds1/counter"] == 15 # 10 + 5 + assert final_metrics1["train_ds1/average"] == 10.0 # (5 + 15) / 2 def test_multiple_datasets(self): """Test that metrics from multiple datasets are correctly namespaced.""" @@ -118,15 +118,15 @@ def test_multiple_datasets(self): result = aggregator.get_metrics_for_logging(prefix="train") - assert result["train/dataset1/samples"] == 100 - assert result["train/dataset2/samples"] == 200 - assert result["train/dataset1/tokens"] == 1000 - assert result["train/dataset2/tokens"] == 2000 + assert result["train_dataset1/samples"] == 100 + assert result["train_dataset2/samples"] == 200 + assert result["train_dataset1/tokens"] == 1000 + assert result["train_dataset2/tokens"] == 2000 def test_empty_aggregator(self): """Test that empty aggregator returns empty metrics.""" aggregator = MetricsAggregator() - result = aggregator.get_metrics_for_logging() + result = aggregator.get_metrics_for_logging(prefix="train") assert result == {} def test_prefix_handling(self): @@ -140,10 +140,10 @@ def test_prefix_handling(self): # Test with prefix result_with_prefix = aggregator.get_metrics_for_logging(prefix="validation") - assert result_with_prefix["validation/test_ds/metric1"] == 42 - assert result_with_prefix["validation/test_ds/metric2"] == 84 + assert result_with_prefix["validation_test_ds/metric1"] == 42 + assert result_with_prefix["validation_test_ds/metric2"] == 84 - # Test without prefix + # Test without prefix (uses default "data") result_no_prefix = aggregator.get_metrics_for_logging() - assert result_no_prefix["test_ds/metric1"] == 42 - assert result_no_prefix["test_ds/metric2"] == 84 + assert result_no_prefix["data_test_ds/metric1"] == 42 + assert result_no_prefix["data_test_ds/metric2"] == 84 diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 96139fd868..e1ded110ac 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -168,25 +168,25 @@ def test_metrics_aggregation( for sample in islice(iter(interleaved), total_samples): aggregator.update(sample["metrics"]) - metrics = aggregator.get_metrics_for_logging() + metrics = aggregator.get_metrics_for_logging(prefix="train") # Should have metrics from both datasets, with flat keys - assert "ds1/samples_seen" in metrics - assert "ds2/samples_seen" in metrics + assert "train_ds1/samples_seen" in metrics + assert "train_ds2/samples_seen" in metrics # Both datasets should have contributed samples - assert metrics["ds1/samples_seen"] > 0 - assert metrics["ds2/samples_seen"] > 0 + assert metrics["train_ds1/samples_seen"] > 0 + assert metrics["train_ds2/samples_seen"] > 0 # Total samples should equal what we processed calculated_total_samples = ( - metrics["ds1/samples_seen"] + metrics["ds2/samples_seen"] + metrics["train_ds1/samples_seen"] + metrics["train_ds2/samples_seen"] ) assert calculated_total_samples == total_samples # Test that ratio is approximately correct - ds1_ratio = metrics["ds1/samples_seen"] / total_samples - ds2_ratio = metrics["ds2/samples_seen"] / total_samples + ds1_ratio = metrics["train_ds1/samples_seen"] / total_samples + ds2_ratio = metrics["train_ds2/samples_seen"] / total_samples # Allow 10% tolerance due to randomness assert abs(ds1_ratio - 0.2) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.2" diff --git a/torchtune/data/_aggregator.py b/torchtune/data/_aggregator.py index 9c933ddfa3..313c5bdc73 100644 --- a/torchtune/data/_aggregator.py +++ b/torchtune/data/_aggregator.py @@ -92,7 +92,7 @@ def _initialize_state( elif agg_type == AggregationType.CATEGORICAL_COUNT: state["counts"] = collections.Counter() - def get_metrics_for_logging(self, prefix: str = "") -> dict[str, float]: + def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]: """ Returns aggregated metrics ready for logging to wandb/tensorboard. @@ -237,7 +237,6 @@ def _compute_distributed_metrics( world_size = dist.get_world_size() # Gather all metrics from all ranks in one operation - dist.barrier() all_metrics = [None] * world_size dist.all_gather_object(all_metrics, local_metrics) @@ -276,7 +275,10 @@ def _compute_distributed_metrics( return reduced def _format_for_logging( - self, metrics: dict[tuple[str, str], dict[str, Any]], prefix: str + self, + metrics: dict[tuple[str, str], dict[str, Any]], + prefix: str, + template: str = r"{prefix}_{ds_name}/{metric_name}", ) -> dict[str, float]: """ Format metrics for wandb/tensorboard logging. @@ -285,6 +287,7 @@ def _format_for_logging( metrics (dict[tuple[str, str], dict[str, Any]]): dict mapping (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} prefix (str): Optional prefix like "train" or "valid" + template (str): Template for metric key. Use {prefix}, {ds_name}, and {metric_name} as placeholders. Returns: dict[str, float]: Flat dict with string keys like "train/dataset1/tokens_seen" -> float @@ -292,12 +295,10 @@ def _format_for_logging( formatted = {} for (ds_name, metric_name), metric_dict in metrics.items(): - # Build key: "prefix/dataset/metric" or "dataset/metric" if no prefix - if prefix: - key = f"{prefix}/{ds_name}/{metric_name}" - else: - key = f"{ds_name}/{metric_name}" - + # Use regex format to build key + key = template.format( + prefix=prefix, ds_name=ds_name, metric_name=metric_name + ) formatted[key] = metric_dict["value"] return formatted diff --git a/torchtune/datasets/_alpaca.py b/torchtune/datasets/_alpaca.py index bae7613729..4326b7024f 100644 --- a/torchtune/datasets/_alpaca.py +++ b/torchtune/datasets/_alpaca.py @@ -107,7 +107,7 @@ def alpaca_dataset( def alpaca_iterable_dataset( model_transform: ModelTokenizer, *, - source: str = "tatsu-lab/alpaca", + path: str = "tatsu-lab/alpaca", column_map: Optional[dict[str, str]] = None, train_on_input: bool = True, shuffle_buffer_size: Optional[int] = 1000, @@ -125,7 +125,7 @@ def alpaca_iterable_dataset( Args: model_transform (ModelTokenizer): Model tokenizer used to tokenize the messages. - source (str): path to dataset repository on Hugging Face. Default is ``tatsu-lab/alpaca``. + path (str): path to dataset repository on Hugging Face. Default is ``tatsu-lab/alpaca``. column_map (Optional[dict[str, str]]): a mapping from the expected columns in the message transform :class:`~torchtune.data.AlpacaToMessages` to the new column names in the dataset. Keys should be "instruction", "input", and "output" and values should be the actual column names. @@ -160,6 +160,6 @@ def alpaca_iterable_dataset( dataset_name=dataset_name, filter_fn=filter_fn, split=split, - path=source, + path=path, **load_dataset_kwargs, ) diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 04f78a9911..6dabee9bb6 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -222,6 +222,7 @@ def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: def sft_iterable_dataset( model_transform: Transform, *, + weight: int = 1, message_transform: Transform, shuffle_buffer_size: Optional[int] = 1000, seed: int = 42, @@ -236,6 +237,7 @@ def sft_iterable_dataset( Args: model_transform (Transform): Usually the tokenizer + weight (int): Weight of the dataset. Used for sampling when interleaving datasets. message_transform (Transform): Transform to convert raw data to messages shuffle_buffer_size (Optional[int]): Buffer size for shuffling seed (int): Random seed for shuffling @@ -266,6 +268,7 @@ def sft_iterable_dataset( output_transform=output_transform, metric_transform=StandardMetricTransform(), shuffle_buffer_size=shuffle_buffer_size, + weight=weight, seed=seed, num_shards_per_rank=num_shards_per_rank, dataset_name=dataset_name, diff --git a/torchtune/datasets/_slimorca.py b/torchtune/datasets/_slimorca.py index 5a5e9bc94f..0346e5b73a 100644 --- a/torchtune/datasets/_slimorca.py +++ b/torchtune/datasets/_slimorca.py @@ -100,7 +100,8 @@ def slimorca_dataset( def slimorca_iterable_dataset( model_transform: ModelTokenizer, *, - source: str = "Open-Orca/SlimOrca-Dedup", + path: str = "Open-Orca/SlimOrca-Dedup", + split: str = "train", column_map: Optional[dict[str, str]] = None, train_on_input: bool = False, new_system_prompt: Optional[str] = None, @@ -120,7 +121,9 @@ def slimorca_iterable_dataset( Args: model_transform (ModelTokenizer): Model tokenizer used to tokenize the messages. - source (str): path to dataset repository on Hugging Face. Default is "Open-Orca/SlimOrca-Dedup". + path (str): path to dataset repository on Hugging Face. Default is "Open-Orca/SlimOrca-Dedup". + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a + subset of a given split, e.g. ``split="train[:10%]"``. Default is "train column_map (Optional[dict[str, str]]): mapping from expected "conversations" column to actual column name in dataset. If None, uses default "conversations". train_on_input (bool): Whether to train on input or mask it. Default is False. @@ -149,7 +152,8 @@ def slimorca_iterable_dataset( ) return sft_iterable_dataset( - source=source, + path=path, + split=split, message_transform=message_transform, model_transform=model_transform, shuffle_buffer_size=shuffle_buffer_size, From aa9e6f417bd82af6fa78cbd8b1e2233751ab6981 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 26 Jun 2025 07:46:10 -0700 Subject: [PATCH 06/25] remove code --- torchtune/datasets/_hf_iterable.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 0be4c0cc53..4b7c04d04c 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -201,12 +201,11 @@ def __iter__(self) -> Iterator[dict[str, Any]]: An additional metric "num_epochs" is added to the sample. """ - epoch_ds = self._ds while True: # Infinite iteration epoch_seed = self._seed + self._num_epochs - epoch_ds.set_epoch(epoch_seed) - epoch_iterator = iter(epoch_ds) + self._ds.set_epoch(epoch_seed) + epoch_iterator = iter(self._ds) samples_yielded = 0 try: @@ -249,9 +248,6 @@ def __iter__(self) -> Iterator[dict[str, Any]]: # Epoch complete - increment and continue infinite loop self._num_epochs += 1 - # Reset to the base dataset for the next epoch's shuffling. - epoch_ds = self._ds - def state_dict(self) -> dict[str, Any]: """ The dataset returns its own state directly, without namespacing. From 5b188ed9c66c80796679a2f8161b3353b8226380 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 1 Jul 2025 21:18:04 -0400 Subject: [PATCH 07/25] update metrics to use handlers --- recipes/full_finetune_distributed.py | 3 +- .../torchtune/data/test_metrics_aggregator.py | 2 +- .../torchtune/data/test_metrics_transform.py | 10 +- tests/torchtune/datasets/test_hf_iterable.py | 2 +- tests/torchtune/datasets/test_interleaved.py | 2 +- torchtune/data/__init__.py | 12 - torchtune/data/_aggregator.py | 5 +- torchtune/data/_metrics.py | 98 ---- torchtune/data/metrics/__init__.py | 39 ++ .../data/metrics/_metric_agg_handlers.py | 433 ++++++++++++++++++ torchtune/data/metrics/_metric_aggregator.py | 271 +++++++++++ torchtune/data/metrics/_metric_transform.py | 124 +++++ torchtune/data/metrics/readme.md | 176 +++++++ torchtune/datasets/_hf_iterable.py | 2 +- torchtune/datasets/_sft.py | 2 +- .../checkpointing/_checkpoint_client.py | 2 +- 16 files changed, 1059 insertions(+), 124 deletions(-) delete mode 100644 torchtune/data/_metrics.py create mode 100644 torchtune/data/metrics/__init__.py create mode 100644 torchtune/data/metrics/_metric_agg_handlers.py create mode 100644 torchtune/data/metrics/_metric_aggregator.py create mode 100644 torchtune/data/metrics/_metric_transform.py create mode 100644 torchtune/data/metrics/readme.md diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 5b65a6e9f5..4c41a81a5b 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -24,7 +24,8 @@ from torchdata.stateful_dataloader import StatefulDataLoader from torchtune import config, modules, training, utils from torchtune.config._utils import _get_component_from_path -from torchtune.data import MetricsAggregator, padded_collate_packed +from torchtune.data import padded_collate_packed +from torchtune.data.metrics import MetricsAggregator from torchtune.datasets import InterleavedDataset from torchtune.modules.embedding_utils import resize_token_embeddings from torchtune.modules.loss import SFTLoss diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index a9fda513a4..0691d9c32d 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -6,7 +6,7 @@ import pytest -from torchtune.data import AggregationType, Metric, MetricsAggregator +from torchtune.data.metrics import AggregationType, Metric, MetricsAggregator class TestMetricsAggregator: diff --git a/tests/torchtune/data/test_metrics_transform.py b/tests/torchtune/data/test_metrics_transform.py index 2c8f3023b2..8a4a86d7dd 100644 --- a/tests/torchtune/data/test_metrics_transform.py +++ b/tests/torchtune/data/test_metrics_transform.py @@ -6,15 +6,15 @@ import pytest -from torchtune.data import AggregationType, StandardMetricTransform +from torchtune.data.metrics import AggregationType, DefaultTrainingMetricTransform -class TestStandardMetricTransform: - """Tests for StandardMetricTransform functionality.""" +class TestDefaultTrainingMetricTransform: + """Tests for DefaultTrainingMetricTransform functionality.""" def test_dataset_name_not_set_raises_error(self): """Test that using transform without setting dataset name raises error.""" - transform = StandardMetricTransform() + transform = DefaultTrainingMetricTransform() sample = {"tokens": [1, 2, 3]} with pytest.raises(RuntimeError, match="set_dataset_name"): @@ -22,7 +22,7 @@ def test_dataset_name_not_set_raises_error(self): def test_basic_metrics_generation(self): """Test that transform generates expected metrics for a sample.""" - transform = StandardMetricTransform() + transform = DefaultTrainingMetricTransform() transform.set_dataset_name("test_dataset") sample = {"tokens": [1, 2, 3, 4, 5]} diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index 83144f0ae9..067fdc1294 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -11,7 +11,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data import MetricsAggregator, StandardMetricTransform +from torchtune.data.metrics import MetricsAggregator, StandardMetricTransform from torchtune.datasets import HfIterableDataset from .test_iterable_utils import collate_with_metrics, generate_ckpt diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index e1ded110ac..d8afcd2263 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -13,7 +13,7 @@ import torch from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data import MetricsAggregator, StandardMetricTransform +from torchtune.data.metrics import MetricsAggregator, StandardMetricTransform from torchtune.datasets import HfIterableDataset, InterleavedDataset # Import test utilities diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index 09292b9ba9..a75e16780a 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.data._aggregator import MetricsAggregator from torchtune.data._collate import ( left_pad_sequence, padded_collate, @@ -25,12 +24,6 @@ ShareGPTToMessages, validate_messages, ) -from torchtune.data._metrics import ( - AggregationType, - Metric, - MetricTransform, - StandardMetricTransform, -) from torchtune.data._prompt_templates import ( ChatMLTemplate, GrammarErrorCorrectionTemplate, @@ -42,13 +35,8 @@ from torchtune.data._utils import format_content_with_images, load_image, truncate __all__ = [ - "AggregationType", "CROSS_ENTROPY_IGNORE_IDX", "GrammarErrorCorrectionTemplate", - "Metric", - "MetricsAggregator", - "MetricTransform", - "StandardMetricTransform", "SummarizeTemplate", "OpenAIToMessages", "ShareGPTToMessages", diff --git a/torchtune/data/_aggregator.py b/torchtune/data/_aggregator.py index 313c5bdc73..66162f826b 100644 --- a/torchtune/data/_aggregator.py +++ b/torchtune/data/_aggregator.py @@ -11,7 +11,7 @@ import torch.distributed as dist -from torchtune.data._metrics import AggregationType, Metric +from torchtune.data.metrics import AggregationType, Metric logger = logging.getLogger(__name__) @@ -159,7 +159,8 @@ def _compute_local_metrics(self) -> dict[tuple[str, str], dict[str, Any]]: n = len(sorted_values) # Each stat becomes its own metric - # For percentiles, it is an approximattion by computing avg of averages + # so that we can all gather O(5) values across ranks + # instead of the entire distribution metrics[(ds_name, f"{metric_name}_mean")] = { "value": sum(values) / n, "agg_type": AggregationType.MEAN, diff --git a/torchtune/data/_metrics.py b/torchtune/data/_metrics.py deleted file mode 100644 index 7a38febb1e..0000000000 --- a/torchtune/data/_metrics.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass -from enum import Enum -from functools import partial -from typing import Any, Callable, Optional, Protocol, Union - - -class AggregationType(Enum): - """Defines how a metric's value should be aggregated.""" - - SUM = "sum" - MEAN = "mean" - DISTRIBUTION = "distribution" - CATEGORICAL_COUNT = "categorical_count" - MAX = "max" - MIN = "min" - - -@dataclass(frozen=True) -class Metric: - """A self-describing metric object.""" - - dataset_name: str - name: str - value: Union[int, float, str] - agg_type: AggregationType - - -class MetricTransform(Protocol): - """Protocol for metric transforms.""" - - def set_dataset_name(self, dataset_name: str) -> None: - ... - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - ... - - -class StandardMetricTransform(MetricTransform): - """ - Attaches per-sample metrics for tracking training progress. - - This transform is responsible for generating metrics on a per-sample - basis (e.g., tokens per sample). The actual aggregation of these metrics - (eg calculating sum of samples seen) is handled by the - `MetricsAggregator`. This separation of concerns ensures that metrics are - correctly aggregated even with multiple dataloader workers and in a - distributed setting. - - Tracked metrics include: - - samples_seen: A count of samples processed. - - tokens_seen: The cumulative sum of all tokens processed. - - seq_len: A distribution of sequence lengths. - """ - - def __init__(self): - # dataset_name is set by the dataset using set_dataset_name - self.dataset_name: Optional[str] = None - self.new_metric: Optional[Callable] = None - - def set_dataset_name(self, dataset_name: str) -> None: - """Called by dataset to set the namespace for metrics. - The dataset name is used to differentiate multiple datasets stats, - e.g. "train/dataset1/tokens_seen" and "train/dataset2/tokens_seen".""" - self.dataset_name = dataset_name - self.new_metric = partial(Metric, dataset_name=dataset_name) - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - if self.dataset_name is None or self.new_metric is None: - raise RuntimeError( - "set_dataset_name() must be called before using the transform." - ) - - # Determine token key - token_key = "tokens" if "tokens" in sample else "input_ids" - token_len = len(sample.get(token_key, [])) - - # Create metrics for this sample - metrics = [ - self.new_metric(name="samples_seen", value=1, agg_type=AggregationType.SUM), - self.new_metric( - name="tokens_seen", value=token_len, agg_type=AggregationType.SUM - ), - self.new_metric( - name="seq_len", value=token_len, agg_type=AggregationType.DISTRIBUTION - ), - ] - - # Append to existing metrics list or create new one - if "metrics" not in sample: - sample["metrics"] = [] - sample["metrics"].extend(metrics) - return sample diff --git a/torchtune/data/metrics/__init__.py b/torchtune/data/metrics/__init__.py new file mode 100644 index 0000000000..17e359d697 --- /dev/null +++ b/torchtune/data/metrics/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.data.metrics._metric_aggregator import MetricsAggregator +from torchtune.data.metrics._metric_agg_handlers import ( + AggregationHandler, + CategoricalCountAggHandler, + DistributionAggHandler, + MaxAggHandler, + MeanAggHandler, + MetricState, + MinAggHandler, + SumAggHandler, +) +from torchtune.data.metrics._metric_transform import ( + AggregationType, + DefaultTrainingMetricTransform, + Metric, + MetricTransform, +) + +__all__ = [ + "AggregationType", + "AggregationHandler", + "CategoricalCountAggHandler", + "DefaultTrainingMetricTransform", + "DistributionAggHandler", + "MaxAggHandler", + "MeanAggHandler", + "Metric", + "MetricState", + "MetricsAggregator", + "MetricTransform", + "MinAggHandler", + "SumAggHandler", +] diff --git a/torchtune/data/metrics/_metric_agg_handlers.py b/torchtune/data/metrics/_metric_agg_handlers.py new file mode 100644 index 0000000000..1a1557c803 --- /dev/null +++ b/torchtune/data/metrics/_metric_agg_handlers.py @@ -0,0 +1,433 @@ +import logging +from abc import ABC, abstractmethod +from collections import Counter, deque +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Union + +import torch + +from torchtune.data.metrics._metric_transform import Metric, AggregationType + +logger = logging.getLogger(__name__) + +@dataclass +class MetricState: + """Mutable state object representing aggregated metric for (dataset, metric) on a single rank. + + Args: + dataset_name (str): Name of the dataset. + metric_name (str): Name of the metric. + value (float): Current aggregated value, whose meaning depends on the aggregation type + (e.g., running sum, current max). + agg_type (AggregationType): Aggregation type. + metadata (dict[str, Any]): Additional state like count, list of values, etc. + """ + dataset_name: str + metric_name: str + value: float + agg_type: AggregationType + metadata: dict[str, Any] = field(default_factory=dict) + +class AggregationHandler(ABC): + """Base class for handling metric aggregation in MetricsAggregator. + + This class defines the interface for different aggregation strategies (e.g., SUM, MEAN). + Each handler is responsible for: + - Initializing the state for a new (dataset, metric) pair. + - Updating the state with new values. + - Finalizing the value for local (single-rank) logging. + - Reducing the values from all ranks in a distributed setting. + - Serializing and deserializing the metric state for checkpointing. + """ + + @abstractmethod + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + """Create a new MetricState for a (dataset_name, metric_name) pair. + + Args: + dataset_name (str): Name of the dataset. Especially useful when tracking multiple datasets. + metric_name (str): Name of the metric. + agg_type (AggregationType): Aggregation type. + + Returns: + MetricState: New MetricState for this (dataset_name, metric_name) pair. + """ + pass + + @abstractmethod + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + """Update cumulative MetricState with new metric info. + + Args: + local_agg_metric (MetricState): Cumulative state of the aggregation for this metric in the local rank. + metric (Metric): Input metric info. + """ + pass + + @abstractmethod + def finalize_local_agg( + self, local_agg_metric: MetricState + ) -> Union[MetricState, list[MetricState]]: + """ + Computes the final value from the locally aggregated state. + + In a distributed setting, this is called before the reduction step. + This method can also expand a single metric into multiple, for instance, + a distribution into mean, min, max, and percentiles. + + Args: + local_agg_metric (MetricState): The locally aggregated metric state to finalize. + + Returns: + A single `MetricState` or a list of them if the metric expands. + """ + pass + + @abstractmethod + def finalize_dist_agg( + self, local_agg_metrics: list[MetricState] + ) -> MetricState: + """ + Merge MetricStates from all ranks into final result. + + Args: + local_agg_metrics (list[MetricState]): list of MetricStates for this (dataset_name, metric_name) pair. + + Returns: + MetricState: Final result for this (dataset_name, metric_name) pair. + """ + pass + + def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert handler-specific metadata to serializable format. + + Args: + metadata (dict[str, Any]): AggHandler-specific metadata. + + Override this when using non-serializable types like deque or Counter. + For example, convert deque to list, Counter to dict. + """ + return metadata.copy() + + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Restore handler-specific metadata from serialized format. + + Args: + metadata (dict[str, Any]): AggHandler-specific metadata. + + Override this to reverse the serialize_metadata transformation. + For example, convert list back to deque, dict back to Counter. + """ + return metadata.copy() + + +class SumAggHandler(AggregationHandler): + """AggHandler for SUM aggregation. Initializes with 0.0 and accumulates metric values.""" + + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=0.0, + agg_type=agg_type + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.value += metric.value + + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: + return local_agg_metric + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + if not local_agg_metrics: + raise ValueError("Cannot aggregate empty list of metrics") + + total = sum(metric.value for metric in local_agg_metrics) + return MetricState( + dataset_name=local_agg_metrics[0].dataset_name, + metric_name=local_agg_metrics[0].metric_name, + value=total, + agg_type=local_agg_metrics[0].agg_type, + metadata=local_agg_metrics[0].metadata.copy() + ) + + +class MaxAggHandler(AggregationHandler): + """AggHandler for MAX aggregation. Tracks maximum value across all updates.""" + + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=float('-inf'), + agg_type=agg_type, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.value = max(local_agg_metric.value, metric.value) + + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: + return local_agg_metric + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + max_value = max(r.value for r in local_agg_metrics) + return MetricState( + dataset_name=local_agg_metrics[0].dataset_name, + metric_name=local_agg_metrics[0].metric_name, + value=max_value, + agg_type=local_agg_metrics[0].agg_type, + metadata=local_agg_metrics[0].metadata.copy() + ) + + +class MinAggHandler(AggregationHandler): + """AggHandler for MIN aggregation. Tracks minimum value across all updates.""" + + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=float('inf'), + agg_type=agg_type, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.value = min(local_agg_metric.value, metric.value) + + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: + return local_agg_metric + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + min_value = min(r.value for r in local_agg_metrics) + return MetricState( + dataset_name=local_agg_metrics[0].dataset_name, + metric_name=local_agg_metrics[0].metric_name, + value=min_value, + agg_type=local_agg_metrics[0].agg_type, + metadata=local_agg_metrics[0].metadata.copy() + ) + + +class MeanAggHandler(AggregationHandler): + """AggHandler for MEAN aggregation. Maintains running sum and count to compute average.""" + + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + metadata={"sum": 0.0, "count": 0}, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.metadata["sum"] += metric.value + local_agg_metric.metadata["count"] += 1 + + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: + count = local_agg_metric.metadata["count"] + local_agg_metric.value = local_agg_metric.metadata["sum"] / count if count > 0 else 0.0 + return local_agg_metric + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + total_sum = sum(metric.metadata["sum"] for metric in local_agg_metrics) + total_count = sum(metric.metadata["count"] for metric in local_agg_metrics) + + return MetricState( + dataset_name=local_agg_metrics[0].dataset_name, + metric_name=local_agg_metrics[0].metric_name, + value=total_sum / total_count if total_count > 0 else 0.0, + agg_type=local_agg_metrics[0].agg_type, + metadata={"sum": total_sum, "count": total_count}, + ) + + +class DistributionAggHandler(AggregationHandler): + """AggHandler for DISTRIBUTION aggregation. Maintains a sliding window of values + and expands into multiple statistical metrics (mean, min, max, percentiles, std). + + Note: Percentiles and standard deviation are approximated in distributed settings by averaging local + percentiles and standard deviations across ranks. This is mathematically imprecise but provides a + reasonable approximation for monitoring purposes. + """ + + def __init__(self, window_size: int = 1000): + """Initialize handler with specified window size for value retention. + + Args: + window_size (int): Maximum number of recent values to retain for statistics. + """ + if window_size <= 0: + raise ValueError(f"window_size must be positive, got {window_size}") + self.window_size = window_size + + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + metadata={"values": deque(maxlen=self.window_size)} + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.metadata["values"].append(metric.value) + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + values = list(local_agg_metric.metadata["values"]) + if not values: + return [] + + return self._compute_distribution_stats(local_agg_metric, values) + + def _compute_distribution_stats( + self, local_agg_metric: MetricState, values: list[float] + ) -> list[MetricState]: + """Compute statistical metrics from distribution values using torch for efficiency.""" + if not values: + return [] + + # Use float64 for precision matching python's float + values_tensor = torch.tensor(values, dtype=torch.float64) + n = len(values_tensor) + + # Compute all stats from the tensor + sum_val = torch.sum(values_tensor).item() + mean_val = sum_val / n + min_val = torch.min(values_tensor).item() + max_val = torch.max(values_tensor).item() + + # Compute all percentiles in one go + percentile_definitions = torch.tensor([0.05, 0.5, 0.95], dtype=torch.float64) + p05_val, p50_val, p95_val = torch.quantile(values_tensor, percentile_definitions).tolist() + + # Return multiple MetricStates with proper agg_types for distributed reduction + # NOTE: Percentiles use MEAN aggregation which approximates global percentiles + # by averaging local percentiles. + metrics = [ + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_mean", + value=mean_val, + agg_type=AggregationType.MEAN, + metadata={"sum": sum_val, "count": n}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_min", + value=min_val, + agg_type=AggregationType.MIN, + metadata={}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_max", + value=max_val, + agg_type=AggregationType.MAX, + metadata={}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_p05", + value=p05_val, + agg_type=AggregationType.MEAN, + metadata={"sum": p05_val, "count": 1}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_p50", + value=p50_val, + agg_type=AggregationType.MEAN, + metadata={"sum": p50_val, "count": 1}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_p95", + value=p95_val, + agg_type=AggregationType.MEAN, + metadata={"sum": p95_val, "count": 1}, + ), + ] + # Standard deviation is only well-defined for n > 1 + if n > 1: + std_val = torch.std(values_tensor).item() + metrics.append( + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_std", + value=std_val, + agg_type=AggregationType.MEAN, + metadata={"sum": std_val, "count": 1}, + ) + ) + return metrics + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + raise NotImplementedError( + "Metrics with AggregationType.DISTRIBUTION are converted to other " + "AggregationTypes for distributed reduction. finalize_dist_agg should not be called." + ) + + def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert deque to list for serialization.""" + serialized = metadata.copy() + if "values" in serialized: + serialized["values"] = list(serialized["values"]) + return serialized + + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert list back to deque.""" + deserialized = metadata.copy() + if "values" in deserialized: + deserialized["values"] = deque(deserialized["values"], maxlen=self.window_size) + return deserialized + + +class CategoricalCountAggHandler(AggregationHandler): + """AggHandler for CATEGORICAL_COUNT aggregation. Counts occurrences of categorical values + and expands into individual count metrics for each category.""" + + def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + metadata={"counts": Counter()} + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.metadata["counts"][metric.value] += 1 + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + # Expand categorical counts into individual metrics + results = [] + for category, count in local_agg_metric.metadata["counts"].items(): + results.append(MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_{category}_count", + value=count, + agg_type=AggregationType.SUM + )) + return results + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + raise NotImplementedError( + "Metrics with AggregationType.CATEGORICAL_COUNT are converted to other " + "AggregationTypes for distributed reduction. finalize_dist_agg should not be called." + ) + + def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert Counter to dict for serialization.""" + serialized = metadata.copy() + if "counts" in serialized: + serialized["counts"] = dict(serialized["counts"]) + return serialized + + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert dict back to Counter.""" + deserialized = metadata.copy() + if "counts" in deserialized: + deserialized["counts"] = Counter(deserialized["counts"]) + return deserialized \ No newline at end of file diff --git a/torchtune/data/metrics/_metric_aggregator.py b/torchtune/data/metrics/_metric_aggregator.py new file mode 100644 index 0000000000..c07f0dea36 --- /dev/null +++ b/torchtune/data/metrics/_metric_aggregator.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import ast +from collections import defaultdict +from typing import Any, tuple + +import torch.distributed as dist + +from torchtune.data.metrics._metric_agg_handlers import ( + AggregationHandler, + CategoricalCountAggHandler, + DistributionAggHandler, + MetricState, + MaxAggHandler, + MeanAggHandler, + MinAggHandler, + SumAggHandler, +) +from torchtune.data.metrics._metric_transform import Metric, AggregationType + +class MetricsAggregator: + """Aggregates metrics across datasets and distributed ranks using pluggable handlers. + + Uses a handler-based strategy pattern where each aggregation type (SUM, MEAN, etc.) + has its own handler. Maintains only one state per (dataset, metric) pair. + + When preparing for logging, uses a two-phase approach: + 1. Local aggregation: Each rank aggregates its metrics independently + 2. Distributed reduction: Results combined across ranks + + The aggregator is checkpointable and restores from state_dict for training resumption. + + Args: + dist_window_size (int): Window size for DistributionAggHandler tracking. + + Example: + >>> from torchtune.data.metrics import MetricsAggregator, Metric, AggregationType + >>> + >>> # Create aggregator + >>> aggregator = MetricsAggregator() + >>> + >>> # Sample metrics from different batches + >>> batch1_metrics = [ + ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + ... ] + >>> + >>> batch2_metrics = [ + ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + ... ] + >>> + >>> # Update with metrics + >>> aggregator.update(batch1_metrics) + >>> aggregator.update(batch2_metrics) + >>> + >>> # Get final results + >>> results = aggregator.get_metrics_for_logging(prefix="train") + >>> # {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} + """ + + def __init__(self, dist_window_size: int = 1000): + if dist_window_size <= 0: + raise ValueError(f"dist_window_size must be positive, got {dist_window_size}") + + # Storage: {(dataset, metric): MetricState} - O(unique metrics) not O(samples) + self._metric_states: dict[tuple[str, str], MetricState] = {} + self._dist_window_size = dist_window_size + + # Create handler registry - all handlers initialized upfront + self._handlers: dict[AggregationType, AggregationHandler] = { + AggregationType.SUM: SumAggHandler(), + AggregationType.MAX: MaxAggHandler(), + AggregationType.MIN: MinAggHandler(), + AggregationType.MEAN: MeanAggHandler(), + AggregationType.DISTRIBUTION: DistributionAggHandler(dist_window_size), + AggregationType.CATEGORICAL_COUNT: CategoricalCountAggHandler(), + } + + def register_handler(self, agg_type: AggregationType, handler: AggregationHandler) -> None: + """Register custom aggregation handler for specified type. + + Args: + agg_type (AggregationType): The aggregation type to handle + handler (AggregationHandler): Handler instance implementing the Ag∂gregationHandler interface + """ + self._handlers[agg_type] = handler + + def update(self, metrics: list[Metric]) -> None: + """Update (dataset_name, metric_name) metric state with new values. + + Args: + metrics (list[Metric]): List of metrics to update the state with + """ + for metric in metrics: + metric_key = (metric.dataset_name, metric.name) + handler = self._handlers.get(metric.agg_type) + + if handler is None: + raise ValueError(f"No handler registered for aggregation type: {metric.agg_type}") + + if metric_key not in self._metric_states: + self._metric_states[metric_key] = handler.initialize_metric_state( + metric.dataset_name, metric.name, metric.agg_type + ) + + local_agg_metric = self._metric_states[metric_key] + handler.update(local_agg_metric, metric) # Mutates local_agg_metric + + def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]: + """Get final metrics for logging in standard format. + + Args: + prefix (str): Prefix for metric names in the returned dictionary + + Returns: + dict[str, float]: Dictionary with keys like "{prefix}_{dataset_name}/{metric_name}" + and float values. For example, with `prefix="train"`, `dataset_name="alpaca"`, + `metric_name="loss"`, the key would be `train_alpaca/loss`. + """ + final_results = self._compute_unified_metrics() + + return { + f"{prefix}_{result.dataset_name}/{result.metric_name}": result.value + for result in final_results + } + + def _compute_unified_metrics(self) -> list[MetricState]: + """ + Compute metrics handling both local and distributed cases uniformly. + + Returns: + list[MetricState]: Final results ready for logging + """ + # Step 1: Get local results from all handlers (may expand distributions/categoricals) + prepared_results = [] + for local_agg_metric in self._metric_states.values(): + handler = self._handlers[local_agg_metric.agg_type] + prepared = handler.finalize_local_agg(local_agg_metric) + if isinstance(prepared, list): # Distribution/categorical expands to multiple + prepared_results.extend(prepared) + else: + prepared_results.append(prepared) + + # Step 2: Apply distributed reduction if needed + if dist.is_initialized() and dist.get_world_size() > 1: + prepared_results = self._finalize_dist_agg(prepared_results) + + return prepared_results + + def _finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> list[MetricState]: + """Apply distributed reduction to local results. + + Args: + local_agg_metrics (list[MetricState]): (dataset_name, metric_name) metric pairs from this rank + + Returns: + list[MetricState]: Reduced results combining all ranks + """ + world_size = dist.get_world_size() + + # Gather all results from all ranks + all_results = [None] * world_size + dist.all_gather_object(all_results, local_agg_metrics) + + # Group by (dataset_name, metric_name) for reduction + grouped = defaultdict(list) + for rank_results in all_results: + if rank_results: # Handle ranks with no metrics + for result in rank_results: + result_key = (result.dataset_name, result.metric_name) + grouped[result_key].append(result) + + # Apply handler-specific distributed reduction + reduced_results = [] + for result_key, results_list in grouped.items(): + if not results_list: + continue # Skip empty groups + + # All results for a key should have same agg_type + agg_type = results_list[0].agg_type + handler = self._handlers[agg_type] + reduced_result = handler.finalize_dist_agg(results_list) + reduced_results.append(reduced_result) + + return reduced_results + + def state_dict(self) -> dict[str, Any]: + """Serialize aggregator state for checkpointing. + + Returns: + dict[str, Any]: Serializable dictionary containing all aggregator state + """ + serializable_state = {} + required_agg_types = set() # Track aggregation types used in saved states + + for metric_key, local_agg_metric in self._metric_states.items(): + # Get handler for this result's aggregation type + handler = self._handlers[local_agg_metric.agg_type] + required_agg_types.add(local_agg_metric.agg_type) + + # Convert MetricState to serializable dict + result_dict = { + "dataset_name": local_agg_metric.dataset_name, + "metric_name": local_agg_metric.metric_name, + "value": local_agg_metric.value, + "agg_type": local_agg_metric.agg_type, + "metadata": handler.serialize_metadata(local_agg_metric.metadata) + } + + # Convert tuple key to string for JSON compatibility + serializable_state[str(metric_key)] = result_dict + + return { + "state": serializable_state, + "dist_window_size": self._dist_window_size, + "required_agg_types": list(required_agg_types) # Save which handlers are needed + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load aggregator state from checkpoint. + + Args: + state_dict (dict[str, Any]): Dictionary containing serialized aggregator state + + Raises: + ValueError: If required handlers are missing after checkpoint restore + """ + self._dist_window_size = state_dict.get("dist_window_size", 1000) + + # Sanity check: Ensure all required handlers are available + required_agg_types = state_dict.get("required_agg_types", []) + missing_handlers = [] + for agg_type in required_agg_types: + if agg_type not in self._handlers: + missing_handlers.append(agg_type) + + if missing_handlers: + raise ValueError( + f"Missing handlers for aggregation types: {missing_handlers}. " + f"Custom handlers must be re-registered before checkpoint restore." + ) + + deserialized_state = {} + for key_str, result_dict in state_dict["state"].items(): + # Convert string keys back to tuples + metric_key = ast.literal_eval(key_str) + + # Get handler for this aggregation type + agg_type = result_dict["agg_type"] + handler = self._handlers[agg_type] + + # Restore metadata using handler-specific deserialization + metadata = handler.deserialize_metadata(result_dict["metadata"]) + + # Create MetricState from dict + local_agg_metric = MetricState( + dataset_name=result_dict["dataset_name"], + metric_name=result_dict["metric_name"], + value=result_dict["value"], + agg_type=result_dict["agg_type"], + metadata=metadata + ) + + deserialized_state[metric_key] = local_agg_metric + + self._metric_states = deserialized_state diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py new file mode 100644 index 0000000000..0e252cb2e6 --- /dev/null +++ b/torchtune/data/metrics/_metric_transform.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from functools import partial +from typing import Any, Callable, Optional, Union + +from torchtune.modules.transforms import Transform + +@dataclass(frozen=True) +class Metric: + dataset_name: str + name: str + value: Union[int, float, str] + agg_type: "AggregationType" + +class AggregationType(Enum): + """Defines how a metric's value should be aggregated.""" + + SUM = "sum" + MEAN = "mean" + DISTRIBUTION = "distribution" + CATEGORICAL_COUNT = "categorical_count" + MAX = "max" + MIN = "min" + +class MetricTransform(Transform): + """Applied to each sample to generate per-sample metrics for training tracking. + + Creates Metric objects that are later aggregated by 'MetricsAggregator'. This separation + of concerns ensures metrics are correctly aggregated even with multiple dataloader + workers and in distributed settings.""" + + def __init__(self): + # dataset_name is set by the dataset using set_dataset_name + self.dataset_name: Optional[str] = None + self.new_metric: Optional[Callable] = None + + def set_dataset_name(self, dataset_name: str) -> None: + """Called by dataset to set the namespace for metrics. + + The dataset name is used to differentiate multiple datasets stats, + e.g. "train/dataset1/tokens_seen" and "train/dataset2/tokens_seen". + + Args: + dataset_name (str): Name of the dataset for metric namespacing + """ + self.dataset_name = dataset_name + # Create a partial to make it easier to create new metrics + self.new_metric = partial(Metric, dataset_name=dataset_name) + + def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: + """Generate metrics for a single sample. Must be implemented by subclasses. + + Args: + sample (dict[str, Any]): The sample dictionary to generate metrics from + + Returns: + list[Metric]: List of metrics generated for this sample + """ + raise NotImplementedError( + "Subclasses must implement _generate_metrics method" + ) + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Apply transform to sample, adding generated metrics. """ + if self.dataset_name is None or self.new_metric is None: + raise RuntimeError( + "set_dataset_name() must be called before using the transform." + ) + + # Generate metrics for this sample + metrics = self._generate_metrics(sample) + + # Add to existing metrics list or create new one + if "metrics" not in sample: + sample["metrics"] = [] + sample["metrics"].extend(metrics) + return sample + + +class DefaultTrainingMetricTransform(MetricTransform): + """Generates training metrics: samples_seen, tokens_seen, seq_len distribution. + + For details about MetricTransform base class behavior, see the parent class docstring. + + Tracked metrics: + - samples_seen: Cumulative count of samples processed (SUM aggregation) + - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) + - seq_len: Distribution of sequence lengths (DISTRIBUTION aggregation) + + Example: + >>> transform = DefaultTrainingMetricTransform() + >>> transform.set_dataset_name("alpaca") + >>> + >>> sample = {"tokens": [1, 2, 3, 4, 5]} # 5 tokens + >>> metrics = transform._generate_metrics(sample) + >>> # Creates: + >>> # [ + >>> # Metric(dataset_name="alpaca", name="samples_seen", value=1, agg_type=AggregationType.SUM), + >>> # Metric(dataset_name="alpaca", name="tokens_seen", value=5, agg_type=AggregationType.SUM), + >>> # Metric(dataset_name="alpaca", name="seq_len", value=5, agg_type=AggregationType.DISTRIBUTION) + >>> # ] + """ + + def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: + # Determine token key + token_key = "tokens" if "tokens" in sample else "input_ids" + token_len = len(sample.get(token_key, [])) + + # Create metrics for this sample + return [ + self.new_metric(name="samples_seen", value=1, agg_type=AggregationType.SUM), + self.new_metric( + name="tokens_seen", value=token_len, agg_type=AggregationType.SUM + ), + self.new_metric( + name="seq_len", value=token_len, agg_type=AggregationType.DISTRIBUTION + ), + ] diff --git a/torchtune/data/metrics/readme.md b/torchtune/data/metrics/readme.md new file mode 100644 index 0000000000..5b79c5e98b --- /dev/null +++ b/torchtune/data/metrics/readme.md @@ -0,0 +1,176 @@ +# TorchTune Metrics Module + +## Overview + +The metrics module provides a robust system for tracking and aggregating training metrics across multiple datasets and distributed environments. It follows a **strategy pattern** design with pluggable aggregation handlers to efficiently handle different types of metrics. + +## Architecture Overview + +``` +┌────────────────────────────────────────────────────┐ +│ Training Loop │ +└─────────────────────┬──────────────────────────────┘ + │ +┌─────────────────────▼──────────────────────────────┐ +│ MetricTransform │ +│ • Applied to each sample │ +│ • Generates per-sample metrics │ +│ • Examples: tokens_seen, seq_len, samples_seen │ +└─────────────────────┬──────────────────────────────┘ + │ list[Metric] +┌─────────────────────▼──────────────────────────────┐ +│ MetricsAggregator │ +│ • Aggregates metrics across samples and ranks │ +│ • Uses pluggable AggregationHandlers │ +│ • Handles distributed reduction │ +└─────────────────────┬──────────────────────────────┘ + │ {prefix_dataset/metric: value} +┌─────────────────────▼──────────────────────────────┐ +│ Logging System │ +│ • W&B, TensorBoard, etc. │ +│ • Gets formatted metrics ready for logging │ +└────────────────────────────────────────────────────┘ +``` + +## File Structure + +- **`_metric_transform.py`**: Defines `Metric`, `AggregationType`, and transform classes +- **`_metric_agg_handlers.py`**: Aggregation strategy implementations +- **`_metric_aggregator.py`**: Main aggregator orchestrating the handlers + +## Customizing metrics + +- **Custom transforms**: Extend `MetricTransform` for domain-specific metrics +- **Handler registration**: Register custom handlers for specialized aggregation needs + +####### +## TODO +## Move this from here to website docs +####### + +## Core Components + +### 1. MetricTransform +Generates per-sample metrics during data processing. + +**Key Features:** +- Applied to each sample in the dataset +- Creates `Metric` objects with dataset name, metric name, value, and aggregation type +- Handles dataset namespacing for multi-dataset scenarios + +**Example Usage:** +```python +from torchtune.data.metrics import DefaultTrainingMetricTransform, AggregationType + +transform = DefaultTrainingMetricTransform() +transform.set_dataset_name("alpaca") + +# Applied to each sample +sample = {"tokens": [1, 2, 3, 4, 5]} +sample = transform(sample) +# sample["metrics"] now contains: +# [ +# Metric(dataset_name="alpaca", name="samples_seen", value=1, agg_type=AggregationType.SUM), +# Metric(dataset_name="alpaca", name="tokens_seen", value=5, agg_type=AggregationType.SUM), +# Metric(dataset_name="alpaca", name="seq_len", value=5, agg_type=AggregationType.DISTRIBUTION) +# ] +``` + +### 2. MetricsAggregator +Efficiently aggregates metrics across samples and distributed ranks. + +**Key Features:** +- Handler-based strategy pattern for different aggregation types +- Distributed-aware with automatic rank reduction +- Checkpointable state for training resumption +- Keep track of (metric, dataset) pairs + +**Aggregation Types (at the time of writing):** +- `SUM`: Cumulative totals (e.g., total tokens processed) +- `MEAN`: Running averages (e.g., average loss) +- `MAX/MIN`: Extrema tracking (e.g., max sequence length seen) +- `DISTRIBUTION`: Statistical summaries (mean, min, max, percentiles) +- `CATEGORICAL_COUNT`: Category cumulative counts (e.g. num of samples from a given category) + +**Example Usage:** +```python +from torchtune.data.metrics import MetricsAggregator, Metric, AggregationType + +# Create aggregator +aggregator = MetricsAggregator() + +# Sample metrics from different batches +batch1_metrics = [ + Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), +] + +batch2_metrics = [ + Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), +] + +# Update with metrics +aggregator.update(batch1_metrics) +aggregator.update(batch2_metrics) + +# Get final results +results = aggregator.get_metrics_for_logging(prefix="train") +# {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} +``` + +### 3. AggregationHandlers +Pluggable strategies for different aggregation patterns. + +``` +AggregationHandler (ABC) +├── SumAggHandler # value += metric.value +├── MeanAggHandler # tracks sum and count +├── MaxAggHandler # value = max(value, metric.value) +├── MinAggHandler # value = min(value, metric.value) +├── DistributionAggHandler # maintains value window + stats +└── CategoricalCountAggHandler # Counter for categories +``` + +**Custom Handler Example:** +```python +class CustomAggHandler(AggregationHandler): + def initialize_metric_state(self, dataset_name, metric_name, agg_type): + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=, # should change + agg_type=agg_type, + metadata={} # may need to change + ) + + def update(self, local_agg_metric, metric): + ... + + def finalize_local_agg(self, local_agg_metric): + ... + + def finalize_dist_agg(self, local_agg_metrics): + ... + +# Register with aggregator +aggregator.register_handler(AggregationType.CUSTOM, CustomAggHandler()) +``` + +## Distributed Training Support + +The metrics system automatically handles distributed environments: + +1. **Local Aggregation**: Each rank aggregates its own metrics +2. **Distributed Reduction**: Results are combined across ranks using `all_gather_object` +3. **Type-Aware Reduction**: Each aggregation type uses appropriate reduction (sum, mean, max, etc.) + +**Distributed Flow:** +``` +Rank 0: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)] +Rank 1: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)] + ↓ + AllGather + Reduce + ↓ + Final Results [(ds1, metric1), (ds1, metric2)] +``` \ No newline at end of file diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 0be4c0cc53..a9495827cf 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -12,7 +12,7 @@ from datasets import load_dataset from datasets.distributed import split_dataset_by_node -from torchtune.data._metrics import AggregationType, Metric, StandardMetricTransform +from torchtune.data.metrics import AggregationType, Metric, StandardMetricTransform from torchtune.datasets._iterable_base import TuneIterableDataset logger = logging.getLogger(__name__) diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 6dabee9bb6..72289c14dc 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -12,7 +12,7 @@ from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages -from torchtune.data._metrics import StandardMetricTransform +from torchtune.data.metrics import StandardMetricTransform from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.modules.transforms import Transform diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index 05fd46e395..d943ad697f 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -19,7 +19,7 @@ StateDictOptions, ) from torchtune import config, training, utils -from torchtune.data import MetricsAggregator +from torchtune.data.metrics import MetricsAggregator from torchtune.modules.optim import OptimizerInBackward from torchtune.modules.peft import ( get_adapter_state_dict, From 2eab08db738be364b0b9edfe4218cb6c6fb8f281 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 1 Jul 2025 21:21:32 -0400 Subject: [PATCH 08/25] remove file after refactoring --- torchtune/data/_aggregator.py | 343 ---------------------------------- 1 file changed, 343 deletions(-) delete mode 100644 torchtune/data/_aggregator.py diff --git a/torchtune/data/_aggregator.py b/torchtune/data/_aggregator.py deleted file mode 100644 index 66162f826b..0000000000 --- a/torchtune/data/_aggregator.py +++ /dev/null @@ -1,343 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import ast -import collections -import logging -from typing import Any - -import torch.distributed as dist - -from torchtune.data.metrics import AggregationType, Metric - -logger = logging.getLogger(__name__) - - -class MetricsAggregator: - """ - Aggregates metrics across datasets and distributed ranks. - - The internal state `_state` is a dictionary where the key is a tuple - of `(dataset_name, metric_name)` and the value is another dictionary - holding the metric's specific state (e.g., `{'type': AggregationType.SUM, 'value': 10}`). - - Usage: - aggregator = MetricsAggregator() - aggregator.update(metrics) - # Get logger-ready metrics {key: value} - metrics = aggregator.get_metrics_for_logging(prefix="train") # {"train/dataset1/tokens": 1234, ...} - """ - - def __init__(self, dist_window_size: int = 1000): - # State shape: {(dataset_name, metric_name): {type: AggType, value/sum/counts/etc}} - self._state: dict[tuple[str, str], dict[str, Any]] = {} - - # For distributions, we keep a window of values to compute percentiles - self._dist_window_size = dist_window_size - - def update(self, metrics: list[Metric]) -> None: - """Update internal state with new metrics. - - Args: - metrics (list[Metric]): list of Metric objects - """ - for metric in metrics: - key = (metric.dataset_name, metric.name) - - if key not in self._state: - self._initialize_state(key, metric.agg_type) - - state = self._state[key] - - # Update based on aggregation type - if metric.agg_type == AggregationType.SUM: - state["value"] += metric.value - elif metric.agg_type == AggregationType.MAX: - if state["value"] is not None: - state["value"] = max(state["value"], metric.value) - else: - state["value"] = metric.value - elif metric.agg_type == AggregationType.MIN: - if state["value"] is not None: - state["value"] = min(state["value"], metric.value) - else: - state["value"] = metric.value - elif metric.agg_type == AggregationType.MEAN: - state["sum"] += metric.value - state["count"] += 1 - elif metric.agg_type == AggregationType.DISTRIBUTION: - state["values"].append(metric.value) - elif metric.agg_type == AggregationType.CATEGORICAL_COUNT: - state["counts"][metric.value] += 1 - - def _initialize_state( - self, key: tuple[str, str], agg_type: AggregationType - ) -> None: - """Initialize state for a new metric.""" - self._state[key] = {"type": agg_type} - state = self._state[key] - - if agg_type == AggregationType.SUM: - state["value"] = 0.0 - elif agg_type in (AggregationType.MAX, AggregationType.MIN): - state["value"] = None - elif agg_type == AggregationType.MEAN: - state["sum"] = 0.0 - state["count"] = 0 - elif agg_type == AggregationType.DISTRIBUTION: - state["values"] = collections.deque(maxlen=self._dist_window_size) - elif agg_type == AggregationType.CATEGORICAL_COUNT: - state["counts"] = collections.Counter() - - def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]: - """ - Returns aggregated metrics ready for logging to wandb/tensorboard. - - Args: - prefix (str): Optional prefix like "train" or "valid" for metric keys - - Returns: - dict[str, float]: Flat dictionary with keys like "train/dataset1/tokens_seen" -> float value - Ready to be logged directly: wandb.log(metrics) - """ - # Always compute local metrics first - local_metrics = self._compute_local_metrics() - - # In distributed mode, perform reduction - if dist.is_initialized() and dist.get_world_size() > 1: - metrics = self._compute_distributed_metrics(local_metrics) - else: - metrics = local_metrics - - # Format for logging with proper key structure - return self._format_for_logging(metrics, prefix) - - def _compute_local_metrics(self) -> dict[tuple[str, str], dict[str, Any]]: - """ - Compute metrics from current state. - - For distributions and categoricals, expands into multiple entries. - The dict format allows future extensions with additional fields. - - Returns: - dict[tuple[str, str], dict[str, Any]]: dictionary mapping - (dataset_name, metric_name) -> {"value": value, "agg_type": aggregation_type} - """ - metrics = {} - - for (ds_name, metric_name), state in self._state.items(): - agg_type = state["type"] - - if agg_type in ( - AggregationType.SUM, - AggregationType.MAX, - AggregationType.MIN, - ): - # For sum, max, and min, we just need to return the value - metrics[(ds_name, metric_name)] = { - "value": state["value"], - "agg_type": agg_type, - } - - elif agg_type == AggregationType.MEAN: - if state["count"] > 0: - value = state["sum"] / state["count"] - metrics[(ds_name, metric_name)] = { - "value": value, - "agg_type": agg_type, - } - - elif agg_type == AggregationType.DISTRIBUTION: - # queue -> list - values = list(state["values"]) - - # Sort to get percentiles efficiently - sorted_values = sorted(values) - n = len(sorted_values) - - # Each stat becomes its own metric - # so that we can all gather O(5) values across ranks - # instead of the entire distribution - metrics[(ds_name, f"{metric_name}_mean")] = { - "value": sum(values) / n, - "agg_type": AggregationType.MEAN, - } - metrics[(ds_name, f"{metric_name}_min")] = { - "value": sorted_values[0], - "agg_type": AggregationType.MIN, - } - metrics[(ds_name, f"{metric_name}_max")] = { - "value": sorted_values[-1], - "agg_type": AggregationType.MAX, - } - metrics[(ds_name, f"{metric_name}_p05")] = { - "value": sorted_values[max(0, int(0.05 * n) - 1)], - "agg_type": AggregationType.MEAN, - } - metrics[(ds_name, f"{metric_name}_p50")] = { - "value": sorted_values[max(0, int(0.5 * n) - 1)], - "agg_type": AggregationType.MEAN, - } - metrics[(ds_name, f"{metric_name}_p95")] = { - "value": sorted_values[max(0, int(0.95 * n) - 1)], - "agg_type": AggregationType.MEAN, - } - - elif agg_type == AggregationType.CATEGORICAL_COUNT: - # Expand categorical counts into individual metrics - for category, count in state["counts"].items(): - metrics[(ds_name, f"{metric_name}_{category}_count")] = { - "value": count, - "agg_type": AggregationType.SUM, - } - - return metrics - - def _compute_distributed_metrics( - self, local_metrics: dict[tuple[str, str], dict[str, Any]] - ) -> dict[tuple[str, str], dict[str, Any]]: - """ - Performs distributed reduction on metrics. - - Strategy: - 1. Do a single all_gather_object to collect all metrics from all ranks - 2. Group metrics by key and aggregation type - 3. Apply the appropriate reduction operation locally - - This avoids complex tensor operations and handles all reduction in one pass. - - Args: - local_metrics (dict[tuple[str, str], dict[str, Any]]): dict mapping - (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} - - Returns: - dict[tuple[str, str], dict[str, Any]]: Reduced metrics in same format as input - - Example: - rank_1_metrics = - { - ("ds1", "metric1"): {"value": 10, "agg_type": AggregationType.SUM}, - ("ds2", "metric2"): {"value": 20, "agg_type": AggregationType.MEAN}, - } - rank_2_metrics = - { - ("ds1", "metric1"): {"value": 30, "agg_type": AggregationType.SUM}, - ("ds2", "metric2"): {"value": 40, "agg_type": AggregationType.MEAN}, - } - - # After reduction - result = - { - ("ds1", "metric1"): {"value": 40, "agg_type": AggregationType.SUM}, - ("ds2", "metric2"): {"value": 30, "agg_type": AggregationType.MEAN}, - } - """ - world_size = dist.get_world_size() - - # Gather all metrics from all ranks in one operation - all_metrics = [None] * world_size - dist.all_gather_object(all_metrics, local_metrics) - - # Group values by key for reduction - grouped = collections.defaultdict(list) - for rank_metrics in all_metrics: - if rank_metrics: # It's possible a rank has no metrics - for key, metric_dict in rank_metrics.items(): - # A key is a tuple (dataset, metric) - grouped[key].append(metric_dict) - - # Reduce based on aggregation type - reduced = {} - if not grouped: - return reduced - - for key, metric_dicts in grouped.items(): - # All metrics for a key should have same type, just take first - values = [m["value"] for m in metric_dicts] - agg_type = metric_dicts[0]["agg_type"] - - # Start with copy of first dict to preserve any extra fields - result_dict = metric_dicts[0].copy() - - if agg_type == AggregationType.SUM: - result_dict["value"] = sum(values) - elif agg_type == AggregationType.MAX: - result_dict["value"] = max(values) - elif agg_type == AggregationType.MIN: - result_dict["value"] = min(values) - elif agg_type == AggregationType.MEAN: - result_dict["value"] = sum(values) / len(values) - - reduced[key] = result_dict - - return reduced - - def _format_for_logging( - self, - metrics: dict[tuple[str, str], dict[str, Any]], - prefix: str, - template: str = r"{prefix}_{ds_name}/{metric_name}", - ) -> dict[str, float]: - """ - Format metrics for wandb/tensorboard logging. - - Args: - metrics (dict[tuple[str, str], dict[str, Any]]): dict mapping - (dataset, metric) -> {"value": value, "agg_type": agg_type, ...} - prefix (str): Optional prefix like "train" or "valid" - template (str): Template for metric key. Use {prefix}, {ds_name}, and {metric_name} as placeholders. - - Returns: - dict[str, float]: Flat dict with string keys like "train/dataset1/tokens_seen" -> float - """ - formatted = {} - - for (ds_name, metric_name), metric_dict in metrics.items(): - # Use regex format to build key - key = template.format( - prefix=prefix, ds_name=ds_name, metric_name=metric_name - ) - formatted[key] = metric_dict["value"] - - return formatted - - def state_dict(self) -> dict[str, Any]: - """Serialize aggregator state. The state is almost directly serializable.""" - serializable_state = {} - for key, state in self._state.items(): - state_copy = state.copy() - - # Convert non-serializable types - if "values" in state_copy: - state_copy["values"] = list(state_copy["values"]) # deque → list - if "counts" in state_copy: - state_copy["counts"] = dict(state_copy["counts"]) # Counter → dict - - # Convert tuple key to string for JSON compatibility - # JSON doesn't support tuple keys, so we convert (dataset, metric) → "('dataset', 'metric')" - serializable_state[str(key)] = state_copy - return {"state": serializable_state, "dist_window_size": self._dist_window_size} - - def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Load aggregator state from checkpoint.""" - self._dist_window_size = state_dict["dist_window_size"] - - deserialized_state = {} - for key_str, state in state_dict["state"].items(): - # Convert string keys back to tuples - # "('dataset', 'metric')" → ('dataset', 'metric') - key = ast.literal_eval(key_str) - - # Re-wrap values in their original types - if state.get("type") == AggregationType.DISTRIBUTION: - state["values"] = collections.deque( - state["values"], maxlen=self._dist_window_size - ) - if state.get("type") == AggregationType.CATEGORICAL_COUNT: - state["counts"] = collections.Counter(state["counts"]) - - deserialized_state[key] = state - self._state = deserialized_state From 58491f1f90aabdab242b59bf79094b9cc3ad82b9 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 10:05:50 -0400 Subject: [PATCH 09/25] add distributed tsts --- .../torchtune/data/test_metrics_aggregator.py | 159 ++++++++++++++++++ tests/torchtune/datasets/test_hf_iterable.py | 101 +++++++++++ tests/torchtune/datasets/test_interleaved.py | 110 ++++++++++++ 3 files changed, 370 insertions(+) diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index 0691d9c32d..bbfb9821cd 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -5,6 +5,9 @@ # LICENSE file in the root directory of this source tree. import pytest +import torch.distributed as dist +from torch.testing._internal.common_fsdp import FSDPTest +from tests.test_utils import gpu_test from torchtune.data.metrics import AggregationType, Metric, MetricsAggregator @@ -147,3 +150,159 @@ def test_prefix_handling(self): result_no_prefix = aggregator.get_metrics_for_logging() assert result_no_prefix["data_test_ds/metric1"] == 42 assert result_no_prefix["data_test_ds/metric2"] == 84 + + +class TestDistributedMetricsAggregator(FSDPTest): + """Distributed tests for MetricsAggregator using FSDPTest infrastructure.""" + + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_distributed_all_aggregation_types(self): + """ + Test that all aggregation types work correctly in distributed setting. + Each rank contributes different values to ensure proper reduction across ranks. + """ + aggregator = MetricsAggregator() + rank = dist.get_rank() + + # Each rank contributes different values to test cross-rank aggregation + base_value = (rank + 1) * 10 # rank 0: 10, rank 1: 20 + + metrics = [ + Metric("test", "sum_metric", base_value, AggregationType.SUM), + Metric("test", "mean_metric", base_value + 5, AggregationType.MEAN), + Metric("test", "max_metric", base_value * 10, AggregationType.MAX), + Metric("test", "min_metric", base_value // 2, AggregationType.MIN), + ] + + # DISTRIBUTION: Each rank adds 5 values for distribution statistics + # rank 0: [0, 1, 2, 3, 4], rank 1: [10, 11, 12, 13, 14] + for i in range(5): + metrics.append( + Metric("test", "dist_metric", rank * 10 + i, AggregationType.DISTRIBUTION) + ) + + # CATEGORICAL_COUNT: Different categories per rank to test counting + # rank 0: 3 of cat_A, 2 of cat_B + # rank 1: 1 of cat_A, 4 of cat_C + if rank == 0: + metrics.extend([ + Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT), + ]) + else: + metrics.extend([ + Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), + ]) + + # Update aggregator and get results + aggregator.update(metrics) + result = aggregator.get_metrics_for_logging(prefix="train") + + # Verify aggregation results across all ranks + # SUM: rank 0 adds 10, rank 1 adds 20 -> total 30 + # MEAN: rank 0 has 15, rank 1 has 25 -> avg 20 + # MAX: rank 0 has 100, rank 1 has 200 -> max 200 + # MIN: rank 0 has 5, rank 1 has 10 -> min 5 + assert result["train_test/sum_metric"] == 30 + assert result["train_test/mean_metric"] == 20 + assert result["train_test/max_metric"] == 200 + assert result["train_test/min_metric"] == 5 + + # DISTRIBUTION: Combined values [0,1,2,3,4,10,11,12,13,14] + # Mean should be average of local means: (2 + 12) / 2 = 7 + assert result["train_test/dist_metric_mean"] == 7 + assert result["train_test/dist_metric_min"] == 0 + assert result["train_test/dist_metric_max"] == 14 + + # CATEGORICAL_COUNT: Total counts across ranks + # cat_A: 3(rank0) + 1(rank1) = 4, cat_B: 2(rank0) + 0(rank1) = 2, cat_C: 0(rank0) + 4(rank1) = 4 + assert result["train_test/cat_metric_cat_A_count"] == 4 + assert result["train_test/cat_metric_cat_B_count"] == 2 + assert result["train_test/cat_metric_cat_C_count"] == 4 + + @gpu_test(gpu_count=2) + def test_distributed_state_dict_resumption(self): + """ + Test that MetricsAggregator state_dict save/restore works correctly in distributed setting. + Verifies: + - State can be saved after partial updates across ranks + - State can be restored consistently across ranks + - Continued updates after restore produce identical results + - Distributed aggregation works correctly after restoration + """ + rank = dist.get_rank() + + # Phase 1: Create aggregator and add initial metrics + aggregator1 = MetricsAggregator() + + # Each rank contributes different initial values + base_value = rank * 100 # rank 0: 0, rank 1: 100 + + initial_metrics = [ + Metric("test", "sum_metric", base_value, AggregationType.SUM), + Metric("test", "mean_metric", base_value // 2, AggregationType.MEAN), + Metric("test", "max_metric", base_value * 2, AggregationType.MAX), + ] + + # Add some DISTRIBUTION values - each rank adds 3 values + for i in range(3): + initial_metrics.append( + Metric("test", "dist_metric", rank * 100 + i, AggregationType.DISTRIBUTION) + ) + + # Add CATEGORICAL_COUNT values + if rank == 0: + initial_metrics.extend([ + Metric("test", "cat_metric", "type_A", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "type_A", AggregationType.CATEGORICAL_COUNT), + ]) + else: + initial_metrics.extend([ + Metric("test", "cat_metric", "type_B", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "type_B", AggregationType.CATEGORICAL_COUNT), + Metric("test", "cat_metric", "type_B", AggregationType.CATEGORICAL_COUNT), + ]) + + aggregator1.update(initial_metrics) + + # Save state_dict after initial update + state_dict = aggregator1.state_dict() + + # Phase 2: Create new aggregator and restore from state_dict + aggregator2 = MetricsAggregator() + aggregator2.load_state_dict(state_dict) + + # Verify both aggregators produce identical results after restore + result1 = aggregator1.get_metrics_for_logging(prefix="checkpoint") + result2 = aggregator2.get_metrics_for_logging(prefix="checkpoint") + assert result1 == result2, ( + f"Rank {rank}: Aggregators differ after state_dict restore" + ) + + # Phase 3: Add more metrics to both aggregators + additional_metrics = [ + Metric("test", "sum_metric", rank * 1000, AggregationType.SUM), + Metric("test", "min_metric", rank * 1000, AggregationType.MIN), + ] + + # Update both aggregators with additional metrics + aggregator1.update(additional_metrics) + aggregator2.update(additional_metrics) + + # Phase 4: Verify final results are identical across both aggregators + final_result1 = aggregator1.get_metrics_for_logging(prefix="final") + final_result2 = aggregator2.get_metrics_for_logging(prefix="final") + assert final_result1 == final_result2, ( + f"Rank {rank}: Final results differ after continued updates" + ) diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index 067fdc1294..ba644515c2 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -4,10 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math +import tempfile +import shutil from itertools import islice from pathlib import Path import pytest +import torch.distributed as dist +from torch.testing._internal.common_fsdp import FSDPTest +from tests.test_utils import gpu_test from torchdata.stateful_dataloader import StatefulDataLoader @@ -241,3 +247,98 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): assert all( epoch_value == 1 for epoch_value in epoch_values ), f"Epoch values should be 1, got {epoch_values}" + + +class TestDistributedHfIterableDataset(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_distributed_epoch_boundary_checkpointing(self): + """ + Test epoch boundary handling with checkpointing in distributed setting. + Ensures proper handling of: + - Checkpointing at 0.9, 1.0, and 2.5 epoch boundaries + - Correct sample distribution across epochs + - Proper state restoration after checkpointing + """ + rank = dist.get_rank() + + # Create shared temp directory (only rank 0 creates it) + if rank == 0: + temp_dir = tempfile.mkdtemp(prefix="epoch_test_") + else: + temp_dir = "" + + # Broadcast temp directory path to all ranks + temp_dir_list = [temp_dir] + dist.broadcast_object_list(temp_dir_list, src=0) + temp_dir = temp_dir_list[0] + tmp_path = Path(temp_dir) + + try: + medium_dataset_file = tmp_path / "medium_data.json" + + # Only rank 0 creates the data file, all ranks read from it + if rank == 0: + create_test_json_file(medium_dataset_file, MEDIUM_DATASET_SIZE) + dist.barrier() # Wait for file creation + + # Test multiple epoch boundaries + for num_epochs in [0.9, 1.0, 2.5]: + def create_loader_and_aggregator(): + dataset = HfIterableDataset( + path="json", + data_files=str(medium_dataset_file), + split="train", + dataset_name="epoch_test", + seed=SEED, + shuffle_buffer_size=0, # No shuffle for determinism + metric_transform=StandardMetricTransform(), + num_shards_per_rank=2, + ) + loader = StatefulDataLoader( + dataset, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics, num_workers=0 + ) + return loader, MetricsAggregator() + + loader1, aggregator1 = create_loader_and_aggregator() + loader2, aggregator2 = create_loader_and_aggregator() + + # Calculate steps to reach desired epoch boundary + samples_per_rank = MEDIUM_DATASET_SIZE // dist.get_world_size() + total_samples = int(samples_per_rank * num_epochs) + total_steps = total_samples // BATCH_SIZE + + if total_steps < 2: + raise ValueError(f"Not enough steps for meaningful test: {total_steps}") + + # Split steps between before and after checkpoint + steps_before = max(1, total_steps // 2) + steps_after = total_steps - steps_before + + result = generate_ckpt( + loader1, aggregator1, steps_before, steps_after, + resume_dataloader=loader2, resume_aggregator=aggregator2 + ) + + # Verify deterministic resumption - critical for distributed training + orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] + assert orig_post_ids == resumed_ids, ( + f"Rank {rank}: Non-deterministic resume for {num_epochs} epochs. " + f"This indicates checkpoint/resume state is not properly preserved." + ) + + # Verify epoch metric is correctly tracked + final_metrics = result["final_metrics"] + expected_epoch = math.floor(num_epochs - 1e-9) # -1e-9 so 1.0 epochs -> 0 + assert final_metrics[f"train_epoch_test/num_epochs"] == expected_epoch, ( + f"Epoch count incorrect for {num_epochs} epochs test scenario" + ) + + finally: + # Clean up temp directory (only rank 0) + if rank == 0: + shutil.rmtree(temp_dir) diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index d8afcd2263..d02912cd0f 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -4,11 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import tempfile +import shutil from itertools import islice from pathlib import Path from unittest.mock import patch import pytest +import torch.distributed as dist +from torch.testing._internal.common_fsdp import FSDPTest +from tests.test_utils import gpu_test import torch from torchdata.stateful_dataloader import StatefulDataLoader @@ -233,3 +238,108 @@ def create_interleaved(): assert ( result["final_metrics"] == result["resumed_metrics"] ), "Final metrics should match" + + +class TestDistributedInterleavedDataset(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_distributed_interleaved_checkpointing(self): + """ + Test interleaved dataset checkpointing with distributed settings. + Assertions: + - Each rank processes non-overlapping data shards + - Sampling ratios (70/30) are maintained across ranks + - Checkpoint/resume produces identical batches (deterministic) + - Metrics correctly aggregate across ranks + """ + rank = dist.get_rank() + + # Create shared temp directory (only rank 0 creates it) + if rank == 0: + temp_dir = tempfile.mkdtemp(prefix="interleaved_test_") + else: + temp_dir = None + + # Broadcast temp directory to all ranks + temp_dir_list = [temp_dir] + dist.broadcast_object_list(temp_dir_list, src=0) + temp_dir = temp_dir_list[0] + tmp_path = Path(temp_dir) + + try: + def create_dataset(): + file1 = tmp_path / "ds1.json" + file2 = tmp_path / "ds2.json" + + # Only rank 0 creates the data files + if rank == 0: + create_test_json_file(file1, SMALL_DATASET_SIZE) # IDs 0-22 + create_test_json_file(file2, MEDIUM_DATASET_SIZE, offset=100) # IDs 100-134 + dist.barrier() # Wait for file creation + + ds1 = HfIterableDataset( + path="json", data_files=str(file1), split="train", dataset_name="ds1", + shuffle_buffer_size=0, # No shuffle for determinism + metric_transform=StandardMetricTransform(), num_shards_per_rank=2, + ) + ds2 = HfIterableDataset( + path="json", data_files=str(file2), split="train", dataset_name="ds2", + shuffle_buffer_size=0, # No shuffle for determinism + metric_transform=StandardMetricTransform(), num_shards_per_rank=2, + ) + + # Create interleaved dataset with 70/30 weighting + return InterleavedDataset([ds1, ds2], [0.8, 0.2], seed=SEED) + + def create_dataloader(dataset): + loader = StatefulDataLoader( + dataset, batch_size=BATCH_SIZE, + num_workers=0, # Avoid multiprocessing in distributed tests + collate_fn=collate_with_metrics + ) + return loader, MetricsAggregator() + + # Run checkpointing test with small number of steps + loader1, aggregator1 = create_dataloader(create_dataset()) + loader2, aggregator2 = create_dataloader(create_dataset()) + + result = generate_ckpt( + loader1, aggregator1, 3, 3, # 3 steps before, 3 steps after checkpoint + resume_dataloader=loader2, resume_aggregator=aggregator2 + ) + + # Verify deterministic resumption + orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] + assert orig_post_ids == resumed_ids, ( + f"Rank {rank}: Non-deterministic interleaved resume. " + f"This indicates sampling state is not properly preserved." + ) + assert result["final_metrics"] == result["resumed_metrics"], ( + "Final metrics don't match resumed metrics - aggregator state issue" + ) + + # Verify sampling ratio is approximately maintained (80/20 split) + all_ids = [] + for batch in result["pre_checkpoint_batches"] + result["post_checkpoint_batches"]: + all_ids.extend(batch["id"].tolist()) + + # Count samples by ID ranges: ds1 has IDs < 100, ds2 has IDs >= 100 + ds1_samples = sum(1 for id in all_ids if id < 100) + ds2_samples = sum(1 for id in all_ids if id >= 100) + total_samples = ds1_samples + ds2_samples + + if total_samples > 0: + ds1_ratio = ds1_samples / total_samples + assert 0.6 < ds1_ratio < 1.0, ( + f"Rank {rank}: Dataset sampling ratio {ds1_ratio:.2f} outside expected " + f"range for 80/20 split. Got {ds1_samples}, {ds2_samples} samples." + ) + + finally: + # Clean up temp directory (only rank 0) + if rank == 0: + shutil.rmtree(temp_dir) From 96424d0df3868c45211e080a09e1a9e5ef59b0b8 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 09:30:01 -0700 Subject: [PATCH 10/25] tests pass --- .../torchtune/data/test_metrics_aggregator.py | 153 ++++++++---- tests/torchtune/datasets/test_hf_iterable.py | 48 ++-- tests/torchtune/datasets/test_interleaved.py | 60 +++-- .../torchtune/datasets/test_iterable_utils.py | 2 +- torchtune/data/metrics/__init__.py | 2 +- .../data/metrics/_metric_agg_handlers.py | 233 +++++++++++------- torchtune/data/metrics/_metric_aggregator.py | 159 ++++++------ torchtune/data/metrics/_metric_transform.py | 51 ++-- torchtune/data/metrics/readme.md | 16 +- torchtune/datasets/_hf_iterable.py | 8 +- torchtune/datasets/_sft.py | 4 +- 11 files changed, 452 insertions(+), 284 deletions(-) diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index bbfb9821cd..b65c11f533 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -6,8 +6,8 @@ import pytest import torch.distributed as dist -from torch.testing._internal.common_fsdp import FSDPTest from tests.test_utils import gpu_test +from torch.testing._internal.common_fsdp import FSDPTest from torchtune.data.metrics import AggregationType, Metric, MetricsAggregator @@ -64,9 +64,7 @@ def test_distribution_metrics(self): assert result["train_test/dist_metric_mean"] == 5.5 assert result["train_test/dist_metric_min"] == 1 assert result["train_test/dist_metric_max"] == 10 - assert ( - result["train_test/dist_metric_p50"] == 5 - ) # Median of 1-10 is 5 (index 4, value 5) + assert result["train_test/dist_metric_p50"] == 5.5 def test_state_management(self): """Test aggregator checkpointing and restoration.""" @@ -182,28 +180,54 @@ def test_distributed_all_aggregation_types(self): # rank 0: [0, 1, 2, 3, 4], rank 1: [10, 11, 12, 13, 14] for i in range(5): metrics.append( - Metric("test", "dist_metric", rank * 10 + i, AggregationType.DISTRIBUTION) + Metric( + "test", "dist_metric", rank * 10 + i, AggregationType.DISTRIBUTION + ) ) # CATEGORICAL_COUNT: Different categories per rank to test counting # rank 0: 3 of cat_A, 2 of cat_B # rank 1: 1 of cat_A, 4 of cat_C if rank == 0: - metrics.extend([ - Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT), - ]) + metrics.extend( + [ + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT + ), + ] + ) else: - metrics.extend([ - Metric("test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT), - ]) + metrics.extend( + [ + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + ] + ) # Update aggregator and get results aggregator.update(metrics) @@ -211,7 +235,7 @@ def test_distributed_all_aggregation_types(self): # Verify aggregation results across all ranks # SUM: rank 0 adds 10, rank 1 adds 20 -> total 30 - # MEAN: rank 0 has 15, rank 1 has 25 -> avg 20 + # MEAN: rank 0 has 15, rank 1 has 25 -> avg 20 # MAX: rank 0 has 100, rank 1 has 200 -> max 200 # MIN: rank 0 has 5, rank 1 has 10 -> min 5 assert result["train_test/sum_metric"] == 30 @@ -237,7 +261,7 @@ def test_distributed_state_dict_resumption(self): Test that MetricsAggregator state_dict save/restore works correctly in distributed setting. Verifies: - State can be saved after partial updates across ranks - - State can be restored consistently across ranks + - State can be restored consistently across ranks - Continued updates after restore produce identical results - Distributed aggregation works correctly after restoration """ @@ -245,64 +269,95 @@ def test_distributed_state_dict_resumption(self): # Phase 1: Create aggregator and add initial metrics aggregator1 = MetricsAggregator() - + # Each rank contributes different initial values base_value = rank * 100 # rank 0: 0, rank 1: 100 - + initial_metrics = [ Metric("test", "sum_metric", base_value, AggregationType.SUM), Metric("test", "mean_metric", base_value // 2, AggregationType.MEAN), Metric("test", "max_metric", base_value * 2, AggregationType.MAX), ] - + # Add some DISTRIBUTION values - each rank adds 3 values for i in range(3): initial_metrics.append( - Metric("test", "dist_metric", rank * 100 + i, AggregationType.DISTRIBUTION) + Metric( + "test", "dist_metric", rank * 100 + i, AggregationType.DISTRIBUTION + ) ) - + # Add CATEGORICAL_COUNT values if rank == 0: - initial_metrics.extend([ - Metric("test", "cat_metric", "type_A", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "type_A", AggregationType.CATEGORICAL_COUNT), - ]) + initial_metrics.extend( + [ + Metric( + "test", + "cat_metric", + "type_A", + AggregationType.CATEGORICAL_COUNT, + ), + Metric( + "test", + "cat_metric", + "type_A", + AggregationType.CATEGORICAL_COUNT, + ), + ] + ) else: - initial_metrics.extend([ - Metric("test", "cat_metric", "type_B", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "type_B", AggregationType.CATEGORICAL_COUNT), - Metric("test", "cat_metric", "type_B", AggregationType.CATEGORICAL_COUNT), - ]) - + initial_metrics.extend( + [ + Metric( + "test", + "cat_metric", + "type_B", + AggregationType.CATEGORICAL_COUNT, + ), + Metric( + "test", + "cat_metric", + "type_B", + AggregationType.CATEGORICAL_COUNT, + ), + Metric( + "test", + "cat_metric", + "type_B", + AggregationType.CATEGORICAL_COUNT, + ), + ] + ) + aggregator1.update(initial_metrics) - + # Save state_dict after initial update state_dict = aggregator1.state_dict() - + # Phase 2: Create new aggregator and restore from state_dict aggregator2 = MetricsAggregator() aggregator2.load_state_dict(state_dict) - + # Verify both aggregators produce identical results after restore result1 = aggregator1.get_metrics_for_logging(prefix="checkpoint") result2 = aggregator2.get_metrics_for_logging(prefix="checkpoint") - assert result1 == result2, ( - f"Rank {rank}: Aggregators differ after state_dict restore" - ) - - # Phase 3: Add more metrics to both aggregators + assert ( + result1 == result2 + ), f"Rank {rank}: Aggregators differ after state_dict restore" + + # Phase 3: Add more metrics to both aggregators additional_metrics = [ Metric("test", "sum_metric", rank * 1000, AggregationType.SUM), Metric("test", "min_metric", rank * 1000, AggregationType.MIN), ] - + # Update both aggregators with additional metrics aggregator1.update(additional_metrics) aggregator2.update(additional_metrics) - + # Phase 4: Verify final results are identical across both aggregators final_result1 = aggregator1.get_metrics_for_logging(prefix="final") final_result2 = aggregator2.get_metrics_for_logging(prefix="final") - assert final_result1 == final_result2, ( - f"Rank {rank}: Final results differ after continued updates" - ) + assert ( + final_result1 == final_result2 + ), f"Rank {rank}: Final results differ after continued updates" diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index ba644515c2..901234af6f 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -5,19 +5,19 @@ # LICENSE file in the root directory of this source tree. import math -import tempfile import shutil +import tempfile from itertools import islice from pathlib import Path import pytest import torch.distributed as dist -from torch.testing._internal.common_fsdp import FSDPTest from tests.test_utils import gpu_test +from torch.testing._internal.common_fsdp import FSDPTest from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data.metrics import MetricsAggregator, StandardMetricTransform +from torchtune.data.metrics import DefaultTrainingMetricTransform, MetricsAggregator from torchtune.datasets import HfIterableDataset from .test_iterable_utils import collate_with_metrics, generate_ckpt @@ -79,7 +79,7 @@ def _create_dataset( dataset_name=dataset_name, seed=SEED, shuffle_buffer_size=10 if shuffle else 0, - metric_transform=StandardMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, **kwargs, ) @@ -99,7 +99,7 @@ def test_default_dataset_name(self, small_dataset_file): split="train", # dataset_name not provided - should auto-generate seed=SEED, - metric_transform=StandardMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=4, ) @@ -113,7 +113,7 @@ def test_default_dataset_name(self, small_dataset_file): split="train", dataset_name="my_dataset", seed=SEED, - metric_transform=StandardMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=4, ) @@ -287,6 +287,7 @@ def test_distributed_epoch_boundary_checkpointing(self): # Test multiple epoch boundaries for num_epochs in [0.9, 1.0, 2.5]: + def create_loader_and_aggregator(): dataset = HfIterableDataset( path="json", @@ -295,11 +296,14 @@ def create_loader_and_aggregator(): dataset_name="epoch_test", seed=SEED, shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=StandardMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, ) loader = StatefulDataLoader( - dataset, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics, num_workers=0 + dataset, + batch_size=BATCH_SIZE, + collate_fn=collate_with_metrics, + num_workers=0, ) return loader, MetricsAggregator() @@ -310,21 +314,29 @@ def create_loader_and_aggregator(): samples_per_rank = MEDIUM_DATASET_SIZE // dist.get_world_size() total_samples = int(samples_per_rank * num_epochs) total_steps = total_samples // BATCH_SIZE - + if total_steps < 2: - raise ValueError(f"Not enough steps for meaningful test: {total_steps}") + raise ValueError( + f"Not enough steps for meaningful test: {total_steps}" + ) # Split steps between before and after checkpoint steps_before = max(1, total_steps // 2) steps_after = total_steps - steps_before result = generate_ckpt( - loader1, aggregator1, steps_before, steps_after, - resume_dataloader=loader2, resume_aggregator=aggregator2 + loader1, + aggregator1, + steps_before, + steps_after, + resume_dataloader=loader2, + resume_aggregator=aggregator2, ) # Verify deterministic resumption - critical for distributed training - orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + orig_post_ids = [ + b["id"].tolist() for b in result["post_checkpoint_batches"] + ] resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] assert orig_post_ids == resumed_ids, ( f"Rank {rank}: Non-deterministic resume for {num_epochs} epochs. " @@ -333,10 +345,12 @@ def create_loader_and_aggregator(): # Verify epoch metric is correctly tracked final_metrics = result["final_metrics"] - expected_epoch = math.floor(num_epochs - 1e-9) # -1e-9 so 1.0 epochs -> 0 - assert final_metrics[f"train_epoch_test/num_epochs"] == expected_epoch, ( - f"Epoch count incorrect for {num_epochs} epochs test scenario" - ) + expected_epoch = math.floor( + num_epochs - 1e-9 + ) # -1e-9 so 1.0 epochs -> 0 + assert ( + final_metrics["train_epoch_test/num_epochs"] == expected_epoch + ), f"Epoch count incorrect for {num_epochs} epochs test scenario" finally: # Clean up temp directory (only rank 0) diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index d02912cd0f..98e9207047 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -4,21 +4,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import tempfile import shutil +import tempfile from itertools import islice from pathlib import Path from unittest.mock import patch import pytest -import torch.distributed as dist -from torch.testing._internal.common_fsdp import FSDPTest -from tests.test_utils import gpu_test import torch +import torch.distributed as dist +from tests.test_utils import gpu_test +from torch.testing._internal.common_fsdp import FSDPTest from torchdata.stateful_dataloader import StatefulDataLoader -from torchtune.data.metrics import MetricsAggregator, StandardMetricTransform +from torchtune.data.metrics import DefaultTrainingMetricTransform, MetricsAggregator from torchtune.datasets import HfIterableDataset, InterleavedDataset # Import test utilities @@ -87,7 +87,7 @@ def _create_dataset( dataset_name=dataset_name, seed=SEED, shuffle_buffer_size=10 if shuffle else 0, - metric_transform=StandardMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, **kwargs, ) @@ -270,6 +270,7 @@ def test_distributed_interleaved_checkpointing(self): tmp_path = Path(temp_dir) try: + def create_dataset(): file1 = tmp_path / "ds1.json" file2 = tmp_path / "ds2.json" @@ -277,18 +278,28 @@ def create_dataset(): # Only rank 0 creates the data files if rank == 0: create_test_json_file(file1, SMALL_DATASET_SIZE) # IDs 0-22 - create_test_json_file(file2, MEDIUM_DATASET_SIZE, offset=100) # IDs 100-134 + create_test_json_file( + file2, MEDIUM_DATASET_SIZE, offset=100 + ) # IDs 100-134 dist.barrier() # Wait for file creation ds1 = HfIterableDataset( - path="json", data_files=str(file1), split="train", dataset_name="ds1", + path="json", + data_files=str(file1), + split="train", + dataset_name="ds1", shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=StandardMetricTransform(), num_shards_per_rank=2, + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=2, ) ds2 = HfIterableDataset( - path="json", data_files=str(file2), split="train", dataset_name="ds2", + path="json", + data_files=str(file2), + split="train", + dataset_name="ds2", shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=StandardMetricTransform(), num_shards_per_rank=2, + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=2, ) # Create interleaved dataset with 70/30 weighting @@ -296,9 +307,10 @@ def create_dataset(): def create_dataloader(dataset): loader = StatefulDataLoader( - dataset, batch_size=BATCH_SIZE, + dataset, + batch_size=BATCH_SIZE, num_workers=0, # Avoid multiprocessing in distributed tests - collate_fn=collate_with_metrics + collate_fn=collate_with_metrics, ) return loader, MetricsAggregator() @@ -307,24 +319,32 @@ def create_dataloader(dataset): loader2, aggregator2 = create_dataloader(create_dataset()) result = generate_ckpt( - loader1, aggregator1, 3, 3, # 3 steps before, 3 steps after checkpoint - resume_dataloader=loader2, resume_aggregator=aggregator2 + loader1, + aggregator1, + 3, + 3, # 3 steps before, 3 steps after checkpoint + resume_dataloader=loader2, + resume_aggregator=aggregator2, ) # Verify deterministic resumption - orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + orig_post_ids = [ + b["id"].tolist() for b in result["post_checkpoint_batches"] + ] resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] assert orig_post_ids == resumed_ids, ( f"Rank {rank}: Non-deterministic interleaved resume. " f"This indicates sampling state is not properly preserved." ) - assert result["final_metrics"] == result["resumed_metrics"], ( - "Final metrics don't match resumed metrics - aggregator state issue" - ) + assert ( + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics don't match resumed metrics - aggregator state issue" # Verify sampling ratio is approximately maintained (80/20 split) all_ids = [] - for batch in result["pre_checkpoint_batches"] + result["post_checkpoint_batches"]: + for batch in ( + result["pre_checkpoint_batches"] + result["post_checkpoint_batches"] + ): all_ids.extend(batch["id"].tolist()) # Count samples by ID ranges: ds1 has IDs < 100, ds2 has IDs >= 100 diff --git a/tests/torchtune/datasets/test_iterable_utils.py b/tests/torchtune/datasets/test_iterable_utils.py index e160345bc1..28c6d8e464 100644 --- a/tests/torchtune/datasets/test_iterable_utils.py +++ b/tests/torchtune/datasets/test_iterable_utils.py @@ -9,7 +9,7 @@ import torch from torch.utils.data import DataLoader -from torchtune.data import MetricsAggregator +from torchtune.data.metrics import MetricsAggregator def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: diff --git a/torchtune/data/metrics/__init__.py b/torchtune/data/metrics/__init__.py index 17e359d697..778245f83a 100644 --- a/torchtune/data/metrics/__init__.py +++ b/torchtune/data/metrics/__init__.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.data.metrics._metric_aggregator import MetricsAggregator from torchtune.data.metrics._metric_agg_handlers import ( AggregationHandler, CategoricalCountAggHandler, @@ -15,6 +14,7 @@ MinAggHandler, SumAggHandler, ) +from torchtune.data.metrics._metric_aggregator import MetricsAggregator from torchtune.data.metrics._metric_transform import ( AggregationType, DefaultTrainingMetricTransform, diff --git a/torchtune/data/metrics/_metric_agg_handlers.py b/torchtune/data/metrics/_metric_agg_handlers.py index 1a1557c803..c3415aba7d 100644 --- a/torchtune/data/metrics/_metric_agg_handlers.py +++ b/torchtune/data/metrics/_metric_agg_handlers.py @@ -1,21 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import logging from abc import ABC, abstractmethod from collections import Counter, deque from dataclasses import dataclass, field -from enum import Enum from typing import Any, Union import torch -from torchtune.data.metrics._metric_transform import Metric, AggregationType +from torchtune.data.metrics._metric_transform import AggregationType, Metric logger = logging.getLogger(__name__) + @dataclass class MetricState: """Mutable state object representing aggregated metric for (dataset, metric) on a single rank. - - Args: + + Attributes: dataset_name (str): Name of the dataset. metric_name (str): Name of the metric. value (float): Current aggregated value, whose meaning depends on the aggregation type @@ -23,15 +29,17 @@ class MetricState: agg_type (AggregationType): Aggregation type. metadata (dict[str, Any]): Additional state like count, list of values, etc. """ + dataset_name: str metric_name: str value: float agg_type: AggregationType metadata: dict[str, Any] = field(default_factory=dict) + class AggregationHandler(ABC): """Base class for handling metric aggregation in MetricsAggregator. - + This class defines the interface for different aggregation strategies (e.g., SUM, MEAN). Each handler is responsible for: - Initializing the state for a new (dataset, metric) pair. @@ -40,31 +48,33 @@ class AggregationHandler(ABC): - Reducing the values from all ranks in a distributed setting. - Serializing and deserializing the metric state for checkpointing. """ - + @abstractmethod - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: """Create a new MetricState for a (dataset_name, metric_name) pair. - + Args: dataset_name (str): Name of the dataset. Especially useful when tracking multiple datasets. metric_name (str): Name of the metric. agg_type (AggregationType): Aggregation type. - + Returns: MetricState: New MetricState for this (dataset_name, metric_name) pair. """ pass - + @abstractmethod def update(self, local_agg_metric: MetricState, metric: Metric) -> None: """Update cumulative MetricState with new metric info. - + Args: local_agg_metric (MetricState): Cumulative state of the aggregation for this metric in the local rank. metric (Metric): Input metric info. """ pass - + @abstractmethod def finalize_local_agg( self, local_agg_metric: MetricState @@ -83,17 +93,15 @@ def finalize_local_agg( A single `MetricState` or a list of them if the metric expands. """ pass - + @abstractmethod - def finalize_dist_agg( - self, local_agg_metrics: list[MetricState] - ) -> MetricState: + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: """ Merge MetricStates from all ranks into final result. - + Args: local_agg_metrics (list[MetricState]): list of MetricStates for this (dataset_name, metric_name) pair. - + Returns: MetricState: Final result for this (dataset_name, metric_name) pair. """ @@ -101,21 +109,27 @@ def finalize_dist_agg( def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: """Convert handler-specific metadata to serializable format. - + Args: metadata (dict[str, Any]): AggHandler-specific metadata. + Returns: + dict[str, Any]: Serializable metadata. + Override this when using non-serializable types like deque or Counter. For example, convert deque to list, Counter to dict. """ return metadata.copy() - + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: """Restore handler-specific metadata from serialized format. - + Args: metadata (dict[str, Any]): AggHandler-specific metadata. + Returns: + dict[str, Any]: Deserialized metadata. + Override this to reverse the serialize_metadata transformation. For example, convert list back to deque, dict back to Counter. """ @@ -124,116 +138,138 @@ def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: class SumAggHandler(AggregationHandler): """AggHandler for SUM aggregation. Initializes with 0.0 and accumulates metric values.""" - - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: return MetricState( dataset_name=dataset_name, metric_name=metric_name, value=0.0, - agg_type=agg_type + agg_type=agg_type, ) - + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + if not isinstance(metric.value, (int, float)): + raise ValueError( + f"SumAggHandler expects numeric values, got {type(metric.value)}" + ) local_agg_metric.value += metric.value - + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: return local_agg_metric - + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: if not local_agg_metrics: raise ValueError("Cannot aggregate empty list of metrics") - + total = sum(metric.value for metric in local_agg_metrics) return MetricState( dataset_name=local_agg_metrics[0].dataset_name, metric_name=local_agg_metrics[0].metric_name, value=total, agg_type=local_agg_metrics[0].agg_type, - metadata=local_agg_metrics[0].metadata.copy() + metadata=local_agg_metrics[0].metadata.copy(), ) class MaxAggHandler(AggregationHandler): """AggHandler for MAX aggregation. Tracks maximum value across all updates.""" - - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: return MetricState( dataset_name=dataset_name, metric_name=metric_name, - value=float('-inf'), + value=float("-inf"), agg_type=agg_type, ) - + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + if not isinstance(metric.value, (int, float)): + raise ValueError( + f"MaxAggHandler expects numeric values, got {type(metric.value)}" + ) local_agg_metric.value = max(local_agg_metric.value, metric.value) - + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: return local_agg_metric - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: max_value = max(r.value for r in local_agg_metrics) return MetricState( dataset_name=local_agg_metrics[0].dataset_name, metric_name=local_agg_metrics[0].metric_name, value=max_value, agg_type=local_agg_metrics[0].agg_type, - metadata=local_agg_metrics[0].metadata.copy() + metadata=local_agg_metrics[0].metadata.copy(), ) class MinAggHandler(AggregationHandler): """AggHandler for MIN aggregation. Tracks minimum value across all updates.""" - - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: return MetricState( dataset_name=dataset_name, metric_name=metric_name, - value=float('inf'), + value=float("inf"), agg_type=agg_type, ) - + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + if not isinstance(metric.value, (int, float)): + raise ValueError( + f"MinAggHandler expects numeric values, got {type(metric.value)}" + ) local_agg_metric.value = min(local_agg_metric.value, metric.value) - + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: return local_agg_metric - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: min_value = min(r.value for r in local_agg_metrics) return MetricState( dataset_name=local_agg_metrics[0].dataset_name, metric_name=local_agg_metrics[0].metric_name, value=min_value, agg_type=local_agg_metrics[0].agg_type, - metadata=local_agg_metrics[0].metadata.copy() + metadata=local_agg_metrics[0].metadata.copy(), ) class MeanAggHandler(AggregationHandler): """AggHandler for MEAN aggregation. Maintains running sum and count to compute average.""" - - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: return MetricState( - dataset_name=dataset_name, - metric_name=metric_name, + dataset_name=dataset_name, + metric_name=metric_name, value=0.0, agg_type=agg_type, metadata={"sum": 0.0, "count": 0}, ) - + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: local_agg_metric.metadata["sum"] += metric.value local_agg_metric.metadata["count"] += 1 - + def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: count = local_agg_metric.metadata["count"] - local_agg_metric.value = local_agg_metric.metadata["sum"] / count if count > 0 else 0.0 + local_agg_metric.value = ( + local_agg_metric.metadata["sum"] / count if count > 0 else 0.0 + ) return local_agg_metric - + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: total_sum = sum(metric.metadata["sum"] for metric in local_agg_metrics) total_count = sum(metric.metadata["count"] for metric in local_agg_metrics) - + return MetricState( dataset_name=local_agg_metrics[0].dataset_name, metric_name=local_agg_metrics[0].metric_name, @@ -244,43 +280,46 @@ def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState class DistributionAggHandler(AggregationHandler): - """AggHandler for DISTRIBUTION aggregation. Maintains a sliding window of values + """AggHandler for DISTRIBUTION aggregation. Maintains a sliding window of values and expands into multiple statistical metrics (mean, min, max, percentiles, std). - - Note: Percentiles and standard deviation are approximated in distributed settings by averaging local - percentiles and standard deviations across ranks. This is mathematically imprecise but provides a + + Note: Percentiles and standard deviation are approximated in distributed settings by averaging local + percentiles and standard deviations across ranks. This is mathematically imprecise but provides a reasonable approximation for monitoring purposes. + + Args: + window_size (int): Maximum number of recent values to retain for statistics. + + Raises: + ValueError: If window_size is not positive. """ - + def __init__(self, window_size: int = 1000): - """Initialize handler with specified window size for value retention. - - Args: - window_size (int): Maximum number of recent values to retain for statistics. - """ if window_size <= 0: raise ValueError(f"window_size must be positive, got {window_size}") self.window_size = window_size - - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: return MetricState( - dataset_name=dataset_name, - metric_name=metric_name, + dataset_name=dataset_name, + metric_name=metric_name, value=0.0, agg_type=agg_type, - metadata={"values": deque(maxlen=self.window_size)} + metadata={"values": deque(maxlen=self.window_size)}, ) - + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: local_agg_metric.metadata["values"].append(metric.value) - + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: values = list(local_agg_metric.metadata["values"]) if not values: return [] - + return self._compute_distribution_stats(local_agg_metric, values) - + def _compute_distribution_stats( self, local_agg_metric: MetricState, values: list[float] ) -> list[MetricState]: @@ -300,7 +339,9 @@ def _compute_distribution_stats( # Compute all percentiles in one go percentile_definitions = torch.tensor([0.05, 0.5, 0.95], dtype=torch.float64) - p05_val, p50_val, p95_val = torch.quantile(values_tensor, percentile_definitions).tolist() + p05_val, p50_val, p95_val = torch.quantile( + values_tensor, percentile_definitions + ).tolist() # Return multiple MetricStates with proper agg_types for distributed reduction # NOTE: Percentiles use MEAN aggregation which approximates global percentiles @@ -362,7 +403,7 @@ def _compute_distribution_stats( ) ) return metrics - + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: raise NotImplementedError( "Metrics with AggregationType.DISTRIBUTION are converted to other " @@ -375,43 +416,49 @@ def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: if "values" in serialized: serialized["values"] = list(serialized["values"]) return serialized - + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: """Convert list back to deque.""" deserialized = metadata.copy() if "values" in deserialized: - deserialized["values"] = deque(deserialized["values"], maxlen=self.window_size) + deserialized["values"] = deque( + deserialized["values"], maxlen=self.window_size + ) return deserialized class CategoricalCountAggHandler(AggregationHandler): """AggHandler for CATEGORICAL_COUNT aggregation. Counts occurrences of categorical values and expands into individual count metrics for each category.""" - - def initialize_metric_state(self, dataset_name: str, metric_name: str, agg_type: AggregationType) -> MetricState: + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: return MetricState( - dataset_name=dataset_name, - metric_name=metric_name, + dataset_name=dataset_name, + metric_name=metric_name, value=0.0, agg_type=agg_type, - metadata={"counts": Counter()} + metadata={"counts": Counter()}, ) - + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: local_agg_metric.metadata["counts"][metric.value] += 1 - + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: # Expand categorical counts into individual metrics results = [] for category, count in local_agg_metric.metadata["counts"].items(): - results.append(MetricState( - dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_{category}_count", - value=count, - agg_type=AggregationType.SUM - )) + results.append( + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_{category}_count", + value=count, + agg_type=AggregationType.SUM, + ) + ) return results - + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: raise NotImplementedError( "Metrics with AggregationType.CATEGORICAL_COUNT are converted to other " @@ -424,10 +471,10 @@ def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: if "counts" in serialized: serialized["counts"] = dict(serialized["counts"]) return serialized - + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: """Convert dict back to Counter.""" deserialized = metadata.copy() if "counts" in deserialized: deserialized["counts"] = Counter(deserialized["counts"]) - return deserialized \ No newline at end of file + return deserialized diff --git a/torchtune/data/metrics/_metric_aggregator.py b/torchtune/data/metrics/_metric_aggregator.py index c07f0dea36..633d5c6b80 100644 --- a/torchtune/data/metrics/_metric_aggregator.py +++ b/torchtune/data/metrics/_metric_aggregator.py @@ -6,7 +6,7 @@ import ast from collections import defaultdict -from typing import Any, tuple +from typing import Any import torch.distributed as dist @@ -14,63 +14,69 @@ AggregationHandler, CategoricalCountAggHandler, DistributionAggHandler, - MetricState, MaxAggHandler, MeanAggHandler, + MetricState, MinAggHandler, SumAggHandler, ) -from torchtune.data.metrics._metric_transform import Metric, AggregationType +from torchtune.data.metrics._metric_transform import AggregationType, Metric + class MetricsAggregator: """Aggregates metrics across datasets and distributed ranks using pluggable handlers. - - Uses a handler-based strategy pattern where each aggregation type (SUM, MEAN, etc.) + + Uses a handler-based strategy pattern where each aggregation type (SUM, MEAN, etc.) has its own handler. Maintains only one state per (dataset, metric) pair. - + When preparing for logging, uses a two-phase approach: 1. Local aggregation: Each rank aggregates its metrics independently 2. Distributed reduction: Results combined across ranks - + The aggregator is checkpointable and restores from state_dict for training resumption. - + Args: dist_window_size (int): Window size for DistributionAggHandler tracking. - + Example: >>> from torchtune.data.metrics import MetricsAggregator, Metric, AggregationType - >>> + >>> >>> # Create aggregator >>> aggregator = MetricsAggregator() - >>> + >>> >>> # Sample metrics from different batches >>> batch1_metrics = [ ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), ... ] - >>> + >>> >>> batch2_metrics = [ ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), ... ] - >>> + >>> >>> # Update with metrics >>> aggregator.update(batch1_metrics) >>> aggregator.update(batch2_metrics) - >>> + >>> >>> # Get final results >>> results = aggregator.get_metrics_for_logging(prefix="train") >>> # {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} + + Raises: + ValueError: If dist_window_size is not positive. """ - + def __init__(self, dist_window_size: int = 1000): if dist_window_size <= 0: - raise ValueError(f"dist_window_size must be positive, got {dist_window_size}") - + raise ValueError( + f"dist_window_size must be positive, got {dist_window_size}" + ) + # Storage: {(dataset, metric): MetricState} - O(unique metrics) not O(samples) self._metric_states: dict[tuple[str, str], MetricState] = {} self._dist_window_size = dist_window_size - + # Create handler registry - all handlers initialized upfront self._handlers: dict[AggregationType, AggregationHandler] = { AggregationType.SUM: SumAggHandler(), @@ -80,59 +86,66 @@ def __init__(self, dist_window_size: int = 1000): AggregationType.DISTRIBUTION: DistributionAggHandler(dist_window_size), AggregationType.CATEGORICAL_COUNT: CategoricalCountAggHandler(), } - - def register_handler(self, agg_type: AggregationType, handler: AggregationHandler) -> None: + + def register_handler( + self, agg_type: AggregationType, handler: AggregationHandler + ) -> None: """Register custom aggregation handler for specified type. - + Args: agg_type (AggregationType): The aggregation type to handle handler (AggregationHandler): Handler instance implementing the Ag∂gregationHandler interface """ self._handlers[agg_type] = handler - + def update(self, metrics: list[Metric]) -> None: """Update (dataset_name, metric_name) metric state with new values. - + Args: metrics (list[Metric]): List of metrics to update the state with + + Raises: + ValueError: If no handler is registered for a metric's aggregation type. """ for metric in metrics: metric_key = (metric.dataset_name, metric.name) handler = self._handlers.get(metric.agg_type) - + if handler is None: - raise ValueError(f"No handler registered for aggregation type: {metric.agg_type}") - + raise ValueError( + f"No handler registered for aggregation type: {metric.agg_type}" + ) + if metric_key not in self._metric_states: self._metric_states[metric_key] = handler.initialize_metric_state( metric.dataset_name, metric.name, metric.agg_type ) - + local_agg_metric = self._metric_states[metric_key] handler.update(local_agg_metric, metric) # Mutates local_agg_metric - + def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]: """Get final metrics for logging in standard format. - + Args: prefix (str): Prefix for metric names in the returned dictionary - + Returns: - dict[str, float]: Dictionary with keys like "{prefix}_{dataset_name}/{metric_name}" - and float values. For example, with `prefix="train"`, `dataset_name="alpaca"`, + dict[str, float]: Dictionary with keys like "{prefix}_{dataset_name}/{metric_name}" + and float values. For example, with `prefix="train"`, `dataset_name="alpaca"`, `metric_name="loss"`, the key would be `train_alpaca/loss`. """ final_results = self._compute_unified_metrics() - + return { - f"{prefix}_{result.dataset_name}/{result.metric_name}": result.value + f"{prefix}_{result.dataset_name}/{result.metric_name}": result.value for result in final_results } - + def _compute_unified_metrics(self) -> list[MetricState]: """ Compute metrics handling both local and distributed cases uniformly. - + Returns: list[MetricState]: Final results ready for logging """ @@ -141,32 +154,36 @@ def _compute_unified_metrics(self) -> list[MetricState]: for local_agg_metric in self._metric_states.values(): handler = self._handlers[local_agg_metric.agg_type] prepared = handler.finalize_local_agg(local_agg_metric) - if isinstance(prepared, list): # Distribution/categorical expands to multiple + if isinstance( + prepared, list + ): # Distribution/categorical expands to multiple prepared_results.extend(prepared) else: prepared_results.append(prepared) - + # Step 2: Apply distributed reduction if needed if dist.is_initialized() and dist.get_world_size() > 1: prepared_results = self._finalize_dist_agg(prepared_results) - + return prepared_results - - def _finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> list[MetricState]: + + def _finalize_dist_agg( + self, local_agg_metrics: list[MetricState] + ) -> list[MetricState]: """Apply distributed reduction to local results. - + Args: local_agg_metrics (list[MetricState]): (dataset_name, metric_name) metric pairs from this rank - + Returns: list[MetricState]: Reduced results combining all ranks """ world_size = dist.get_world_size() - + # Gather all results from all ranks all_results = [None] * world_size dist.all_gather_object(all_results, local_agg_metrics) - + # Group by (dataset_name, metric_name) for reduction grouped = defaultdict(list) for rank_results in all_results: @@ -174,98 +191,100 @@ def _finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> list[Metri for result in rank_results: result_key = (result.dataset_name, result.metric_name) grouped[result_key].append(result) - + # Apply handler-specific distributed reduction reduced_results = [] for result_key, results_list in grouped.items(): if not results_list: continue # Skip empty groups - + # All results for a key should have same agg_type agg_type = results_list[0].agg_type handler = self._handlers[agg_type] reduced_result = handler.finalize_dist_agg(results_list) reduced_results.append(reduced_result) - + return reduced_results - + def state_dict(self) -> dict[str, Any]: """Serialize aggregator state for checkpointing. - + Returns: dict[str, Any]: Serializable dictionary containing all aggregator state """ serializable_state = {} required_agg_types = set() # Track aggregation types used in saved states - + for metric_key, local_agg_metric in self._metric_states.items(): # Get handler for this result's aggregation type handler = self._handlers[local_agg_metric.agg_type] required_agg_types.add(local_agg_metric.agg_type) - + # Convert MetricState to serializable dict result_dict = { "dataset_name": local_agg_metric.dataset_name, "metric_name": local_agg_metric.metric_name, "value": local_agg_metric.value, "agg_type": local_agg_metric.agg_type, - "metadata": handler.serialize_metadata(local_agg_metric.metadata) + "metadata": handler.serialize_metadata(local_agg_metric.metadata), } - + # Convert tuple key to string for JSON compatibility serializable_state[str(metric_key)] = result_dict - + return { "state": serializable_state, "dist_window_size": self._dist_window_size, - "required_agg_types": list(required_agg_types) # Save which handlers are needed + "required_agg_types": list( + required_agg_types + ), # Save which handlers are needed } - + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load aggregator state from checkpoint. - + Args: state_dict (dict[str, Any]): Dictionary containing serialized aggregator state - + Raises: ValueError: If required handlers are missing after checkpoint restore """ self._dist_window_size = state_dict.get("dist_window_size", 1000) - + # Sanity check: Ensure all required handlers are available required_agg_types = state_dict.get("required_agg_types", []) missing_handlers = [] for agg_type in required_agg_types: if agg_type not in self._handlers: missing_handlers.append(agg_type) - + if missing_handlers: raise ValueError( f"Missing handlers for aggregation types: {missing_handlers}. " f"Custom handlers must be re-registered before checkpoint restore." ) - + deserialized_state = {} for key_str, result_dict in state_dict["state"].items(): # Convert string keys back to tuples metric_key = ast.literal_eval(key_str) - + # Get handler for this aggregation type agg_type = result_dict["agg_type"] handler = self._handlers[agg_type] - + # Restore metadata using handler-specific deserialization metadata = handler.deserialize_metadata(result_dict["metadata"]) - + # Create MetricState from dict local_agg_metric = MetricState( dataset_name=result_dict["dataset_name"], metric_name=result_dict["metric_name"], value=result_dict["value"], agg_type=result_dict["agg_type"], - metadata=metadata + metadata=metadata, ) - + deserialized_state[metric_key] = local_agg_metric - + self._metric_states = deserialized_state diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py index 0e252cb2e6..9ae73488e4 100644 --- a/torchtune/data/metrics/_metric_transform.py +++ b/torchtune/data/metrics/_metric_transform.py @@ -7,10 +7,11 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Mapping, Optional, Union from torchtune.modules.transforms import Transform + @dataclass(frozen=True) class Metric: dataset_name: str @@ -18,6 +19,7 @@ class Metric: value: Union[int, float, str] agg_type: "AggregationType" + class AggregationType(Enum): """Defines how a metric's value should be aggregated.""" @@ -28,11 +30,12 @@ class AggregationType(Enum): MAX = "max" MIN = "min" + class MetricTransform(Transform): """Applied to each sample to generate per-sample metrics for training tracking. - - Creates Metric objects that are later aggregated by 'MetricsAggregator'. This separation - of concerns ensures metrics are correctly aggregated even with multiple dataloader + + Creates Metric objects that are later aggregated by 'MetricsAggregator'. This separation + of concerns ensures metrics are correctly aggregated even with multiple dataloader workers and in distributed settings.""" def __init__(self): @@ -42,10 +45,10 @@ def __init__(self): def set_dataset_name(self, dataset_name: str) -> None: """Called by dataset to set the namespace for metrics. - + The dataset name is used to differentiate multiple datasets stats, e.g. "train/dataset1/tokens_seen" and "train/dataset2/tokens_seen". - + Args: dataset_name (str): Name of the dataset for metric namespacing """ @@ -53,21 +56,22 @@ def set_dataset_name(self, dataset_name: str) -> None: # Create a partial to make it easier to create new metrics self.new_metric = partial(Metric, dataset_name=dataset_name) - def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: + def _generate_metrics(self, sample: Mapping[str, Any]) -> list[Metric]: """Generate metrics for a single sample. Must be implemented by subclasses. - + Args: - sample (dict[str, Any]): The sample dictionary to generate metrics from - + sample (Mapping[str, Any]): The sample dictionary to generate metrics from + Returns: list[Metric]: List of metrics generated for this sample + + Raises: + NotImplementedError: If subclass does not implement this method. """ - raise NotImplementedError( - "Subclasses must implement _generate_metrics method" - ) + raise NotImplementedError("Subclasses must implement _generate_metrics method") - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Apply transform to sample, adding generated metrics. """ + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """Apply transform to sample, adding generated metrics.""" if self.dataset_name is None or self.new_metric is None: raise RuntimeError( "set_dataset_name() must be called before using the transform." @@ -85,18 +89,18 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: class DefaultTrainingMetricTransform(MetricTransform): """Generates training metrics: samples_seen, tokens_seen, seq_len distribution. - + For details about MetricTransform base class behavior, see the parent class docstring. - + Tracked metrics: - samples_seen: Cumulative count of samples processed (SUM aggregation) - - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) + - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) - seq_len: Distribution of sequence lengths (DISTRIBUTION aggregation) - + Example: >>> transform = DefaultTrainingMetricTransform() >>> transform.set_dataset_name("alpaca") - >>> + >>> >>> sample = {"tokens": [1, 2, 3, 4, 5]} # 5 tokens >>> metrics = transform._generate_metrics(sample) >>> # Creates: @@ -107,7 +111,12 @@ class DefaultTrainingMetricTransform(MetricTransform): >>> # ] """ - def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: + def _generate_metrics(self, sample: Mapping[str, Any]) -> list[Metric]: + if self.new_metric is None: + raise RuntimeError( + "set_dataset_name() must be called before using the transform." + ) + # Determine token key token_key = "tokens" if "tokens" in sample else "input_ids" token_len = len(sample.get(token_key, [])) diff --git a/torchtune/data/metrics/readme.md b/torchtune/data/metrics/readme.md index 5b79c5e98b..6c6c413246 100644 --- a/torchtune/data/metrics/readme.md +++ b/torchtune/data/metrics/readme.md @@ -102,12 +102,12 @@ aggregator = MetricsAggregator() # Sample metrics from different batches batch1_metrics = [ Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), ] batch2_metrics = [ Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), ] # Update with metrics @@ -125,7 +125,7 @@ Pluggable strategies for different aggregation patterns. ``` AggregationHandler (ABC) ├── SumAggHandler # value += metric.value -├── MeanAggHandler # tracks sum and count +├── MeanAggHandler # tracks sum and count ├── MaxAggHandler # value = max(value, metric.value) ├── MinAggHandler # value = min(value, metric.value) ├── DistributionAggHandler # maintains value window + stats @@ -138,18 +138,18 @@ class CustomAggHandler(AggregationHandler): def initialize_metric_state(self, dataset_name, metric_name, agg_type): return MetricState( dataset_name=dataset_name, - metric_name=metric_name, + metric_name=metric_name, value=, # should change agg_type=agg_type, metadata={} # may need to change ) - + def update(self, local_agg_metric, metric): ... - + def finalize_local_agg(self, local_agg_metric): ... - + def finalize_dist_agg(self, local_agg_metrics): ... @@ -173,4 +173,4 @@ Rank 1: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, AllGather + Reduce ↓ Final Results [(ds1, metric1), (ds1, metric2)] -``` \ No newline at end of file +``` diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 81710169d0..7aac8adcc4 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -12,7 +12,11 @@ from datasets import load_dataset from datasets.distributed import split_dataset_by_node -from torchtune.data.metrics import AggregationType, Metric, StandardMetricTransform +from torchtune.data.metrics import ( + AggregationType, + DefaultTrainingMetricTransform, + Metric, +) from torchtune.datasets._iterable_base import TuneIterableDataset logger = logging.getLogger(__name__) @@ -73,7 +77,7 @@ def __init__( self._weight = weight # TODO: make it a property? # Create default transform if not provided - self._metric_transform = metric_transform or StandardMetricTransform() + self._metric_transform = metric_transform or DefaultTrainingMetricTransform() # Auto-generate dataset name if not provided, ensuring it's always a string. if dataset_name is None: diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 72289c14dc..a0aab0b27b 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -12,7 +12,7 @@ from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages -from torchtune.data.metrics import StandardMetricTransform +from torchtune.data.metrics import DefaultTrainingMetricTransform from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.modules.transforms import Transform @@ -266,7 +266,7 @@ def sft_iterable_dataset( message_transform=message_transform, model_transform=model_transform, output_transform=output_transform, - metric_transform=StandardMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), shuffle_buffer_size=shuffle_buffer_size, weight=weight, seed=seed, From 853147b5a5c2945661b6016da620db46e325dad1 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 13:49:09 -0700 Subject: [PATCH 11/25] optimize SFTOutputTransform --- torchtune/data/metrics/_metric_transform.py | 12 ++-- torchtune/datasets/_sft.py | 73 ++++++++++++--------- 2 files changed, 48 insertions(+), 37 deletions(-) diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py index 9ae73488e4..529fff8c5c 100644 --- a/torchtune/data/metrics/_metric_transform.py +++ b/torchtune/data/metrics/_metric_transform.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Callable, Mapping, Optional, Union +from typing import Any, Callable, Optional, Union from torchtune.modules.transforms import Transform @@ -32,7 +32,7 @@ class AggregationType(Enum): class MetricTransform(Transform): - """Applied to each sample to generate per-sample metrics for training tracking. + """Applied to each dataset sample to generate per-sample metrics for training tracking. Creates Metric objects that are later aggregated by 'MetricsAggregator'. This separation of concerns ensures metrics are correctly aggregated even with multiple dataloader @@ -56,11 +56,11 @@ def set_dataset_name(self, dataset_name: str) -> None: # Create a partial to make it easier to create new metrics self.new_metric = partial(Metric, dataset_name=dataset_name) - def _generate_metrics(self, sample: Mapping[str, Any]) -> list[Metric]: + def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: """Generate metrics for a single sample. Must be implemented by subclasses. Args: - sample (Mapping[str, Any]): The sample dictionary to generate metrics from + sample (dict[str, Any]): The sample dictionary to generate metrics from Returns: list[Metric]: List of metrics generated for this sample @@ -70,7 +70,7 @@ def _generate_metrics(self, sample: Mapping[str, Any]) -> list[Metric]: """ raise NotImplementedError("Subclasses must implement _generate_metrics method") - def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: """Apply transform to sample, adding generated metrics.""" if self.dataset_name is None or self.new_metric is None: raise RuntimeError( @@ -111,7 +111,7 @@ class DefaultTrainingMetricTransform(MetricTransform): >>> # ] """ - def _generate_metrics(self, sample: Mapping[str, Any]) -> list[Metric]: + def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: if self.new_metric is None: raise RuntimeError( "set_dataset_name() must be called before using the transform." diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index a0aab0b27b..f7638bb609 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -4,9 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Mapping, Optional +from typing import Any, Callable, Optional import numpy as np +import torch from datasets import load_dataset from torch.utils.data import Dataset @@ -145,7 +146,7 @@ def __init__( self._message_transform = message_transform self._model_transform = model_transform - def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: if self._message_transform is not None: transformed_sample = self._message_transform(sample) if "messages" in transformed_sample: @@ -183,40 +184,50 @@ def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: class SFTOutputTransform(Transform): - """ - Output transform to be used in SFT recipes as an input to TuneIterableDataset. - It takes tokenized inputs with "tokens" and "mask" keys and - creates the "labels" key for SFT training. - - The labels are created by: - 1. Shifting tokens by 1 position (for autoregressive training) - 2. Masking positions where mask[1:] is True with CROSS_ENTROPY_IGNORE_IDX - 3. Adding CROSS_ENTROPY_IGNORE_IDX at the end + """Applied to each dataset sample to build the `"labels"` tensor for causal-LM SFT training. + + Expects sample to contain 1-D torch tensors + "tokens": token IDs, dtype=torch.long + "mask": bool/int where **True** marks positions to ignore + + If they are not tensors, they are converted to tensors. + + Produces ``"labels"`` of the same shape such that + labels[t] = tokens[t+1] # shift left + labels[t] = IGNORE_IDX if mask[t+1] # respect mask + labels[-1] = IGNORE_IDX # last token has no target + + All ops are vectorised; only one fresh tensor (`labels`) is allocated. """ - def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: - # Create a copy to avoid modifying the original - tokenized_dict = dict(sample) + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): - keys_str = ", ".join(tokenized_dict.keys()) - raise ValueError( - f"SFTOutputTransform expects 'tokens' and 'mask' keys. " - f"Got keys: {keys_str}" - ) + tokens = sample["tokens"] + mask = sample["mask"] - # Create labels for SFT training - tokenized_dict["labels"] = list( - np.where( - tokenized_dict["mask"][1:], - CROSS_ENTROPY_IGNORE_IDX, - tokenized_dict["tokens"][1:], - ) - ) - tokenized_dict["labels"].append(CROSS_ENTROPY_IGNORE_IDX) - assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) + # Sanity checks + if not isinstance(tokens, torch.Tensor): + tokens = torch.tensor(tokens) + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask) - return tokenized_dict + if tokens.ndim != 1 or mask.ndim != 1: + raise ValueError("Both 'tokens' and 'mask' must be 1-D tensors.") + + # build labels + # pre-fill with IGNORE so we don’t need extra assignments later + labels = tokens.new_full(tokens.shape, CROSS_ENTROPY_IGNORE_IDX) + + # left-shift via cheap views (no copy) + labels[:-1].copy_(tokens[1:]) + + # apply mask in-place (single fused kernel on GPU/CPU) + labels[:-1].masked_fill_(mask[1:].bool(), CROSS_ENTROPY_IGNORE_IDX) + + # return a shallow-copied mapping so the original sample stays intact + out = dict(sample) + out["labels"] = labels + return out def sft_iterable_dataset( From 96bc3172a0e437e2b574aa79ba4d32a3223b38ae Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 18:04:27 -0400 Subject: [PATCH 12/25] use ds.sampling_weight --- recipes/full_finetune_distributed.py | 3 -- tests/torchtune/datasets/test_hf_iterable.py | 13 +++++- tests/torchtune/datasets/test_interleaved.py | 48 +++++++++++--------- torchtune/datasets/_hf_iterable.py | 4 ++ torchtune/datasets/_interleaved.py | 20 ++++++-- torchtune/datasets/_iterable_base.py | 12 +++++ 6 files changed, 69 insertions(+), 31 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 4c41a81a5b..fadb3f7c23 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -775,7 +775,6 @@ def _setup_data( # 1. Create all datasets iterable_datasets = [] - weights = [] cfg_dataset_list = cfg_dataset if not isinstance(cfg_dataset_list, ListConfig): cfg_dataset_list = [cfg_dataset_list] @@ -783,13 +782,11 @@ def _setup_data( for ds_cfg in cfg_dataset_list: ds = config.instantiate(ds_cfg, model_transform=self._tokenizer) iterable_datasets.append(ds) - weights.append(ds_cfg.get("weight", 1.0)) # 2. Interleave datasets if any if len(iterable_datasets) > 1: ds = InterleavedDataset( datasets=iterable_datasets, - weights=weights, seed=self.seed, ) else: diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index 901234af6f..94b55b87be 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -105,13 +105,18 @@ def test_default_dataset_name(self, small_dataset_file): # Should generate name from path and split assert dataset.dataset_name == "json_train" + # Test default sampling weight + assert dataset.sampling_weight == 1.0 + assert isinstance(dataset.sampling_weight, float) - # Test giving a name + # Test giving a name and custom weight + custom_weight = 2.5 dataset2 = HfIterableDataset( path="json", data_files=small_dataset_file, split="train", dataset_name="my_dataset", + weight=custom_weight, seed=SEED, metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=4, @@ -119,6 +124,8 @@ def test_default_dataset_name(self, small_dataset_file): # Should generate name from path and split assert dataset2.dataset_name == "my_dataset" + # Test custom sampling weight + assert dataset2.sampling_weight == custom_weight @pytest.mark.parametrize("num_epochs", [0.5, 1.0, 2.5]) def test_epoch_boundaries_and_checkpointing( @@ -198,9 +205,11 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file): first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] - # Shuffled epochs should have different order + # Extract IDs for comparison first_epoch_ids = [sample["id"] for sample in first_epoch_samples] second_epoch_ids = [sample["id"] for sample in second_epoch_samples] + + # Shuffled epochs should have different order assert first_epoch_ids != list( range(SMALL_DATASET_SIZE) ), f"Shuffled should not be sorted, got {first_epoch_ids}" diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 98e9207047..38d92bc2d3 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -101,30 +101,35 @@ class TestInterleavedDataset: def test_initialization_validation(self, dataset_factory, small_dataset_file): """Tests that the dataset raises errors for invalid configurations, like duplicate names.""" # Test duplicate dataset names - ds1 = dataset_factory(small_dataset_file, dataset_name="duplicate") - ds2 = dataset_factory(small_dataset_file, dataset_name="duplicate") + ds1 = dataset_factory(small_dataset_file, dataset_name="duplicate", weight=0.5) + ds2 = dataset_factory(small_dataset_file, dataset_name="duplicate", weight=0.5) with pytest.raises(ValueError, match="Duplicate dataset names detected"): - InterleavedDataset(datasets=[ds1, ds2], weights=[0.5, 0.5], seed=SEED) + InterleavedDataset(datasets=[ds1, ds2], seed=SEED) # Test weight normalization (should work with warning) - ds3 = dataset_factory(small_dataset_file, dataset_name="ds3") - ds4 = dataset_factory(small_dataset_file, dataset_name="ds4") + ds3 = dataset_factory(small_dataset_file, dataset_name="ds3", weight=0.5) + ds4 = dataset_factory(small_dataset_file, dataset_name="ds4", weight=1.5) with patch("logging.Logger.warning") as mock_warning: interleaved = InterleavedDataset( datasets=[ds3, ds4], - weights=[0.5, 1.5], seed=SEED, dataset_name="test_interleaved", # Sum = 2.0 != 1.0 ) - # Check that weights were normalized - assert torch.allclose(interleaved._weights, torch.tensor([0.25, 0.75])) - mock_warning.assert_called_once() assert interleaved.dataset_name == "test_interleaved" + # Test sampling_weight property returns normalized weights + sampling_weights = interleaved.sampling_weight + assert isinstance(sampling_weights, dict) + assert "ds3" in sampling_weights + assert "ds4" in sampling_weights + assert abs(sampling_weights["ds3"] - 0.25) < 1e-6 + assert abs(sampling_weights["ds4"] - 0.75) < 1e-6 + assert abs(sum(sampling_weights.values()) - 1.0) < 1e-6 + def test_sampling_ratios( self, dataset_factory, small_dataset_file, medium_dataset_file ): @@ -132,12 +137,11 @@ def test_sampling_ratios( # Create two datasets with distinct ID ranges # ds1 has IDs 0-22 (small dataset) # ds2 has IDs 100-134 (medium dataset with offset) - ds1 = dataset_factory(small_dataset_file, dataset_name="ds1") - ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2") + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.7) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.3) # Test with 70/30 weighting - weights = [0.7, 0.3] - interleaved = InterleavedDataset([ds1, ds2], weights, seed=SEED) + interleaved = InterleavedDataset([ds1, ds2], seed=SEED) # Collect 300 samples sample_count = 300 @@ -162,10 +166,10 @@ def test_metrics_aggregation( self, dataset_factory, small_dataset_file, medium_dataset_file ): """Tests that metrics from all child datasets are collected and aggregated.""" - ds1 = dataset_factory(small_dataset_file, dataset_name="ds1") - ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2") + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.2) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.8) - interleaved = InterleavedDataset([ds1, ds2], [0.2, 0.8], seed=SEED) + interleaved = InterleavedDataset([ds1, ds2], seed=SEED) aggregator = MetricsAggregator() # Process some samples @@ -203,9 +207,9 @@ def test_checkpointing( """Tests that interleaved dataset checkpointing preserves sampling state.""" def create_interleaved(): - ds1 = dataset_factory(small_dataset_file, dataset_name="ds1") - ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2") - return InterleavedDataset([ds1, ds2], [0.7, 0.3], seed=SEED) + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.7) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.3) + return InterleavedDataset([ds1, ds2], seed=SEED) # Original run interleaved1 = create_interleaved() @@ -291,6 +295,7 @@ def create_dataset(): shuffle_buffer_size=0, # No shuffle for determinism metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, + weight=0.8, ) ds2 = HfIterableDataset( path="json", @@ -300,10 +305,11 @@ def create_dataset(): shuffle_buffer_size=0, # No shuffle for determinism metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, + weight=0.2, ) - # Create interleaved dataset with 70/30 weighting - return InterleavedDataset([ds1, ds2], [0.8, 0.2], seed=SEED) + # Create interleaved dataset with 80/20 weighting + return InterleavedDataset([ds1, ds2], seed=SEED) def create_dataloader(dataset): loader = StatefulDataLoader( diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 7aac8adcc4..3949e93508 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -108,6 +108,10 @@ def __init__( def dataset_name(self) -> str: return self._dataset_name + @property + def sampling_weight(self) -> float: + return self._weight + def _apply_transforms(self, sample: dict[str, Any]) -> dict[str, Any]: """Apply transforms if they exist, otherwise return sample unchanged.""" if self._message_transform is not None: diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 0245d4e94e..71a8f72674 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -17,14 +17,13 @@ class InterleavedDataset(TuneIterableDataset): - """Infinitely interleaves multiple TuneIterableDatasets according to a list of weights. - - The weights are normalized to sum to 1.0. + """Infinitely interleaves multiple TuneIterableDatasets according to their sampling weights. + - The weights are extracted from each dataset's sampling_weight property and normalized to sum to 1.0. - This dataset is responsible for managing the state of its child datasets to ensure correct checkpointing and resumption. Args: datasets (list[TuneIterableDataset]): list of TuneIterableDatasets to interleave. - weights (list[float]): list of weights for each dataset. Must sum to 1.0. seed (int): Seed for sampling. dataset_name (str): Name of the dataset. If None, defaults to "interleaved_dataset". @@ -35,7 +34,6 @@ class InterleavedDataset(TuneIterableDataset): def __init__( self, datasets: list[TuneIterableDataset], - weights: list[float], seed: int, dataset_name: str = "interleaved_dataset", ): @@ -62,8 +60,16 @@ def __init__( self._sampling_generator = torch.Generator().manual_seed(seed) + # Extract weights from datasets' sampling_weight property + weights = [] + for ds in datasets: + weight = ds.sampling_weight + if isinstance(weight, dict): + # For composite datasets, sum up their weights + weight = sum(weight.values()) + weights.append(weight) + # Normalize weights to sum to 1 - # TODO: make it a property? rely on ds.weight? total_weight = sum(weights) self._weights = torch.tensor( [w / total_weight for w in weights], dtype=torch.float @@ -78,6 +84,10 @@ def __init__( def dataset_name(self) -> str: return self._dataset_name + @property + def sampling_weight(self) -> dict[str, float]: + return {name: weight.item() for name, weight in zip(self._dataset_names, self._weights)} + def __iter__(self) -> Iterator[dict[str, Any]]: """Interleave samples from child infinite datasets""" child_iters = {name: iter(ds) for name, ds in self._datasets.items()} diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index f0821dc3f1..6630761f0d 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -24,6 +24,18 @@ def dataset_name(self) -> str: """A unique identifier for the dataset, used for namespacing in metrics and checkpoints.""" pass + @property + @abstractmethod + def sampling_weight(self) -> float | dict[str, float]: + """ + Returns the sampling weight for this dataset when used in multi-dataset scenarios. + + For leaf datasets: returns a float representing the relative weight. + For composite datasets: returns a dict mapping child dataset names to their weights. + Used by interleaving logic to determine sampling probabilities. + """ + pass + @abstractmethod def __iter__(self) -> Iterator[dict[str, Any]]: """ From 3c9d161629d55cee1a9e55b49df98734f985db2f Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 18:31:02 -0400 Subject: [PATCH 13/25] add sampling log to interlead dataset --- tests/torchtune/datasets/test_interleaved.py | 19 +++++++++++++++++ torchtune/datasets/_interleaved.py | 22 +++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 38d92bc2d3..361714523c 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -243,6 +243,25 @@ def create_interleaved(): result["final_metrics"] == result["resumed_metrics"] ), "Final metrics should match" + # Test sampling log functionality + # Check that sampling log contains tuples of (iteration_count, dataset_name) + state_dict = interleaved1.state_dict() + sampling_log = state_dict["sampling_log"] + iteration_count = state_dict["iteration_count"] + + assert len(sampling_log) > 0, "Sampling log should not be empty" + assert iteration_count > 0, "Iteration count should be positive" + + # Check sampling ratios match expected weights (70/30) + ds1_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds1") + ds2_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds2") + total_samples = ds1_count + ds2_count + + ds1_ratio = ds1_count / total_samples + ds2_ratio = ds2_count / total_samples + + assert abs(ds1_ratio - 0.7) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.7" + assert abs(ds2_ratio - 0.3) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~0.3" class TestDistributedInterleavedDataset(FSDPTest): @property diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 71a8f72674..13859dd2d0 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -7,6 +7,7 @@ import collections import logging import math +from collections import deque from typing import Any, Iterator import torch @@ -26,7 +27,8 @@ class InterleavedDataset(TuneIterableDataset): datasets (list[TuneIterableDataset]): list of TuneIterableDatasets to interleave. seed (int): Seed for sampling. dataset_name (str): Name of the dataset. If None, defaults to "interleaved_dataset". - + sampling_log_maxlen (int): Maximum length of the sampling log. + Raises: ValueError: If duplicate dataset names are detected in the provided datasets. """ @@ -36,8 +38,10 @@ def __init__( datasets: list[TuneIterableDataset], seed: int, dataset_name: str = "interleaved_dataset", + sampling_log_maxlen: int = 10000, ): self._dataset_name = dataset_name + self._sampling_log_maxlen = sampling_log_maxlen # Preserve original order for weighted sampling self._dataset_names = [ds.dataset_name for ds in datasets] @@ -60,6 +64,10 @@ def __init__( self._sampling_generator = torch.Generator().manual_seed(seed) + # Track sampling decisions for debugging and analysis + self._sampling_log: deque[tuple[int, str]] = deque(maxlen=self._sampling_log_maxlen) + self._iteration_count = 0 + # Extract weights from datasets' sampling_weight property weights = [] for ds in datasets: @@ -101,6 +109,10 @@ def __iter__(self) -> Iterator[dict[str, Any]]: # Sample an index, then get the name for safe lookup ds_name = self._dataset_names[ds_idx] + # Log this sampling decision + self._sampling_log.append((self._iteration_count, ds_name)) + self._iteration_count += 1 + try: sample = next(child_iters[ds_name]) yield sample @@ -123,6 +135,8 @@ def state_dict(self) -> dict[str, Any]: return { "sampling_generator_state": self._sampling_generator.get_state(), "child_states": child_states, + "sampling_log": list(self._sampling_log), + "iteration_count": self._iteration_count, } def load_state_dict(self, state_dict: dict[str, Any]) -> None: @@ -133,3 +147,9 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: if name in child_states: # Pass the raw state dict to the child ds.load_state_dict(child_states[name]) + + # Load sampling log and iteration count + self._sampling_log = deque( + state_dict.get("sampling_log", []), maxlen=self._sampling_log_maxlen + ) + self._iteration_count = state_dict.get("iteration_count", 0) From 4804663bf375baa9cd9bfe0eff83fa4831e0f9ac Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 22:32:09 -0400 Subject: [PATCH 14/25] fix nested interleave --- tests/torchtune/datasets/test_interleaved.py | 8 +++++++- torchtune/datasets/_interleaved.py | 9 ++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 361714523c..4dedfca5ee 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -107,10 +107,16 @@ def test_initialization_validation(self, dataset_factory, small_dataset_file): with pytest.raises(ValueError, match="Duplicate dataset names detected"): InterleavedDataset(datasets=[ds1, ds2], seed=SEED) - # Test weight normalization (should work with warning) + # Test nested interleaved datasets are rejected ds3 = dataset_factory(small_dataset_file, dataset_name="ds3", weight=0.5) ds4 = dataset_factory(small_dataset_file, dataset_name="ds4", weight=1.5) + nested_interleaved = InterleavedDataset([ds3, ds4], seed=SEED, dataset_name="nested") + + with pytest.raises(ValueError, match="returned a dict for sampling_weight"): + # This should fail because nested_interleaved.sampling_weight returns a dict + InterleavedDataset([nested_interleaved, ds3], seed=SEED) + # Test weight normalization (should work with warning) with patch("logging.Logger.warning") as mock_warning: interleaved = InterleavedDataset( datasets=[ds3, ds4], diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 13859dd2d0..46c4cfb7dd 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -68,13 +68,16 @@ def __init__( self._sampling_log: deque[tuple[int, str]] = deque(maxlen=self._sampling_log_maxlen) self._iteration_count = 0 - # Extract weights from datasets' sampling_weight property + # Extract weights from child datasets weights = [] for ds in datasets: weight = ds.sampling_weight if isinstance(weight, dict): - # For composite datasets, sum up their weights - weight = sum(weight.values()) + raise ValueError( + f"Child dataset '{ds.dataset_name}' returned a dict for sampling_weight, " + f"indicating it's a composite dataset (likely InterleavedDataset). " + f"Nested interleaving is not supported. Please flatten the dataset hierarchy." + ) weights.append(weight) # Normalize weights to sum to 1 From 2fe4b401bf6f98931f9cd5eb60780bfa8503a731 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 22:57:24 -0400 Subject: [PATCH 15/25] changes to TuneIterableDataset --- torchtune/datasets/_iterable_base.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index 6630761f0d..dafa8d9fce 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -7,12 +7,8 @@ from abc import ABC, abstractmethod from typing import Any, Iterator -from torch.utils.data import IterableDataset - - -class TuneIterableDataset(IterableDataset, ABC): - """ - Abstract base class for all torchtune iterable datasets. +class TuneIterableDataset(ABC): + """Abstract base class for all torchtune iterable datasets. It defines the minimal, consistent interface required for all dataset implementations to ensure they are compatible with the training loop, checkpointing, and metric logging systems. @@ -25,22 +21,19 @@ def dataset_name(self) -> str: pass @property - @abstractmethod def sampling_weight(self) -> float | dict[str, float]: - """ - Returns the sampling weight for this dataset when used in multi-dataset scenarios. + """Returns the sampling weight for this dataset, especially useful in multi-dataset scenarios. For leaf datasets: returns a float representing the relative weight. For composite datasets: returns a dict mapping child dataset names to their weights. - Used by interleaving logic to determine sampling probabilities. """ - pass + return 1.0 @abstractmethod def __iter__(self) -> Iterator[dict[str, Any]]: - """ - Returns an infinite iterator over the dataset. Each implementation is responsible - for its own iteration logic, including shuffling and making it an infinite stream. + """Returns an infinite iterator over the dataset. Each implementation is responsible + for its own iteration logic, including shuffling, distribution of data across ranks, + and making it an infinite stream. """ pass From 72211c99cac9254c785ead30ce7bfddcf0229d60 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 2 Jul 2025 20:01:47 -0700 Subject: [PATCH 16/25] add IterableDataset back --- torchtune/datasets/_iterable_base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index dafa8d9fce..92290d658e 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -7,7 +7,10 @@ from abc import ABC, abstractmethod from typing import Any, Iterator -class TuneIterableDataset(ABC): +from torch.utils.data import IterableDataset + + +class TuneIterableDataset(IterableDataset, ABC): """Abstract base class for all torchtune iterable datasets. It defines the minimal, consistent interface required for all dataset implementations to ensure they are compatible with the training loop, @@ -23,7 +26,7 @@ def dataset_name(self) -> str: @property def sampling_weight(self) -> float | dict[str, float]: """Returns the sampling weight for this dataset, especially useful in multi-dataset scenarios. - + For leaf datasets: returns a float representing the relative weight. For composite datasets: returns a dict mapping child dataset names to their weights. """ @@ -33,8 +36,7 @@ def sampling_weight(self) -> float | dict[str, float]: def __iter__(self) -> Iterator[dict[str, Any]]: """Returns an infinite iterator over the dataset. Each implementation is responsible for its own iteration logic, including shuffling, distribution of data across ranks, - and making it an infinite stream. - """ + and making it an infinite stream.""" pass @abstractmethod From b350ac7b5e13f9121bbf4d8397eebebc24cbcd8b Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sat, 5 Jul 2025 21:59:44 -0400 Subject: [PATCH 17/25] nested interleaved + dataset.info --- tests/torchtune/datasets/test_hf_iterable.py | 11 +- tests/torchtune/datasets/test_interleaved.py | 294 +++++++++++++------ torchtune/datasets/__init__.py | 4 +- torchtune/datasets/_hf_iterable.py | 33 +-- torchtune/datasets/_interleaved.py | 141 ++++----- torchtune/datasets/_iterable_base.py | 47 ++- 6 files changed, 327 insertions(+), 203 deletions(-) diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index 94b55b87be..ec3ca26936 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -104,10 +104,9 @@ def test_default_dataset_name(self, small_dataset_file): ) # Should generate name from path and split - assert dataset.dataset_name == "json_train" + assert dataset.info.name == "json_train" # Test default sampling weight - assert dataset.sampling_weight == 1.0 - assert isinstance(dataset.sampling_weight, float) + assert dataset.info.weight == 1.0 # Test giving a name and custom weight custom_weight = 2.5 @@ -122,10 +121,10 @@ def test_default_dataset_name(self, small_dataset_file): num_shards_per_rank=4, ) - # Should generate name from path and split - assert dataset2.dataset_name == "my_dataset" + # Should use provided name and weight + assert dataset2.info.name == "my_dataset" # Test custom sampling weight - assert dataset2.sampling_weight == custom_weight + assert dataset2.info.weight == custom_weight @pytest.mark.parametrize("num_epochs", [0.5, 1.0, 2.5]) def test_epoch_boundaries_and_checkpointing( diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 4dedfca5ee..c2ea53c970 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -27,6 +27,7 @@ # Test Constants SMALL_DATASET_SIZE = 23 MEDIUM_DATASET_SIZE = 35 +LARGE_DATASET_SIZE = 47 SEED = 42 BATCH_SIZE = 5 @@ -70,6 +71,13 @@ def medium_dataset_file(tmp_data_dir): return str(path) +@pytest.fixture +def large_dataset_file(tmp_data_dir): + path = tmp_data_dir / "large_data.json" + create_test_json_file(path, LARGE_DATASET_SIZE, offset=1000) + return str(path) + + @pytest.fixture def dataset_factory(): """Factory for creating HfIterableDataset instances with common defaults.""" @@ -100,122 +108,201 @@ class TestInterleavedDataset: def test_initialization_validation(self, dataset_factory, small_dataset_file): """Tests that the dataset raises errors for invalid configurations, like duplicate names.""" - # Test duplicate dataset names - ds1 = dataset_factory(small_dataset_file, dataset_name="duplicate", weight=0.5) - ds2 = dataset_factory(small_dataset_file, dataset_name="duplicate", weight=0.5) + + # Test 1: Duplicate dataset names should raise an error + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) + ds2 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) - with pytest.raises(ValueError, match="Duplicate dataset names detected"): + with pytest.raises(ValueError, match="Duplicate dataset names found in hierarchy"): InterleavedDataset(datasets=[ds1, ds2], seed=SEED) - # Test nested interleaved datasets are rejected - ds3 = dataset_factory(small_dataset_file, dataset_name="ds3", weight=0.5) - ds4 = dataset_factory(small_dataset_file, dataset_name="ds4", weight=1.5) - nested_interleaved = InterleavedDataset([ds3, ds4], seed=SEED, dataset_name="nested") + # Test 2: Nested interleaved datasets should be supported + ds3 = dataset_factory(small_dataset_file, dataset_name="ds3", weight=1.5) + interleaved_child = InterleavedDataset([ds1, ds3], seed=SEED, dataset_name="interleaved_child") - with pytest.raises(ValueError, match="returned a dict for sampling_weight"): - # This should fail because nested_interleaved.sampling_weight returns a dict - InterleavedDataset([nested_interleaved, ds3], seed=SEED) - - # Test weight normalization (should work with warning) + # Create a parent interleaved dataset containing the nested one + ds4 = dataset_factory(small_dataset_file, dataset_name="ds4", weight=0.5) + + # Test 3: Weight normalization should work with a warning with patch("logging.Logger.warning") as mock_warning: - interleaved = InterleavedDataset( - datasets=[ds3, ds4], - seed=SEED, - dataset_name="test_interleaved", # Sum = 2.0 != 1.0 - ) - - - assert interleaved.dataset_name == "test_interleaved" - - # Test sampling_weight property returns normalized weights - sampling_weights = interleaved.sampling_weight - assert isinstance(sampling_weights, dict) - assert "ds3" in sampling_weights - assert "ds4" in sampling_weights - assert abs(sampling_weights["ds3"] - 0.25) < 1e-6 - assert abs(sampling_weights["ds4"] - 0.75) < 1e-6 - assert abs(sum(sampling_weights.values()) - 1.0) < 1e-6 - + interleaved_parent = InterleavedDataset([interleaved_child, ds4], seed=SEED, dataset_name="interleaved_parent") + + # Verify that a warning was logged about weight normalization + mock_warning.assert_called_once() + warning_message = mock_warning.call_args[0][0] + assert "normalized" in warning_message.lower() + + # Verify the hierarchical structure is correct + assert interleaved_parent.info.name == "interleaved_parent" + assert len(interleaved_parent.info.children) == 2 + assert interleaved_parent.info.children[0].name == "interleaved_child" + assert interleaved_parent.info.children[1].name == "ds4" + + # Verify the nested structure within the nested dataset + nested_info = interleaved_parent.info.children[0] + assert len(nested_info.children) == 2 + assert nested_info.children[0].name == "ds1" + assert nested_info.children[1].name == "ds3" + + # Verify that sampling weights are normalized to sum to 1.0 + # Access the internal normalized weights tensor + normalized_weights = interleaved_parent._normalized_weights + assert isinstance(normalized_weights, torch.Tensor) + assert len(normalized_weights) == 2 + + # nested: 2.0/2.5 = 0.8, ds4: 0.5/2.5 = 0.2 + assert abs(normalized_weights[0].item() - 0.8) < 1e-6 + assert abs(normalized_weights[1].item() - 0.2) < 1e-6 + assert abs(normalized_weights.sum().item() - 1.0) < 1e-6 + + # Verify that original weights in info remain unnormalized + child_weights = [child.weight for child in interleaved_parent.info.children] + assert abs(child_weights[0] - 2.0) < 1e-6 # nested original weight + assert abs(child_weights[1] - 0.5) < 1e-6 # ds4 original weight + + + def test_single_dataset(self, dataset_factory, small_dataset_file): + """Tests that InterleavedDataset works correctly with a single dataset.""" + # Create a single dataset + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) + + # Should work without issues + interleaved = InterleavedDataset([ds1], seed=SEED) + + # Verify the hierarchical structure + assert interleaved.info.name == "interleaved_dataset" # default name + assert len(interleaved.info.children) == 1 + assert interleaved.info.children[0].name == "ds1" + assert interleaved.info.children[0].weight == 0.5 + + # Verify normalized weights sum to 1.0 (single dataset gets weight 1.0) + normalized_weights = interleaved._normalized_weights + assert isinstance(normalized_weights, torch.Tensor) + assert len(normalized_weights) == 1 + assert abs(normalized_weights[0].item() - 1.0) < 1e-6 + + # Test that iteration works correctly + samples = list(islice(iter(interleaved), 10)) + assert len(samples) == 10 + + # All samples should come from the single dataset + sample_ids = {sample["id"] for sample in samples} + expected_ids = set(range(23)) # ds1 has IDs 0-22 + assert sample_ids == expected_ids + def test_sampling_ratios( - self, dataset_factory, small_dataset_file, medium_dataset_file + self, dataset_factory, small_dataset_file, medium_dataset_file, large_dataset_file ): - """Tests that datasets are sampled according to their assigned weights.""" - # Create two datasets with distinct ID ranges - # ds1 has IDs 0-22 (small dataset) - # ds2 has IDs 100-134 (medium dataset with offset) - ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.7) - ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.3) + """Tests that datasets are sampled according to their assigned weights in nested structure.""" + # Create three datasets with distinct ID ranges + # ds1 has IDs 0-22, ds2 has IDs 100-134, ds3 has IDs 1000-1046 + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.3) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.7) + ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) - # Test with 70/30 weighting - interleaved = InterleavedDataset([ds1, ds2], seed=SEED) + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) + child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") + parent_interleaved = InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") - # Collect 300 samples - sample_count = 300 - samples = list(islice(iter(interleaved), sample_count)) + # Collect 400 samples + sample_count = 400 + samples = list(islice(iter(parent_interleaved), sample_count)) # Count samples by checking ID ranges - # ds1 has IDs < 100, ds2 has IDs >= 100 - ds1_count = sum(1 for s in samples if s["id"] < 100) - ds2_count = sum(1 for s in samples if s["id"] >= 100) + ds1_count = sum(1 for s in samples if 0 <= s["id"] < SMALL_DATASET_SIZE) + ds2_count = sum(1 for s in samples if 100 <= s["id"] < (MEDIUM_DATASET_SIZE + 100)) + ds3_count = sum(1 for s in samples if 1000 <= s["id"] < (LARGE_DATASET_SIZE + 1000)) - assert ds1_count + ds2_count == sample_count + assert ds1_count + ds2_count + ds3_count == sample_count - # Check ratios are approximately correct + # Calculate ratios ds1_ratio = ds1_count / sample_count ds2_ratio = ds2_count / sample_count + ds3_ratio = ds3_count / sample_count + + # Expected ratios based on nested weighting: + # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 + # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each + # Final ratios: ds1=0.5*0.3=0.15, ds2=0.5*0.7=0.35, ds3=0.5 + expected_ds1_ratio = 0.15 + expected_ds2_ratio = 0.35 + expected_ds3_ratio = 0.5 # Allow 10% tolerance due to randomness - assert abs(ds1_ratio - 0.7) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.7" - assert abs(ds2_ratio - 0.3) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~0.3" + assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" def test_metrics_aggregation( - self, dataset_factory, small_dataset_file, medium_dataset_file + self, dataset_factory, small_dataset_file, medium_dataset_file, large_dataset_file ): - """Tests that metrics from all child datasets are collected and aggregated.""" + """Tests that metrics from all child datasets are collected and aggregated in nested structure.""" ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.2) ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.8) + ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) - interleaved = InterleavedDataset([ds1, ds2], seed=SEED) + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) + child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") + parent_interleaved = InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") + aggregator = MetricsAggregator() # Process some samples - total_samples = 200 - for sample in islice(iter(interleaved), total_samples): + total_samples = 300 + for sample in islice(iter(parent_interleaved), total_samples): aggregator.update(sample["metrics"]) metrics = aggregator.get_metrics_for_logging(prefix="train") - # Should have metrics from both datasets, with flat keys + # Should have metrics from all three datasets, with flat keys assert "train_ds1/samples_seen" in metrics assert "train_ds2/samples_seen" in metrics + assert "train_ds3/samples_seen" in metrics - # Both datasets should have contributed samples + # All datasets should have contributed samples assert metrics["train_ds1/samples_seen"] > 0 assert metrics["train_ds2/samples_seen"] > 0 + assert metrics["train_ds3/samples_seen"] > 0 # Total samples should equal what we processed calculated_total_samples = ( - metrics["train_ds1/samples_seen"] + metrics["train_ds2/samples_seen"] + metrics["train_ds1/samples_seen"] + + metrics["train_ds2/samples_seen"] + + metrics["train_ds3/samples_seen"] ) assert calculated_total_samples == total_samples - # Test that ratio is approximately correct + # Test that ratios are approximately correct based on nested weighting ds1_ratio = metrics["train_ds1/samples_seen"] / total_samples ds2_ratio = metrics["train_ds2/samples_seen"] / total_samples + ds3_ratio = metrics["train_ds3/samples_seen"] / total_samples + + # Expected ratios based on nested weighting: + # Inner weights: ds1=0.2, ds2=0.8 -> inner total=1.0 + # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each + # Final ratios: ds1=0.5*0.2=0.1, ds2=0.5*0.8=0.4, ds3=0.5 + expected_ds1_ratio = 0.1 + expected_ds2_ratio = 0.4 + expected_ds3_ratio = 0.5 # Allow 10% tolerance due to randomness - assert abs(ds1_ratio - 0.2) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.2" - assert abs(ds2_ratio - 0.8) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~0.8" + assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" def test_checkpointing( - self, dataset_factory, small_dataset_file, medium_dataset_file + self, dataset_factory, small_dataset_file, medium_dataset_file, large_dataset_file ): - """Tests that interleaved dataset checkpointing preserves sampling state.""" + """Tests that interleaved dataset checkpointing preserves sampling state in nested structure.""" def create_interleaved(): - ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.7) - ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.3) - return InterleavedDataset([ds1, ds2], seed=SEED) + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.3) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.7) + ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) + + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) + child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") + return InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") # Original run interleaved1 = create_interleaved() @@ -258,16 +345,27 @@ def create_interleaved(): assert len(sampling_log) > 0, "Sampling log should not be empty" assert iteration_count > 0, "Iteration count should be positive" - # Check sampling ratios match expected weights (70/30) + # Check sampling ratios match expected weights for nested structure ds1_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds1") ds2_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds2") - total_samples = ds1_count + ds2_count + ds3_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds3") + total_samples = ds1_count + ds2_count + ds3_count ds1_ratio = ds1_count / total_samples ds2_ratio = ds2_count / total_samples + ds3_ratio = ds3_count / total_samples + + # Expected ratios based on nested weighting: + # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 + # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each + # Final ratios: ds1=0.5*0.3=0.15, ds2=0.5*0.7=0.35, ds3=0.5 + expected_ds1_ratio = 0.15 + expected_ds2_ratio = 0.35 + expected_ds3_ratio = 0.5 - assert abs(ds1_ratio - 0.7) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~0.7" - assert abs(ds2_ratio - 0.3) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~0.3" + assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" class TestDistributedInterleavedDataset(FSDPTest): @property @@ -277,10 +375,10 @@ def world_size(self) -> int: @gpu_test(gpu_count=2) def test_distributed_interleaved_checkpointing(self): """ - Test interleaved dataset checkpointing with distributed settings. + Test interleaved dataset checkpointing with distributed settings using nested structure. Assertions: - Each rank processes non-overlapping data shards - - Sampling ratios (70/30) are maintained across ranks + - Sampling ratios for nested structure (ds1: 15%, ds2: 35%, ds3: 50%) are maintained across ranks - Checkpoint/resume produces identical batches (deterministic) - Metrics correctly aggregate across ranks """ @@ -303,6 +401,7 @@ def test_distributed_interleaved_checkpointing(self): def create_dataset(): file1 = tmp_path / "ds1.json" file2 = tmp_path / "ds2.json" + file3 = tmp_path / "ds3.json" # Only rank 0 creates the data files if rank == 0: @@ -310,6 +409,9 @@ def create_dataset(): create_test_json_file( file2, MEDIUM_DATASET_SIZE, offset=100 ) # IDs 100-134 + create_test_json_file( + file3, LARGE_DATASET_SIZE, offset=1000 + ) # IDs 1000-1046 dist.barrier() # Wait for file creation ds1 = HfIterableDataset( @@ -320,7 +422,7 @@ def create_dataset(): shuffle_buffer_size=0, # No shuffle for determinism metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, - weight=0.8, + weight=0.3, ) ds2 = HfIterableDataset( path="json", @@ -330,11 +432,22 @@ def create_dataset(): shuffle_buffer_size=0, # No shuffle for determinism metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, - weight=0.2, + weight=0.7, + ) + ds3 = HfIterableDataset( + path="json", + data_files=str(file3), + split="train", + dataset_name="ds3", + shuffle_buffer_size=0, # No shuffle for determinism + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=2, + weight=1.0, ) - # Create interleaved dataset with 80/20 weighting - return InterleavedDataset([ds1, ds2], seed=SEED) + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) + child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") + return InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") def create_dataloader(dataset): loader = StatefulDataLoader( @@ -371,24 +484,35 @@ def create_dataloader(dataset): result["final_metrics"] == result["resumed_metrics"] ), "Final metrics don't match resumed metrics - aggregator state issue" - # Verify sampling ratio is approximately maintained (80/20 split) + # Verify sampling ratio is approximately maintained for nested structure all_ids = [] for batch in ( result["pre_checkpoint_batches"] + result["post_checkpoint_batches"] ): all_ids.extend(batch["id"].tolist()) - # Count samples by ID ranges: ds1 has IDs < 100, ds2 has IDs >= 100 - ds1_samples = sum(1 for id in all_ids if id < 100) - ds2_samples = sum(1 for id in all_ids if id >= 100) - total_samples = ds1_samples + ds2_samples + # Count samples by ID ranges: ds1 has IDs 0-22, ds2 has IDs 100-134, ds3 has IDs 1000-1046 + ds1_samples = sum(1 for id in all_ids if 0 <= id < SMALL_DATASET_SIZE) + ds2_samples = sum(1 for id in all_ids if 100 <= id < (MEDIUM_DATASET_SIZE + 100)) + ds3_samples = sum(1 for id in all_ids if 1000 <= id < (LARGE_DATASET_SIZE + 1000)) + total_samples = ds1_samples + ds2_samples + ds3_samples if total_samples > 0: ds1_ratio = ds1_samples / total_samples - assert 0.6 < ds1_ratio < 1.0, ( - f"Rank {rank}: Dataset sampling ratio {ds1_ratio:.2f} outside expected " - f"range for 80/20 split. Got {ds1_samples}, {ds2_samples} samples." - ) + ds2_ratio = ds2_samples / total_samples + ds3_ratio = ds3_samples / total_samples + + # Expected ratios based on nested weighting: + # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 + # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each + # Final ratios: ds1=0.5*0.3=0.15, ds2=0.5*0.7=0.35, ds3=0.5 + expected_ds1_ratio = 0.15 + expected_ds2_ratio = 0.35 + expected_ds3_ratio = 0.5 + + assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" finally: # Clean up temp directory (only rank 0) diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index f5ecbb95ea..e0afccdde4 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -18,7 +18,7 @@ from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset from torchtune.datasets._instruct import instruct_dataset from torchtune.datasets._interleaved import InterleavedDataset -from torchtune.datasets._iterable_base import TuneIterableDataset +from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset, TuneIterableDataset from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import preference_dataset, PreferenceDataset from torchtune.datasets._samsum import samsum_dataset @@ -38,6 +38,7 @@ "chat_dataset", "cnn_dailymail_articles_dataset", "ConcatDataset", + "DatasetInfo", "grammar_dataset", "hh_rlhf_helpful_dataset", "HfIterableDataset", @@ -55,6 +56,7 @@ "stack_exchange_paired_dataset", "text_completion_dataset", "TextCompletionDataset", + "InfiniteTuneIterableDataset", "TuneIterableDataset", "wikitext_dataset", ] diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 3949e93508..8856f04699 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -17,12 +17,12 @@ DefaultTrainingMetricTransform, Metric, ) -from torchtune.datasets._iterable_base import TuneIterableDataset +from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset logger = logging.getLogger(__name__) -class HfIterableDataset(TuneIterableDataset): +class HfIterableDataset(InfiniteTuneIterableDataset): """HuggingFace dataset implementation with composable metrics. This is an infinite dataset. After exhausting the dataset, it will restart from the beginning. @@ -46,6 +46,7 @@ class HfIterableDataset(TuneIterableDataset): of world_size * dataloader_workers. dataset_name (Optional[str]): Name of the dataset. If None, a default name is generated from the path, source, and split. + weight (Optional[float]): Weight for this dataset. Defaults to 1.0. filter_fn (Optional[Callable]): Filter function to apply to the dataset. filter_kwargs (Optional[dict[str, Any]]): Keyword arguments to pass to the filter function. load_dataset_kwargs (dict[str, Any]): Keyword arguments to pass to the load_dataset function. @@ -74,12 +75,12 @@ def __init__( self._message_transform = message_transform self._model_transform = model_transform self._output_transform = output_transform - self._weight = weight # TODO: make it a property? + self._weight = weight # Create default transform if not provided self._metric_transform = metric_transform or DefaultTrainingMetricTransform() - # Auto-generate dataset name if not provided, ensuring it's always a string. + # Auto-generate dataset name if not provided, ensuring it's always a string if dataset_name is None: path = load_dataset_kwargs.get("path", None) source = load_dataset_kwargs.get("source", None) @@ -88,13 +89,14 @@ def __init__( for item in [path, source, split]: if item is not None: name_parts.append(str(item).replace("/", "_")) - self._dataset_name: str = "_".join(name_parts) - else: - self._dataset_name: str = dataset_name + dataset_name = "_".join(name_parts) + + # Build the hierarchical info object for this dataset + self._info = DatasetInfo(name=dataset_name, weight=weight) # Set dataset name on the transform if it supports it if hasattr(self._metric_transform, "set_dataset_name"): - self._metric_transform.set_dataset_name(self._dataset_name) + self._metric_transform.set_dataset_name(dataset_name) # Internal state for resumption self._num_epochs = 0 @@ -105,12 +107,9 @@ def __init__( ) @property - def dataset_name(self) -> str: - return self._dataset_name - - @property - def sampling_weight(self) -> float: - return self._weight + def info(self) -> DatasetInfo: + """Returns info for this leaf dataset, which has no children.""" + return self._info def _apply_transforms(self, sample: dict[str, Any]) -> dict[str, Any]: """Apply transforms if they exist, otherwise return sample unchanged.""" @@ -227,7 +226,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: # especially useful when interleaving multiple datasets, but # also necessary to track dataset-level metrics. metric_num_epochs = Metric( - dataset_name=self.dataset_name, + dataset_name=self.info.name, name="num_epochs", value=self._num_epochs, agg_type=AggregationType.MAX, @@ -243,14 +242,14 @@ def __iter__(self) -> Iterator[dict[str, Any]]: pass # Iterator is exhausted, which is expected. except Exception as e: logger.warning( - f"Dataset {self.dataset_name} encountered an unexpected error: {e}." + f"Dataset {self.info.name} encountered an unexpected error: {e}." ) raise # Check if we got zero samples - this might indicate an issue if samples_yielded == 0: logger.warning( - f"Dataset {self.dataset_name} epoch {self._num_epochs} yielded 0 samples - potential issue!" + f"Dataset {self.info.name} epoch {self._num_epochs} yielded 0 samples - potential issue!" ) # Epoch complete - increment and continue infinite loop diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 46c4cfb7dd..2dee36e557 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -4,137 +4,111 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import collections +from collections import deque import logging import math -from collections import deque from typing import Any, Iterator import torch -from torchtune.datasets._iterable_base import TuneIterableDataset +from torchtune.datasets._iterable_base import ( + DatasetInfo, + InfiniteTuneIterableDataset, +) logger = logging.getLogger(__name__) -class InterleavedDataset(TuneIterableDataset): +class InterleavedDataset(InfiniteTuneIterableDataset): """Infinitely interleaves multiple TuneIterableDatasets according to their sampling weights. - - The weights are extracted from each dataset's sampling_weight property and normalized to sum to 1.0. + - The weights are extracted from each dataset's info.weight property and normalized to sum to 1.0. - This dataset is responsible for managing the state of its child datasets to ensure correct checkpointing and resumption. Args: - datasets (list[TuneIterableDataset]): list of TuneIterableDatasets to interleave. + datasets (list[InfiniteTuneIterableDataset]): list of datasets to interleave. seed (int): Seed for sampling. - dataset_name (str): Name of the dataset. If None, defaults to "interleaved_dataset". + weight (float): Weight for this dataset. Defaults to 1.0. + dataset_name (str): Name of the dataset. Defaults to "interleaved_dataset". sampling_log_maxlen (int): Maximum length of the sampling log. Raises: - ValueError: If duplicate dataset names are detected in the provided datasets. + ValueError: If duplicate dataset names are detected in the hierarchy. """ def __init__( self, - datasets: list[TuneIterableDataset], + datasets: list[InfiniteTuneIterableDataset], seed: int, + weight: float = 1.0, dataset_name: str = "interleaved_dataset", sampling_log_maxlen: int = 10000, ): - self._dataset_name = dataset_name + self._datasets = sorted(datasets, key=lambda ds: ds.info.name) self._sampling_log_maxlen = sampling_log_maxlen - # Preserve original order for weighted sampling - self._dataset_names = [ds.dataset_name for ds in datasets] - - # Create a name-to-dataset mapping for robust state management - self._datasets: dict[str, TuneIterableDataset] = { - ds.dataset_name: ds for ds in datasets - } - - # Validate unique dataset names upfront - fail fast with clear error - names = self._dataset_names - if len(names) != len(set(names)): - duplicates = [ - name for name, count in collections.Counter(names).items() if count > 1 - ] - raise ValueError( - f"Duplicate dataset names detected: {duplicates}. All {names=}" - f"Please provide a unique 'dataset_name' for each dataset in the interleaved list." - ) - - self._sampling_generator = torch.Generator().manual_seed(seed) + # Build the hierarchical info object for this dataset + self._info = DatasetInfo( + name=dataset_name, + weight=weight, + children=tuple(ds.info for ds in self._datasets), + ) - # Track sampling decisions for debugging and analysis - self._sampling_log: deque[tuple[int, str]] = deque(maxlen=self._sampling_log_maxlen) - self._iteration_count = 0 + # Validate the entire hierarchy using the base class method + self._validate_unique_dataset_names() - # Extract weights from child datasets - weights = [] - for ds in datasets: - weight = ds.sampling_weight - if isinstance(weight, dict): - raise ValueError( - f"Child dataset '{ds.dataset_name}' returned a dict for sampling_weight, " - f"indicating it's a composite dataset (likely InterleavedDataset). " - f"Nested interleaving is not supported. Please flatten the dataset hierarchy." - ) - weights.append(weight) - - # Normalize weights to sum to 1 - total_weight = sum(weights) - self._weights = torch.tensor( - [w / total_weight for w in weights], dtype=torch.float - ) + # Extract weights from direct children and normalize them + child_weights = [info.weight for info in self._info.children] + total_weight = sum(child_weights) if not math.isclose(total_weight, 1.0, rel_tol=1e-9): logger.warning( f"Interleaved dataset normalized weights to sum to 1.0. " - f"Found {total_weight=}. Previous {weights=}, new {self._weights.tolist()}" + f"Previous weights={child_weights}, " + f"new weights={[w / total_weight for w in child_weights]}" ) + self._normalized_weights = torch.tensor( + [w / total_weight for w in child_weights], dtype=torch.float + ) + + # Track sampling decisions for debugging and analysis + self._sampling_log: deque[tuple[int, str]] = deque( + maxlen=self._sampling_log_maxlen + ) + self._iteration_count = 0 + self._sampling_generator = torch.Generator().manual_seed(seed) @property - def dataset_name(self) -> str: - return self._dataset_name - - @property - def sampling_weight(self) -> dict[str, float]: - return {name: weight.item() for name, weight in zip(self._dataset_names, self._weights)} + def info(self) -> DatasetInfo: + return self._info def __iter__(self) -> Iterator[dict[str, Any]]: """Interleave samples from child infinite datasets""" - child_iters = {name: iter(ds) for name, ds in self._datasets.items()} + # Create a dictionary of iterators for each child dataset + child_iters = {ds.info.name: iter(ds) for ds in self._datasets} while True: - # Sample which dataset to use + # Sample a child dataset based on the normalized weights ds_idx = torch.multinomial( - self._weights, 1, replacement=True, generator=self._sampling_generator + self._normalized_weights, + 1, + replacement=True, + generator=self._sampling_generator, ).item() - # Sample an index, then get the name for safe lookup - ds_name = self._dataset_names[ds_idx] + selected_ds = self._datasets[ds_idx] + ds_name = selected_ds.info.name - # Log this sampling decision + # Log self._sampling_log.append((self._iteration_count, ds_name)) self._iteration_count += 1 - try: - sample = next(child_iters[ds_name]) - yield sample - except StopIteration: - # Per the design, child datasets must be infinite. - # We re-initialize to allow for continuous operation but warn loudly - # as this may indicate a design problem in the child dataset. - logger.warning( - f"Child dataset {self._datasets[ds_name].dataset_name} was exhausted. " - "This is unexpected for an infinite dataset. Re-initializing its iterator." - ) - child_iters[ds_name] = iter(self._datasets[ds_name]) - sample = next(child_iters[ds_name]) - yield sample + # Yield the next sample from the selected child iterator + yield next(child_iters[ds_name]) def state_dict(self) -> dict[str, Any]: """Save state for the interleaver and its children.""" - # The parent is responsible for namespacing the child states. - child_states = {name: ds.state_dict() for name, ds in self._datasets.items()} + # The parent is responsible for namespacing the child states + child_states = {ds.info.name: ds.state_dict() for ds in self._datasets} return { "sampling_generator_state": self._sampling_generator.get_state(), "child_states": child_states, @@ -146,11 +120,10 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load state for the interleaver and its children.""" self._sampling_generator.set_state(state_dict["sampling_generator_state"]) child_states = state_dict["child_states"] - for name, ds in self._datasets.items(): - if name in child_states: - # Pass the raw state dict to the child - ds.load_state_dict(child_states[name]) - + + for ds in self._datasets: + ds.load_state_dict(child_states[ds.info.name]) + # Load sampling log and iteration count self._sampling_log = deque( state_dict.get("sampling_log", []), maxlen=self._sampling_log_maxlen diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index 92290d658e..51dec07990 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -5,11 +5,23 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import Any, Iterator from torch.utils.data import IterableDataset +@dataclass(frozen=True) +class DatasetInfo: + """Represents hierarchical information about a dataset, including its name, + sampling weight and children. Children is a common case when composing datasets, + e.g. Packed(InterleavedDataset([ds1, ds2])). + """ + name: str + weight: float = 1.0 + children: tuple["DatasetInfo", ...] = field(default_factory=tuple) + + class TuneIterableDataset(IterableDataset, ABC): """Abstract base class for all torchtune iterable datasets. It defines the minimal, consistent interface required for all dataset @@ -19,22 +31,32 @@ class TuneIterableDataset(IterableDataset, ABC): @property @abstractmethod - def dataset_name(self) -> str: - """A unique identifier for the dataset, used for namespacing in metrics and checkpoints.""" + def info(self) -> DatasetInfo: + """Returns a hierarchical structure of all dataset information, including + this dataset and its children.""" pass - @property - def sampling_weight(self) -> float | dict[str, float]: - """Returns the sampling weight for this dataset, especially useful in multi-dataset scenarios. + def _validate_unique_dataset_names(self) -> None: + """Traverses the DatasetInfo tree and raises ValueError on duplicate names.""" + root_info = self.info + names = [] + to_process = [root_info] + + while to_process: + node = to_process.pop(0) + names.append(node.name) + to_process.extend(node.children) - For leaf datasets: returns a float representing the relative weight. - For composite datasets: returns a dict mapping child dataset names to their weights. - """ - return 1.0 + # Check for duplicates after traversing the whole tree + duplicates = [name for name in set(names) if names.count(name) > 1] + if duplicates: + raise ValueError( + f"Duplicate dataset names found in hierarchy: {duplicates=}, all names={names}" + ) @abstractmethod def __iter__(self) -> Iterator[dict[str, Any]]: - """Returns an infinite iterator over the dataset. Each implementation is responsible + """Returns an iterator over the dataset. Each implementation is responsible for its own iteration logic, including shuffling, distribution of data across ranks, and making it an infinite stream.""" pass @@ -48,3 +70,8 @@ def state_dict(self) -> dict[str, Any]: def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load state from a state dictionary, used when resuming from a checkpoint.""" pass + +class InfiniteTuneIterableDataset(TuneIterableDataset): + """Abstract base class for infinite datasets, which yield samples indefinitely. + It only purpose is to make it explicit that the dataset is expected to be infinite.""" + pass \ No newline at end of file From f9a1aecee41abf4e3fe2f7c399442885e30a9920 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sat, 5 Jul 2025 23:13:40 -0400 Subject: [PATCH 18/25] nits hf_iterable --- torchtune/datasets/_hf_iterable.py | 29 ++++++++++++++++------------ torchtune/datasets/_iterable_base.py | 5 ++++- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 8856f04699..5c759039d5 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -145,13 +145,22 @@ def _setup_hf_dataset( # Load and shard dataset ds = load_dataset(**load_dataset_kwargs) - # Use to_iterable_dataset for streaming datasets - if not load_dataset_kwargs.get("streaming", False): - + # Use to_iterable_dataset for non-streaming datasets + is_streaming = load_dataset_kwargs.get("streaming", False) + if is_streaming: + logger.warning( + f"Streaming datasets were not yet tested for distributed training. " + f"split_dataset_by_node is applied, but no resharding was done manually. " + f"Dataset '{self.info.name}' has " + f"{getattr(ds, 'num_shards', 'unknown')}, and your training has {world_size} ranks." + f"See: https://huggingface.co/docs/datasets/en/package_reference/main_classes?#datasets.IterableDataset.shard" + f"Consider setting streaming=False, which should also be faster." + ) + if not is_streaming: # Define number of shards based on (world_size, num of shards per GPU, dataloader workers) # E.g. world_size=2, num_shards_per_rank=16, dataloader_workers=3 - # we will try 2*16 = 32 shards. Since 32 is not a multiple of 3, we will do 36 shards. - # Each rank gets 16 shards, each dataloader worker in that rankgets 6 shards. + # we will try 2*16 = 32 shards. Since 32 is not a multiple of 6, we will do 36 shards. + # Each rank gets 18 shards, each dataloader worker in that rank gets 6 shards. worker_info = torch.utils.data.get_worker_info() num_dataloader_workers = worker_info.num_workers if worker_info else 1 @@ -171,14 +180,12 @@ def _setup_hf_dataset( # If the dataset is not streaming and has a defined length, # we cannot have num_shards > dataset_size. - if not load_dataset_kwargs.get("streaming", False) and hasattr( - ds, "__len__" - ): + if hasattr(ds, "__len__"): dataset_size = len(ds) if num_shards > dataset_size: raise ValueError( f"Number of shards ({num_shards}) is greater than the dataset size ({dataset_size})." - f"Please decrease num_shards_per_rank." + f"Please decrease one of {num_shards_per_rank=} or {num_dataloader_workers=} or {world_size=}." ) ds = ds.to_iterable_dataset(num_shards=num_shards) @@ -210,8 +217,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: """ while True: # Infinite iteration - epoch_seed = self._seed + self._num_epochs - self._ds.set_epoch(epoch_seed) + self._ds.set_epoch(self._num_epochs) epoch_iterator = iter(self._ds) samples_yielded = 0 @@ -262,7 +268,6 @@ def state_dict(self) -> dict[str, Any]: hf_state = self._ds.state_dict() state = { "num_epochs": self._num_epochs, - "seed": self._seed, "hf_dataset_state": hf_state, } return state diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index 51dec07990..fee09a5123 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -71,7 +71,10 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load state from a state dictionary, used when resuming from a checkpoint.""" pass + class InfiniteTuneIterableDataset(TuneIterableDataset): """Abstract base class for infinite datasets, which yield samples indefinitely. - It only purpose is to make it explicit that the dataset is expected to be infinite.""" + It only purpose is to make it explicit that the dataset is expected to be infinite, i.e. + it never exhausts. This is helpful to avoid complexity due to some rank hanging because + of lack of data""" pass \ No newline at end of file From f7a3aa76453af770bbb48c30e7092816a7059054 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sat, 5 Jul 2025 23:14:40 -0400 Subject: [PATCH 19/25] update readme --- torchtune/data/metrics/readme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/data/metrics/readme.md b/torchtune/data/metrics/readme.md index 6c6c413246..bf6cb8d27b 100644 --- a/torchtune/data/metrics/readme.md +++ b/torchtune/data/metrics/readme.md @@ -24,7 +24,7 @@ The metrics module provides a robust system for tracking and aggregating trainin │ • Uses pluggable AggregationHandlers │ │ • Handles distributed reduction │ └─────────────────────┬──────────────────────────────┘ - │ {prefix_dataset/metric: value} + │ {prefix}_{dataset_name}/{metric_name} # prefix is "train", "val", etc. ┌─────────────────────▼──────────────────────────────┐ │ Logging System │ │ • W&B, TensorBoard, etc. │ From 17878bf9f498434c9af7469b043a729ef8d9a9af Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sat, 5 Jul 2025 23:23:31 -0400 Subject: [PATCH 20/25] make metric dataset name explicit --- torchtune/data/metrics/_metric_transform.py | 31 +++++++++++++-------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py index 529fff8c5c..5affaca046 100644 --- a/torchtune/data/metrics/_metric_transform.py +++ b/torchtune/data/metrics/_metric_transform.py @@ -6,8 +6,7 @@ from dataclasses import dataclass from enum import Enum -from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union from torchtune.modules.transforms import Transform @@ -41,7 +40,6 @@ class MetricTransform(Transform): def __init__(self): # dataset_name is set by the dataset using set_dataset_name self.dataset_name: Optional[str] = None - self.new_metric: Optional[Callable] = None def set_dataset_name(self, dataset_name: str) -> None: """Called by dataset to set the namespace for metrics. @@ -53,8 +51,6 @@ def set_dataset_name(self, dataset_name: str) -> None: dataset_name (str): Name of the dataset for metric namespacing """ self.dataset_name = dataset_name - # Create a partial to make it easier to create new metrics - self.new_metric = partial(Metric, dataset_name=dataset_name) def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: """Generate metrics for a single sample. Must be implemented by subclasses. @@ -72,7 +68,7 @@ def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: """Apply transform to sample, adding generated metrics.""" - if self.dataset_name is None or self.new_metric is None: + if self.dataset_name is None: raise RuntimeError( "set_dataset_name() must be called before using the transform." ) @@ -112,7 +108,7 @@ class DefaultTrainingMetricTransform(MetricTransform): """ def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: - if self.new_metric is None: + if self.dataset_name is None: raise RuntimeError( "set_dataset_name() must be called before using the transform." ) @@ -123,11 +119,22 @@ def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: # Create metrics for this sample return [ - self.new_metric(name="samples_seen", value=1, agg_type=AggregationType.SUM), - self.new_metric( - name="tokens_seen", value=token_len, agg_type=AggregationType.SUM + Metric( + dataset_name=self.dataset_name, + name="samples_seen", + value=1, + agg_type=AggregationType.SUM, ), - self.new_metric( - name="seq_len", value=token_len, agg_type=AggregationType.DISTRIBUTION + Metric( + dataset_name=self.dataset_name, + name="tokens_seen", + value=token_len, + agg_type=AggregationType.SUM, + ), + Metric( + dataset_name=self.dataset_name, + name="seq_len", + value=token_len, + agg_type=AggregationType.DISTRIBUTION, ), ] From 101e96e205db8ca56cb72c654a1aef316a2e3758 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sat, 5 Jul 2025 23:29:07 -0400 Subject: [PATCH 21/25] update recipe to share log freq + validagtion msg --- recipes/full_finetune_distributed.py | 56 +++++++++++++--------------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index fadb3f7c23..95d5069eaa 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -277,7 +277,6 @@ def __init__(self, cfg: DictConfig) -> None: # Step-based training support self.num_training_steps = cfg.num_training_steps - self._dataset_metrics_log_freq = cfg.get("dataset_metrics_log_freq", 100) self._metrics_aggregator = None # Will be initialized in setup def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: @@ -311,7 +310,7 @@ def setup(self, cfg: DictConfig) -> None: """ if cfg.get("dataset_val") is not None: raise NotImplementedError( - "Validation is not supported yet with iterable datasets." + "Validation is not supported yet with iterable datasets since it currently requiresinfinite datasets." ) if self.fsdp_cpu_offload: @@ -1045,39 +1044,34 @@ def train(self) -> None: ) # Log per-step metrics - if ( - self.global_step % self._log_every_n_steps == 0 - and self._is_rank_zero - ): - time_per_step = time.perf_counter() - t0 - log_dict = { - "loss": loss_to_log, - "lr": get_lr( - self._optimizer - if not self._optimizer_in_bwd - else self._optim_ckpt_wrapper - ), - "tokens_per_second_per_gpu": ( - num_tokens / self.parallel_dims.non_data_parallel_size - ) - / (time_per_step * self.world_size), - } - if self._log_peak_memory_stats: - log_dict.update(training.get_memory_stats(device=self._device)) - if self._clip_grad_norm is not None: - log_dict.update({"grad_norm": grad_norm}) - self._metric_logger.log_dict(log_dict, step=self.global_step) - - # Log dataset metrics - # #TODO: it requires all_gather. Should we keep a separate log_freq for this? - if self.global_step % self._dataset_metrics_log_freq == 0: + if self.global_step % self._log_every_n_steps == 0: + # Get dataset metrics outside of rank zero check since it involves all_gather dataset_metrics = self._metrics_aggregator.get_metrics_for_logging( prefix="train" ) + if self._is_rank_zero: - self._metric_logger.log_dict( - dataset_metrics, step=self.global_step - ) + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": get_lr( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + "tokens_per_second_per_gpu": ( + num_tokens / self.parallel_dims.non_data_parallel_size + ) + / (time_per_step * self.world_size), + } + if dataset_metrics: + log_dict.update(dataset_metrics) + if self._log_peak_memory_stats: + log_dict.update(training.get_memory_stats(device=self._device)) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict(log_dict, step=self.global_step) + # Save checkpoint if specified by user if ( From 1b3f3fcc4f37f0746c68d57cc9a99cd27a64daa1 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sun, 6 Jul 2025 07:44:15 -0700 Subject: [PATCH 22/25] update interleaved tests to do nesting --- tests/torchtune/datasets/test_interleaved.py | 229 +++++++++++++------ 1 file changed, 158 insertions(+), 71 deletions(-) diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index c2ea53c970..96c825d2f1 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -108,38 +108,46 @@ class TestInterleavedDataset: def test_initialization_validation(self, dataset_factory, small_dataset_file): """Tests that the dataset raises errors for invalid configurations, like duplicate names.""" - + # Test 1: Duplicate dataset names should raise an error ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) ds2 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) - with pytest.raises(ValueError, match="Duplicate dataset names found in hierarchy"): + with pytest.raises( + ValueError, match="Duplicate dataset names found in hierarchy" + ): InterleavedDataset(datasets=[ds1, ds2], seed=SEED) # Test 2: Nested interleaved datasets should be supported ds3 = dataset_factory(small_dataset_file, dataset_name="ds3", weight=1.5) - interleaved_child = InterleavedDataset([ds1, ds3], seed=SEED, dataset_name="interleaved_child") - + interleaved_child = InterleavedDataset( + [ds1, ds3], seed=SEED, dataset_name="interleaved_child" + ) + # Create a parent interleaved dataset containing the nested one ds4 = dataset_factory(small_dataset_file, dataset_name="ds4", weight=0.5) - + # Test 3: Weight normalization should work with a warning with patch("logging.Logger.warning") as mock_warning: - interleaved_parent = InterleavedDataset([interleaved_child, ds4], seed=SEED, dataset_name="interleaved_parent") - + interleaved_parent = InterleavedDataset( + [interleaved_child, ds4], seed=SEED, dataset_name="interleaved_parent" + ) + # Verify that a warning was logged about weight normalization mock_warning.assert_called_once() warning_message = mock_warning.call_args[0][0] assert "normalized" in warning_message.lower() - + # Verify the hierarchical structure is correct assert interleaved_parent.info.name == "interleaved_parent" assert len(interleaved_parent.info.children) == 2 - assert interleaved_parent.info.children[0].name == "interleaved_child" - assert interleaved_parent.info.children[1].name == "ds4" - + # Datasets are sorted alphabetically, so ds4 comes before interleaved_child + assert interleaved_parent.info.children[0].name == "ds4" + assert interleaved_parent.info.children[1].name == "interleaved_child" + # Verify the nested structure within the nested dataset - nested_info = interleaved_parent.info.children[0] + # interleaved_child is at index 1 due to alphabetical sorting + nested_info = interleaved_parent.info.children[1] assert len(nested_info.children) == 2 assert nested_info.children[0].name == "ds1" assert nested_info.children[1].name == "ds3" @@ -149,49 +157,54 @@ def test_initialization_validation(self, dataset_factory, small_dataset_file): normalized_weights = interleaved_parent._normalized_weights assert isinstance(normalized_weights, torch.Tensor) assert len(normalized_weights) == 2 - - # nested: 2.0/2.5 = 0.8, ds4: 0.5/2.5 = 0.2 - assert abs(normalized_weights[0].item() - 0.8) < 1e-6 - assert abs(normalized_weights[1].item() - 0.2) < 1e-6 + + # ds4: 0.5/(0.5+1.0) = 1/3, interleaved_child: 1.0/(0.5+1.0) = 2/3 + assert abs(normalized_weights[0].item() - 1 / 3) < 1e-3 + assert abs(normalized_weights[1].item() - 2 / 3) < 1e-3 assert abs(normalized_weights.sum().item() - 1.0) < 1e-6 # Verify that original weights in info remain unnormalized child_weights = [child.weight for child in interleaved_parent.info.children] - assert abs(child_weights[0] - 2.0) < 1e-6 # nested original weight - assert abs(child_weights[1] - 0.5) < 1e-6 # ds4 original weight - + assert abs(child_weights[0] - 0.5) < 1e-6 # ds4 original weight + assert ( + abs(child_weights[1] - 1.0) < 1e-6 + ) # interleaved_child original weight def test_single_dataset(self, dataset_factory, small_dataset_file): """Tests that InterleavedDataset works correctly with a single dataset.""" # Create a single dataset ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) - + # Should work without issues interleaved = InterleavedDataset([ds1], seed=SEED) - + # Verify the hierarchical structure assert interleaved.info.name == "interleaved_dataset" # default name assert len(interleaved.info.children) == 1 assert interleaved.info.children[0].name == "ds1" assert interleaved.info.children[0].weight == 0.5 - + # Verify normalized weights sum to 1.0 (single dataset gets weight 1.0) normalized_weights = interleaved._normalized_weights assert isinstance(normalized_weights, torch.Tensor) assert len(normalized_weights) == 1 assert abs(normalized_weights[0].item() - 1.0) < 1e-6 - + # Test that iteration works correctly samples = list(islice(iter(interleaved), 10)) assert len(samples) == 10 - - # All samples should come from the single dataset + + # All samples should come from the single dataset (ds1 has IDs 0-22) sample_ids = {sample["id"] for sample in samples} - expected_ids = set(range(23)) # ds1 has IDs 0-22 + expected_ids = set(range(10)) # ds1 has IDs 0-22 assert sample_ids == expected_ids - + def test_sampling_ratios( - self, dataset_factory, small_dataset_file, medium_dataset_file, large_dataset_file + self, + dataset_factory, + small_dataset_file, + medium_dataset_file, + large_dataset_file, ): """Tests that datasets are sampled according to their assigned weights in nested structure.""" # Create three datasets with distinct ID ranges @@ -201,8 +214,12 @@ def test_sampling_ratios( ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) - child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") - parent_interleaved = InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") + child_interleaved = InterleavedDataset( + [ds1, ds2], seed=SEED, dataset_name="child" + ) + parent_interleaved = InterleavedDataset( + [child_interleaved, ds3], seed=SEED, dataset_name="parent" + ) # Collect 400 samples sample_count = 400 @@ -210,8 +227,12 @@ def test_sampling_ratios( # Count samples by checking ID ranges ds1_count = sum(1 for s in samples if 0 <= s["id"] < SMALL_DATASET_SIZE) - ds2_count = sum(1 for s in samples if 100 <= s["id"] < (MEDIUM_DATASET_SIZE + 100)) - ds3_count = sum(1 for s in samples if 1000 <= s["id"] < (LARGE_DATASET_SIZE + 1000)) + ds2_count = sum( + 1 for s in samples if 100 <= s["id"] < (MEDIUM_DATASET_SIZE + 100) + ) + ds3_count = sum( + 1 for s in samples if 1000 <= s["id"] < (LARGE_DATASET_SIZE + 1000) + ) assert ds1_count + ds2_count + ds3_count == sample_count @@ -219,7 +240,7 @@ def test_sampling_ratios( ds1_ratio = ds1_count / sample_count ds2_ratio = ds2_count / sample_count ds3_ratio = ds3_count / sample_count - + # Expected ratios based on nested weighting: # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each @@ -229,12 +250,22 @@ def test_sampling_ratios( expected_ds3_ratio = 0.5 # Allow 10% tolerance due to randomness - assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" - assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" - assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + assert ( + abs(ds1_ratio - expected_ds1_ratio) < 0.1 + ), f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert ( + abs(ds2_ratio - expected_ds2_ratio) < 0.1 + ), f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert ( + abs(ds3_ratio - expected_ds3_ratio) < 0.1 + ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" def test_metrics_aggregation( - self, dataset_factory, small_dataset_file, medium_dataset_file, large_dataset_file + self, + dataset_factory, + small_dataset_file, + medium_dataset_file, + large_dataset_file, ): """Tests that metrics from all child datasets are collected and aggregated in nested structure.""" ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.2) @@ -242,9 +273,13 @@ def test_metrics_aggregation( ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) - child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") - parent_interleaved = InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") - + child_interleaved = InterleavedDataset( + [ds1, ds2], seed=SEED, dataset_name="child" + ) + parent_interleaved = InterleavedDataset( + [child_interleaved, ds3], seed=SEED, dataset_name="parent" + ) + aggregator = MetricsAggregator() # Process some samples @@ -266,9 +301,9 @@ def test_metrics_aggregation( # Total samples should equal what we processed calculated_total_samples = ( - metrics["train_ds1/samples_seen"] + - metrics["train_ds2/samples_seen"] + - metrics["train_ds3/samples_seen"] + metrics["train_ds1/samples_seen"] + + metrics["train_ds2/samples_seen"] + + metrics["train_ds3/samples_seen"] ) assert calculated_total_samples == total_samples @@ -286,12 +321,22 @@ def test_metrics_aggregation( expected_ds3_ratio = 0.5 # Allow 10% tolerance due to randomness - assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" - assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" - assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + assert ( + abs(ds1_ratio - expected_ds1_ratio) < 0.1 + ), f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert ( + abs(ds2_ratio - expected_ds2_ratio) < 0.1 + ), f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert ( + abs(ds3_ratio - expected_ds3_ratio) < 0.1 + ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" def test_checkpointing( - self, dataset_factory, small_dataset_file, medium_dataset_file, large_dataset_file + self, + dataset_factory, + small_dataset_file, + medium_dataset_file, + large_dataset_file, ): """Tests that interleaved dataset checkpointing preserves sampling state in nested structure.""" @@ -299,10 +344,14 @@ def create_interleaved(): ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.3) ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.7) ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) - + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) - child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") - return InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") + child_interleaved = InterleavedDataset( + [ds1, ds2], seed=SEED, dataset_name="child" + ) + return InterleavedDataset( + [child_interleaved, ds3], seed=SEED, dataset_name="parent" + ) # Original run interleaved1 = create_interleaved() @@ -341,20 +390,36 @@ def create_interleaved(): state_dict = interleaved1.state_dict() sampling_log = state_dict["sampling_log"] iteration_count = state_dict["iteration_count"] - + assert len(sampling_log) > 0, "Sampling log should not be empty" assert iteration_count > 0, "Iteration count should be positive" - - # Check sampling ratios match expected weights for nested structure - ds1_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds1") - ds2_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds2") - ds3_count = sum(1 for _, ds_name in sampling_log if ds_name == "ds3") + + # Check sampling ratios by analyzing the actual samples processed during the test + # Since the sampling log only shows immediate children ("child", "ds3"), + # we need to look at the actual sample IDs to determine leaf dataset usage + + # Collect all sample IDs from the batches processed during checkpointing + all_sample_ids = [] + for batch_list in [ + result["pre_checkpoint_batches"], + result["post_checkpoint_batches"], + ]: + for batch in batch_list: + all_sample_ids.extend(batch["id"].tolist()) + + # Count samples by ID ranges: ds1 has IDs 0-22, ds2 has IDs 100-134, ds3 has IDs 1000-1046 + ds1_count = sum(1 for id in all_sample_ids if 0 <= id < SMALL_DATASET_SIZE) + ds2_count = sum( + 1 for id in all_sample_ids if 100 <= id < (MEDIUM_DATASET_SIZE + 100) + ) + ds3_count = sum( + 1 for id in all_sample_ids if 1000 <= id < (LARGE_DATASET_SIZE + 1000) + ) total_samples = ds1_count + ds2_count + ds3_count - ds1_ratio = ds1_count / total_samples ds2_ratio = ds2_count / total_samples ds3_ratio = ds3_count / total_samples - + # Expected ratios based on nested weighting: # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each @@ -362,10 +427,18 @@ def create_interleaved(): expected_ds1_ratio = 0.15 expected_ds2_ratio = 0.35 expected_ds3_ratio = 0.5 - - assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" - assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" - assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + + # Allow larger tolerance due to small sample size in checkpointing test + assert ( + abs(ds1_ratio - expected_ds1_ratio) < 0.2 + ), f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert ( + abs(ds2_ratio - expected_ds2_ratio) < 0.2 + ), f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert ( + abs(ds3_ratio - expected_ds3_ratio) < 0.2 + ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + class TestDistributedInterleavedDataset(FSDPTest): @property @@ -446,8 +519,12 @@ def create_dataset(): ) # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) - child_interleaved = InterleavedDataset([ds1, ds2], seed=SEED, dataset_name="child") - return InterleavedDataset([child_interleaved, ds3], seed=SEED, dataset_name="parent") + child_interleaved = InterleavedDataset( + [ds1, ds2], seed=SEED, dataset_name="child" + ) + return InterleavedDataset( + [child_interleaved, ds3], seed=SEED, dataset_name="parent" + ) def create_dataloader(dataset): loader = StatefulDataLoader( @@ -493,15 +570,19 @@ def create_dataloader(dataset): # Count samples by ID ranges: ds1 has IDs 0-22, ds2 has IDs 100-134, ds3 has IDs 1000-1046 ds1_samples = sum(1 for id in all_ids if 0 <= id < SMALL_DATASET_SIZE) - ds2_samples = sum(1 for id in all_ids if 100 <= id < (MEDIUM_DATASET_SIZE + 100)) - ds3_samples = sum(1 for id in all_ids if 1000 <= id < (LARGE_DATASET_SIZE + 1000)) + ds2_samples = sum( + 1 for id in all_ids if 100 <= id < (MEDIUM_DATASET_SIZE + 100) + ) + ds3_samples = sum( + 1 for id in all_ids if 1000 <= id < (LARGE_DATASET_SIZE + 1000) + ) total_samples = ds1_samples + ds2_samples + ds3_samples if total_samples > 0: ds1_ratio = ds1_samples / total_samples ds2_ratio = ds2_samples / total_samples ds3_ratio = ds3_samples / total_samples - + # Expected ratios based on nested weighting: # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each @@ -509,10 +590,16 @@ def create_dataloader(dataset): expected_ds1_ratio = 0.15 expected_ds2_ratio = 0.35 expected_ds3_ratio = 0.5 - - assert abs(ds1_ratio - expected_ds1_ratio) < 0.1, f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" - assert abs(ds2_ratio - expected_ds2_ratio) < 0.1, f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" - assert abs(ds3_ratio - expected_ds3_ratio) < 0.1, f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + + assert ( + abs(ds1_ratio - expected_ds1_ratio) < 0.1 + ), f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert ( + abs(ds2_ratio - expected_ds2_ratio) < 0.1 + ), f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert ( + abs(ds3_ratio - expected_ds3_ratio) < 0.1 + ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" finally: # Clean up temp directory (only rank 0) From fac3fd5596964acf5fce51248166d70f78f739bc Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sun, 6 Jul 2025 11:27:49 -0700 Subject: [PATCH 23/25] lint --- recipes/full_finetune_distributed.py | 7 ++++--- tests/torchtune/datasets/test_interleaved.py | 2 +- torchtune/datasets/__init__.py | 6 +++++- torchtune/datasets/_hf_iterable.py | 9 +++++---- torchtune/datasets/_interleaved.py | 14 ++++---------- torchtune/datasets/_iterable_base.py | 4 +++- 6 files changed, 22 insertions(+), 20 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 95d5069eaa..444596619b 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -1049,7 +1049,7 @@ def train(self) -> None: dataset_metrics = self._metrics_aggregator.get_metrics_for_logging( prefix="train" ) - + if self._is_rank_zero: time_per_step = time.perf_counter() - t0 log_dict = { @@ -1067,12 +1067,13 @@ def train(self) -> None: if dataset_metrics: log_dict.update(dataset_metrics) if self._log_peak_memory_stats: - log_dict.update(training.get_memory_stats(device=self._device)) + log_dict.update( + training.get_memory_stats(device=self._device) + ) if self._clip_grad_norm is not None: log_dict.update({"grad_norm": grad_norm}) self._metric_logger.log_dict(log_dict, step=self.global_step) - # Save checkpoint if specified by user if ( self.save_every_n_steps is not None diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index 96c825d2f1..db4ec95035 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -464,7 +464,7 @@ def test_distributed_interleaved_checkpointing(self): temp_dir = None # Broadcast temp directory to all ranks - temp_dir_list = [temp_dir] + temp_dir_list = [temp_dir] if temp_dir is not None else [""] dist.broadcast_object_list(temp_dir_list, src=0) temp_dir = temp_dir_list[0] tmp_path = Path(temp_dir) diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index e0afccdde4..b38663578e 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -18,7 +18,11 @@ from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset from torchtune.datasets._instruct import instruct_dataset from torchtune.datasets._interleaved import InterleavedDataset -from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset, TuneIterableDataset +from torchtune.datasets._iterable_base import ( + DatasetInfo, + InfiniteTuneIterableDataset, + TuneIterableDataset, +) from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import preference_dataset, PreferenceDataset from torchtune.datasets._samsum import samsum_dataset diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 5c759039d5..7b6f790914 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -16,6 +16,7 @@ AggregationType, DefaultTrainingMetricTransform, Metric, + MetricTransform, ) from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset @@ -38,7 +39,7 @@ class HfIterableDataset(InfiniteTuneIterableDataset): model_transform (Optional[Callable]): Take messages and prepares it for the model. Usually the tokenizer. output_transform (Optional[Callable]): Takes tokenized inputs and prepares it for the recipe. Usually does some label manipulation, e.g. ignore index. Think of it as recipe-dependent, e.g. SFT, RL, DPO, etc. - metric_transform (Optional[Callable]): Takes the sample and computes metrics, e.g. token count. + metric_transform (Optional[MetricTransform]): Takes the sample and computes metrics, e.g. token count. If None, a default transform is used. To stop tracking metrics, set it to lambda x: x. shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. If None or 0, no shuffling is done. seed (int): Seed for shuffling. @@ -59,7 +60,7 @@ def __init__( message_transform: Optional[Callable] = None, model_transform: Optional[Callable] = None, output_transform: Optional[Callable] = None, - metric_transform: Optional[Callable] = None, + metric_transform: Optional[MetricTransform] = None, shuffle_buffer_size: Optional[int] = 1000, weight: Optional[float] = 1.0, seed: int = 42, @@ -75,7 +76,7 @@ def __init__( self._message_transform = message_transform self._model_transform = model_transform self._output_transform = output_transform - self._weight = weight + self._weight = weight if weight is not None else 1.0 # Create default transform if not provided self._metric_transform = metric_transform or DefaultTrainingMetricTransform() @@ -92,7 +93,7 @@ def __init__( dataset_name = "_".join(name_parts) # Build the hierarchical info object for this dataset - self._info = DatasetInfo(name=dataset_name, weight=weight) + self._info = DatasetInfo(name=dataset_name, weight=self._weight) # Set dataset name on the transform if it supports it if hasattr(self._metric_transform, "set_dataset_name"): diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 2dee36e557..2267696ef4 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -4,17 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from collections import deque import logging import math +from collections import deque from typing import Any, Iterator import torch -from torchtune.datasets._iterable_base import ( - DatasetInfo, - InfiniteTuneIterableDataset, -) +from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset logger = logging.getLogger(__name__) @@ -31,9 +28,6 @@ class InterleavedDataset(InfiniteTuneIterableDataset): weight (float): Weight for this dataset. Defaults to 1.0. dataset_name (str): Name of the dataset. Defaults to "interleaved_dataset". sampling_log_maxlen (int): Maximum length of the sampling log. - - Raises: - ValueError: If duplicate dataset names are detected in the hierarchy. """ def __init__( @@ -69,7 +63,7 @@ def __init__( self._normalized_weights = torch.tensor( [w / total_weight for w in child_weights], dtype=torch.float ) - + # Track sampling decisions for debugging and analysis self._sampling_log: deque[tuple[int, str]] = deque( maxlen=self._sampling_log_maxlen @@ -88,7 +82,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: while True: # Sample a child dataset based on the normalized weights - ds_idx = torch.multinomial( + ds_idx: int = torch.multinomial( self._normalized_weights, 1, replacement=True, diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index fee09a5123..a26c22eafa 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -17,6 +17,7 @@ class DatasetInfo: sampling weight and children. Children is a common case when composing datasets, e.g. Packed(InterleavedDataset([ds1, ds2])). """ + name: str weight: float = 1.0 children: tuple["DatasetInfo", ...] = field(default_factory=tuple) @@ -77,4 +78,5 @@ class InfiniteTuneIterableDataset(TuneIterableDataset): It only purpose is to make it explicit that the dataset is expected to be infinite, i.e. it never exhausts. This is helpful to avoid complexity due to some rank hanging because of lack of data""" - pass \ No newline at end of file + + pass From 29ba1cb85e8e399fd79a3f4f4eff95e677d6b1f0 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sun, 6 Jul 2025 17:46:34 -0700 Subject: [PATCH 24/25] error if duplicated metric name --- .../torchtune/data/test_metrics_aggregator.py | 104 ++++++++++++++++-- .../torchtune/data/test_metrics_transform.py | 6 +- tests/torchtune/datasets/test_hf_iterable.py | 8 +- .../data/metrics/_metric_agg_handlers.py | 43 ++++---- torchtune/data/metrics/_metric_aggregator.py | 71 ++++++++++-- torchtune/data/metrics/_metric_transform.py | 14 +-- torchtune/datasets/_hf_iterable.py | 2 +- 7 files changed, 188 insertions(+), 60 deletions(-) diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index b65c11f533..c2e8141ff9 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging + import pytest import torch.distributed as dist from tests.test_utils import gpu_test @@ -34,7 +36,9 @@ def test_aggregation_types(self, agg_type, test_values, expected): aggregator = MetricsAggregator() metrics = [ - Metric(dataset_name="test", name="metric", value=val, agg_type=agg_type) + Metric( + dataset_name="test", metric_name="metric", value=val, agg_type=agg_type + ) for val in test_values ] aggregator.update(metrics) @@ -43,7 +47,7 @@ def test_aggregation_types(self, agg_type, test_values, expected): if agg_type == AggregationType.CATEGORICAL_COUNT: for category, count in expected.items(): - assert result[f"train_test/metric_{category}_count"] == count + assert result[f"train_test/metric_count_{category}"] == count else: assert result["train_test/metric"] == expected @@ -61,10 +65,10 @@ def test_distribution_metrics(self): result = aggregator.get_metrics_for_logging(prefix="train") # Verify distribution statistics - assert result["train_test/dist_metric_mean"] == 5.5 - assert result["train_test/dist_metric_min"] == 1 - assert result["train_test/dist_metric_max"] == 10 - assert result["train_test/dist_metric_p50"] == 5.5 + assert result["train_test/dist_metric_stat_mean"] == 5.5 + assert result["train_test/dist_metric_stat_min"] == 1 + assert result["train_test/dist_metric_stat_max"] == 10 + assert result["train_test/dist_metric_stat_p50"] == 5.5 def test_state_management(self): """Test aggregator checkpointing and restoration.""" @@ -149,6 +153,82 @@ def test_prefix_handling(self): assert result_no_prefix["data_test_ds/metric1"] == 42 assert result_no_prefix["data_test_ds/metric2"] == 84 + def test_metric_consistency_validation(self): + """Test that same metric name must use same aggregation type.""" + aggregator = MetricsAggregator() + + # First metric with SUM aggregation + metrics1 = [Metric("test", "my_metric", 10, AggregationType.SUM)] + aggregator.update(metrics1) + + # Try to use same metric name with different aggregation type - should fail + metrics2 = [Metric("test", "my_metric", 5.0, AggregationType.MEAN)] + with pytest.raises( + ValueError, match="is already registered with aggregation type sum" + ): + aggregator.update(metrics2) + + # Same metric name with same aggregation type should work + metrics3 = [Metric("test", "my_metric", 20, AggregationType.SUM)] + aggregator.update(metrics3) # Should not raise + + result = aggregator.get_metrics_for_logging(prefix="train") + assert result["train_test/my_metric"] == 30 # 10 + 20 + + def test_metric_consistency_across_datasets(self): + """Test that same metric name can use different aggregation types across different datasets.""" + aggregator = MetricsAggregator() + + # Same metric name but different datasets - should be allowed + metrics = [ + Metric("dataset1", "metric", 10, AggregationType.SUM), + Metric("dataset2", "metric", 5.0, AggregationType.MEAN), + ] + aggregator.update(metrics) # Should not raise + + result = aggregator.get_metrics_for_logging(prefix="train") + assert result["train_dataset1/metric"] == 10 + assert result["train_dataset2/metric"] == 5.0 + + def test_handler_generated_metric_validation(self): + """Test that handler-generated metrics are validated for consistency.""" + aggregator = MetricsAggregator() + + # Create a user-defined metric that will conflict with distribution stats + user_metrics = [ + Metric("test", "dist_metric_stat_mean", 42, AggregationType.SUM) + ] + aggregator.update(user_metrics) + + # Now try to add a distribution metric that will generate conflicting stat names + dist_metrics = [Metric("test", "dist_metric", 10, AggregationType.DISTRIBUTION)] + aggregator.update(dist_metrics) + + # This should fail when trying to get metrics for logging because the handler + # will try to create "dist_metric_stat_mean" which conflicts with the user metric + with pytest.raises( + ValueError, match="is already registered with aggregation type sum" + ): + aggregator.get_metrics_for_logging(prefix="train") + + def test_handler_replacement_warning(self, caplog): + """Test that replacing handlers in use generates a warning.""" + aggregator = MetricsAggregator() + + # Add a metric that uses SUM aggregation + metrics = [Metric("test", "sum_metric", 10, AggregationType.SUM)] + aggregator.update(metrics) + + # Replace the SUM handler - should generate warning + from torchtune.data.metrics._metric_agg_handlers import SumAggHandler + + with caplog.at_level(logging.WARNING): + aggregator.register_handler(AggregationType.SUM, SumAggHandler()) + + # Check that the expected warning was logged + assert len(caplog.records) == 1 + assert "Replacing handler for AggregationType.SUM" in caplog.records[0].message + class TestDistributedMetricsAggregator(FSDPTest): """Distributed tests for MetricsAggregator using FSDPTest infrastructure.""" @@ -245,15 +325,15 @@ def test_distributed_all_aggregation_types(self): # DISTRIBUTION: Combined values [0,1,2,3,4,10,11,12,13,14] # Mean should be average of local means: (2 + 12) / 2 = 7 - assert result["train_test/dist_metric_mean"] == 7 - assert result["train_test/dist_metric_min"] == 0 - assert result["train_test/dist_metric_max"] == 14 + assert result["train_test/dist_metric_stat_mean"] == 7 + assert result["train_test/dist_metric_stat_min"] == 0 + assert result["train_test/dist_metric_stat_max"] == 14 # CATEGORICAL_COUNT: Total counts across ranks # cat_A: 3(rank0) + 1(rank1) = 4, cat_B: 2(rank0) + 0(rank1) = 2, cat_C: 0(rank0) + 4(rank1) = 4 - assert result["train_test/cat_metric_cat_A_count"] == 4 - assert result["train_test/cat_metric_cat_B_count"] == 2 - assert result["train_test/cat_metric_cat_C_count"] == 4 + assert result["train_test/cat_metric_count_cat_A"] == 4 + assert result["train_test/cat_metric_count_cat_B"] == 2 + assert result["train_test/cat_metric_count_cat_C"] == 4 @gpu_test(gpu_count=2) def test_distributed_state_dict_resumption(self): diff --git a/tests/torchtune/data/test_metrics_transform.py b/tests/torchtune/data/test_metrics_transform.py index 8a4a86d7dd..eb7a1e951e 100644 --- a/tests/torchtune/data/test_metrics_transform.py +++ b/tests/torchtune/data/test_metrics_transform.py @@ -38,17 +38,17 @@ def test_basic_metrics_generation(self): # Check each metric for metric in metrics: - if metric.name == "samples_seen": + if metric.metric_name == "samples_seen": assert metric.dataset_name == "test_dataset" assert metric.value == 1 assert metric.agg_type == AggregationType.SUM - elif metric.name == "tokens_seen": + elif metric.metric_name == "tokens_seen": assert metric.dataset_name == "test_dataset" assert metric.value == 5 assert metric.agg_type == AggregationType.SUM - elif metric.name == "seq_len": + elif metric.metric_name == "seq_len": assert metric.dataset_name == "test_dataset" assert metric.value == 5 assert metric.agg_type == AggregationType.DISTRIBUTION diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index ec3ca26936..9a54a0cc99 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -239,7 +239,9 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): for sample in first_epoch_samples: first_epoch_metrics.extend(sample["metrics"]) epoch_values = [ - metric.value for metric in first_epoch_metrics if metric.name == "epoch" + metric.value + for metric in first_epoch_metrics + if metric.metric_name == "epoch" ] assert all( epoch_value == 0 for epoch_value in epoch_values @@ -250,7 +252,9 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): for sample in second_epoch_samples: second_epoch_metrics.extend(sample["metrics"]) epoch_values = [ - metric.value for metric in second_epoch_metrics if metric.name == "epoch" + metric.value + for metric in second_epoch_metrics + if metric.metric_name == "epoch" ] assert all( epoch_value == 1 for epoch_value in epoch_values diff --git a/torchtune/data/metrics/_metric_agg_handlers.py b/torchtune/data/metrics/_metric_agg_handlers.py index c3415aba7d..d5c4122228 100644 --- a/torchtune/data/metrics/_metric_agg_handlers.py +++ b/torchtune/data/metrics/_metric_agg_handlers.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from collections import Counter, deque from dataclasses import dataclass, field -from typing import Any, Union +from typing import Any import torch @@ -76,21 +76,18 @@ def update(self, local_agg_metric: MetricState, metric: Metric) -> None: pass @abstractmethod - def finalize_local_agg( - self, local_agg_metric: MetricState - ) -> Union[MetricState, list[MetricState]]: + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: """ Computes the final value from the locally aggregated state. - In a distributed setting, this is called before the reduction step. - This method can also expand a single metric into multiple, for instance, + This method may expand a single metric into multiple, for instance, a distribution into mean, min, max, and percentiles. Args: local_agg_metric (MetricState): The locally aggregated metric state to finalize. Returns: - A single `MetricState` or a list of them if the metric expands. + list[MetricState]: List of finalized metric states. """ pass @@ -156,8 +153,8 @@ def update(self, local_agg_metric: MetricState, metric: Metric) -> None: ) local_agg_metric.value += metric.value - def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: - return local_agg_metric + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + return [local_agg_metric] def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: if not local_agg_metrics: @@ -193,8 +190,8 @@ def update(self, local_agg_metric: MetricState, metric: Metric) -> None: ) local_agg_metric.value = max(local_agg_metric.value, metric.value) - def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: - return local_agg_metric + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + return [local_agg_metric] def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: max_value = max(r.value for r in local_agg_metrics) @@ -227,8 +224,8 @@ def update(self, local_agg_metric: MetricState, metric: Metric) -> None: ) local_agg_metric.value = min(local_agg_metric.value, metric.value) - def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: - return local_agg_metric + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + return [local_agg_metric] def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: min_value = min(r.value for r in local_agg_metrics) @@ -259,12 +256,12 @@ def update(self, local_agg_metric: MetricState, metric: Metric) -> None: local_agg_metric.metadata["sum"] += metric.value local_agg_metric.metadata["count"] += 1 - def finalize_local_agg(self, local_agg_metric: MetricState) -> MetricState: + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: count = local_agg_metric.metadata["count"] local_agg_metric.value = ( local_agg_metric.metadata["sum"] / count if count > 0 else 0.0 ) - return local_agg_metric + return [local_agg_metric] def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: total_sum = sum(metric.metadata["sum"] for metric in local_agg_metrics) @@ -349,42 +346,42 @@ def _compute_distribution_stats( metrics = [ MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_mean", + metric_name=f"{local_agg_metric.metric_name}_stat_mean", value=mean_val, agg_type=AggregationType.MEAN, metadata={"sum": sum_val, "count": n}, ), MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_min", + metric_name=f"{local_agg_metric.metric_name}_stat_min", value=min_val, agg_type=AggregationType.MIN, metadata={}, ), MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_max", + metric_name=f"{local_agg_metric.metric_name}_stat_max", value=max_val, agg_type=AggregationType.MAX, metadata={}, ), MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_p05", + metric_name=f"{local_agg_metric.metric_name}_stat_p05", value=p05_val, agg_type=AggregationType.MEAN, metadata={"sum": p05_val, "count": 1}, ), MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_p50", + metric_name=f"{local_agg_metric.metric_name}_stat_p50", value=p50_val, agg_type=AggregationType.MEAN, metadata={"sum": p50_val, "count": 1}, ), MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_p95", + metric_name=f"{local_agg_metric.metric_name}_stat_p95", value=p95_val, agg_type=AggregationType.MEAN, metadata={"sum": p95_val, "count": 1}, @@ -396,7 +393,7 @@ def _compute_distribution_stats( metrics.append( MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_std", + metric_name=f"{local_agg_metric.metric_name}_stat_std", value=std_val, agg_type=AggregationType.MEAN, metadata={"sum": std_val, "count": 1}, @@ -452,7 +449,7 @@ def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState] results.append( MetricState( dataset_name=local_agg_metric.dataset_name, - metric_name=f"{local_agg_metric.metric_name}_{category}_count", + metric_name=f"{local_agg_metric.metric_name}_count_{category}", value=count, agg_type=AggregationType.SUM, ) diff --git a/torchtune/data/metrics/_metric_aggregator.py b/torchtune/data/metrics/_metric_aggregator.py index 633d5c6b80..cb4f78abf0 100644 --- a/torchtune/data/metrics/_metric_aggregator.py +++ b/torchtune/data/metrics/_metric_aggregator.py @@ -5,8 +5,9 @@ # LICENSE file in the root directory of this source tree. import ast +import logging from collections import defaultdict -from typing import Any +from typing import Any, Union import torch.distributed as dist @@ -22,6 +23,8 @@ ) from torchtune.data.metrics._metric_transform import AggregationType, Metric +logger = logging.getLogger(__name__) + class MetricsAggregator: """Aggregates metrics across datasets and distributed ranks using pluggable handlers. @@ -77,6 +80,9 @@ def __init__(self, dist_window_size: int = 1000): self._metric_states: dict[tuple[str, str], MetricState] = {} self._dist_window_size = dist_window_size + # Track aggregation types for validation - prevents same metric name with different agg types + self._metric_agg_types: dict[tuple[str, str], AggregationType] = {} + # Create handler registry - all handlers initialized upfront self._handlers: dict[AggregationType, AggregationHandler] = { AggregationType.SUM: SumAggHandler(), @@ -87,6 +93,24 @@ def __init__(self, dist_window_size: int = 1000): AggregationType.CATEGORICAL_COUNT: CategoricalCountAggHandler(), } + def _validate_metric_consistency(self, metric: Union[Metric, MetricState]) -> None: + """Validate that metric name uses consistent aggregation type.""" + metric_key = (metric.dataset_name, metric.metric_name) + metric_name = metric.metric_name + + if metric_key in self._metric_agg_types: + existing_agg_type = self._metric_agg_types[metric_key] + if existing_agg_type != metric.agg_type: + raise ValueError( + f"Metric '{metric_name}' in dataset '{metric.dataset_name}' " + f"is already registered with aggregation type {existing_agg_type.value}, " + f"but a handler or user code tried to use it with type {metric.agg_type.value}. " + f"Use different metric names for different aggregation types." + ) + else: + # Track this metric's aggregation type + self._metric_agg_types[metric_key] = metric.agg_type + def register_handler( self, agg_type: AggregationType, handler: AggregationHandler ) -> None: @@ -94,8 +118,17 @@ def register_handler( Args: agg_type (AggregationType): The aggregation type to handle - handler (AggregationHandler): Handler instance implementing the Ag∂gregationHandler interface + handler (AggregationHandler): Handler instance implementing the AggregationHandler interface """ + # Warn if replacing a handler that's already in use + if agg_type in self._handlers and any( + state.agg_type == agg_type for state in self._metric_states.values() + ): + logger.warning( + f"Replacing handler for {agg_type} - aggregation type already in use by existing metrics. " + f"This may affect existing metric behavior." + ) + self._handlers[agg_type] = handler def update(self, metrics: list[Metric]) -> None: @@ -105,10 +138,14 @@ def update(self, metrics: list[Metric]) -> None: metrics (list[Metric]): List of metrics to update the state with Raises: - ValueError: If no handler is registered for a metric's aggregation type. + ValueError: If no handler is registered for a metric's aggregation type, + or if metric name conflicts with existing aggregation type. """ for metric in metrics: - metric_key = (metric.dataset_name, metric.name) + # Same metric name must use same aggregation type + self._validate_metric_consistency(metric) + + metric_key = (metric.dataset_name, metric.metric_name) handler = self._handlers.get(metric.agg_type) if handler is None: @@ -118,7 +155,7 @@ def update(self, metrics: list[Metric]) -> None: if metric_key not in self._metric_states: self._metric_states[metric_key] = handler.initialize_metric_state( - metric.dataset_name, metric.name, metric.agg_type + metric.dataset_name, metric.metric_name, metric.agg_type ) local_agg_metric = self._metric_states[metric_key] @@ -153,13 +190,13 @@ def _compute_unified_metrics(self) -> list[MetricState]: prepared_results = [] for local_agg_metric in self._metric_states.values(): handler = self._handlers[local_agg_metric.agg_type] - prepared = handler.finalize_local_agg(local_agg_metric) - if isinstance( - prepared, list - ): # Distribution/categorical expands to multiple - prepared_results.extend(prepared) - else: - prepared_results.append(prepared) + generated_metrics = handler.finalize_local_agg(local_agg_metric) + + # Validate each newly generated metric state immediately + for gen_metric in generated_metrics: + self._validate_metric_consistency(gen_metric) + + prepared_results.extend(generated_metrics) # Step 2: Apply distributed reduction if needed if dist.is_initialized() and dist.get_world_size() > 1: @@ -238,6 +275,10 @@ def state_dict(self) -> dict[str, Any]: "required_agg_types": list( required_agg_types ), # Save which handlers are needed + # Save which aggregation types are used for each metric + "metric_agg_types": { + str(k): v.value for k, v in self._metric_agg_types.items() + }, } def load_state_dict(self, state_dict: dict[str, Any]) -> None: @@ -288,3 +329,9 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: deserialized_state[metric_key] = local_agg_metric self._metric_states = deserialized_state + + # Restore validation state + self._metric_agg_types = {} + for key_str, agg_type_str in state_dict.get("metric_agg_types", {}).items(): + key = ast.literal_eval(key_str) + self._metric_agg_types[key] = AggregationType(agg_type_str) diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py index 5affaca046..f6a7fbf7e2 100644 --- a/torchtune/data/metrics/_metric_transform.py +++ b/torchtune/data/metrics/_metric_transform.py @@ -14,7 +14,7 @@ @dataclass(frozen=True) class Metric: dataset_name: str - name: str + metric_name: str value: Union[int, float, str] agg_type: "AggregationType" @@ -101,9 +101,9 @@ class DefaultTrainingMetricTransform(MetricTransform): >>> metrics = transform._generate_metrics(sample) >>> # Creates: >>> # [ - >>> # Metric(dataset_name="alpaca", name="samples_seen", value=1, agg_type=AggregationType.SUM), - >>> # Metric(dataset_name="alpaca", name="tokens_seen", value=5, agg_type=AggregationType.SUM), - >>> # Metric(dataset_name="alpaca", name="seq_len", value=5, agg_type=AggregationType.DISTRIBUTION) + >>> # Metric(dataset_name="alpaca", metric_name="samples_seen", value=1, agg_type=AggregationType.SUM), + >>> # Metric(dataset_name="alpaca", metric_name="tokens_seen", value=5, agg_type=AggregationType.SUM), + >>> # Metric(dataset_name="alpaca", metric_name="seq_len", value=5, agg_type=AggregationType.DISTRIBUTION) >>> # ] """ @@ -121,19 +121,19 @@ def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: return [ Metric( dataset_name=self.dataset_name, - name="samples_seen", + metric_name="samples_seen", value=1, agg_type=AggregationType.SUM, ), Metric( dataset_name=self.dataset_name, - name="tokens_seen", + metric_name="tokens_seen", value=token_len, agg_type=AggregationType.SUM, ), Metric( dataset_name=self.dataset_name, - name="seq_len", + metric_name="seq_len", value=token_len, agg_type=AggregationType.DISTRIBUTION, ), diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index 7b6f790914..bb0f647508 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -234,7 +234,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: # also necessary to track dataset-level metrics. metric_num_epochs = Metric( dataset_name=self.info.name, - name="num_epochs", + metric_name="num_epochs", value=self._num_epochs, agg_type=AggregationType.MAX, ) From f89eefe90cf6f78559e054847f9ac92fa0c5ce42 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sun, 6 Jul 2025 19:36:34 -0700 Subject: [PATCH 25/25] improve docs --- .../torchtune/data/test_metrics_aggregator.py | 30 +++++- .../torchtune/data/test_metrics_transform.py | 21 ++++- tests/torchtune/datasets/test_hf_iterable.py | 13 +++ tests/torchtune/datasets/test_interleaved.py | 20 +++- .../data/metrics/_metric_agg_handlers.py | 17 ++-- torchtune/data/metrics/_metric_aggregator.py | 21 +++-- torchtune/data/metrics/_metric_transform.py | 56 ++++++++--- torchtune/datasets/_hf_iterable.py | 93 ++++++++++--------- torchtune/datasets/_interleaved.py | 21 +++-- torchtune/datasets/_iterable_base.py | 57 +++++++++--- 10 files changed, 243 insertions(+), 106 deletions(-) diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py index c2e8141ff9..db2ab3f617 100644 --- a/tests/torchtune/data/test_metrics_aggregator.py +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -4,6 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Tests for MetricsAggregator functionality. + +This module tests the metrics collection and aggregation system including: +- All aggregation types (SUM, MEAN, MAX, MIN, DISTRIBUTION, CATEGORICAL_COUNT) +- State management and checkpointing +- Multi-dataset metric namespacing +- Distributed metrics aggregation +- Metric consistency validation + +Uses synthetic metrics to verify correct aggregation behavior across scenarios. +""" + import logging import pytest @@ -15,7 +28,7 @@ class TestMetricsAggregator: - """Focused tests for MetricsAggregator functionality.""" + """Tests for MetricsAggregator core functionality and edge cases.""" @pytest.mark.parametrize( "agg_type,test_values,expected", @@ -32,7 +45,14 @@ class TestMetricsAggregator: ], ) def test_aggregation_types(self, agg_type, test_values, expected): - """Tests each `AggregationType` to ensure it computes the correct value.""" + """Tests each AggregationType with representative data to verify correct computation. + + Covers aggregation types: + - SUM: Simple addition across values + - MEAN: Average computation with proper count tracking + - MAX/MIN: Extrema identification + - CATEGORICAL_COUNT: Category frequency counting + """ aggregator = MetricsAggregator() metrics = [ @@ -52,7 +72,7 @@ def test_aggregation_types(self, agg_type, test_values, expected): assert result["train_test/metric"] == expected def test_distribution_metrics(self): - """Tests that `AggregationType.DISTRIBUTION` computes all expected statistics (mean, min, max, p50).""" + """Tests that DISTRIBUTION aggregation computes statistics (mean, min, max, percentiles).""" aggregator = MetricsAggregator() values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] @@ -71,8 +91,8 @@ def test_distribution_metrics(self): assert result["train_test/dist_metric_stat_p50"] == 5.5 def test_state_management(self): - """Test aggregator checkpointing and restoration.""" - # Create aggregator with some state + """Test metrics aggregator state persistence and restoration for checkpointing scenarios.""" + # Create aggregator with mixed metric types to test state saving aggregator1 = MetricsAggregator() initial_metrics = [ Metric("ds1", "counter", 10, AggregationType.SUM), diff --git a/tests/torchtune/data/test_metrics_transform.py b/tests/torchtune/data/test_metrics_transform.py index eb7a1e951e..ebfb1c81a1 100644 --- a/tests/torchtune/data/test_metrics_transform.py +++ b/tests/torchtune/data/test_metrics_transform.py @@ -4,6 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Tests cover: +- DefaultTrainingMetricTransform +- Basic metric generation (samples_seen, tokens_seen, seq_len) +- Dataset name validation and requirements +- Proper metric type assignment and aggregation configuration +""" + import pytest from torchtune.data.metrics import AggregationType, DefaultTrainingMetricTransform @@ -13,7 +21,9 @@ class TestDefaultTrainingMetricTransform: """Tests for DefaultTrainingMetricTransform functionality.""" def test_dataset_name_not_set_raises_error(self): - """Test that using transform without setting dataset name raises error.""" + """Test that the transform raises a RuntimeError if used before + `set_dataset_name` is called, ensuring that metrics are always + correctly attributed to a dataset.""" transform = DefaultTrainingMetricTransform() sample = {"tokens": [1, 2, 3]} @@ -21,22 +31,23 @@ def test_dataset_name_not_set_raises_error(self): transform(sample) def test_basic_metrics_generation(self): - """Test that transform generates expected metrics for a sample.""" + """Test that transform generates expected training metrics for input samples.""" transform = DefaultTrainingMetricTransform() + # Set dataset name required for metric generation transform.set_dataset_name("test_dataset") sample = {"tokens": [1, 2, 3, 4, 5]} result = transform(sample) - # Should preserve original sample data + # Transform should preserve original sample data unchanged assert result["tokens"] == [1, 2, 3, 4, 5] - # Should add metrics + # Should generate exactly 3 metrics: samples_seen, tokens_seen, seq_len assert "metrics" in result metrics = result["metrics"] assert len(metrics) == 3 - # Check each metric + # Verify each metric has correct properties and aggregation type for metric in metrics: if metric.metric_name == "samples_seen": assert metric.dataset_name == "test_dataset" diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py index 9a54a0cc99..adcede297a 100644 --- a/tests/torchtune/datasets/test_hf_iterable.py +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -4,6 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Tests for HfIterableDataset core functionality. + +This module tests the foundational iterable dataset capabilities including: +- Basic iteration and data loading +- Epoch boundary handling and tracking +- Shuffling behavior across epochs +- Checkpointing and state restoration +- Distributed training scenarios + +Uses synthetic JSON data with predictable patterns to verify correct behavior. +""" + import math import shutil import tempfile diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py index db4ec95035..37bda9adcc 100644 --- a/tests/torchtune/datasets/test_interleaved.py +++ b/tests/torchtune/datasets/test_interleaved.py @@ -4,6 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +Tests for InterleavedDataset functionality. + +This module tests the multi-dataset interleaving capabilities, including: +- Dataset composition with weighted sampling +- Nested interleaving structures +- Metrics collection and aggregation across datasets +- Checkpointing and state restoration +- Distributed training scenarios + +The tests use synthetic JSON data with distinct ID ranges per dataset +to verify correct sampling ratios and data isolation. +""" + import shutil import tempfile from itertools import islice @@ -53,12 +67,14 @@ def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None @pytest.fixture def tmp_data_dir(tmp_path): - """Provide temporary directory for test data files.""" + """Provide temporary directory for test data files. + All test datasets are created in this isolated directory to avoid conflicts.""" return tmp_path @pytest.fixture def small_dataset_file(tmp_data_dir): + """Create small dataset (23 samples) with IDs 0-22 for testing basic functionality.""" path = tmp_data_dir / "small_data.json" create_test_json_file(path, SMALL_DATASET_SIZE, offset=0) return str(path) @@ -66,6 +82,7 @@ def small_dataset_file(tmp_data_dir): @pytest.fixture def medium_dataset_file(tmp_data_dir): + """Create medium dataset (35 samples) with IDs 100-134 for multi-dataset testing.""" path = tmp_data_dir / "medium_data.json" create_test_json_file(path, MEDIUM_DATASET_SIZE, offset=100) return str(path) @@ -73,6 +90,7 @@ def medium_dataset_file(tmp_data_dir): @pytest.fixture def large_dataset_file(tmp_data_dir): + """Create large dataset (47 samples) with IDs 1000-1046 for nested interleaving tests.""" path = tmp_data_dir / "large_data.json" create_test_json_file(path, LARGE_DATASET_SIZE, offset=1000) return str(path) diff --git a/torchtune/data/metrics/_metric_agg_handlers.py b/torchtune/data/metrics/_metric_agg_handlers.py index d5c4122228..ac3f9a2fd7 100644 --- a/torchtune/data/metrics/_metric_agg_handlers.py +++ b/torchtune/data/metrics/_metric_agg_handlers.py @@ -38,15 +38,14 @@ class MetricState: class AggregationHandler(ABC): - """Base class for handling metric aggregation in MetricsAggregator. - - This class defines the interface for different aggregation strategies (e.g., SUM, MEAN). - Each handler is responsible for: - - Initializing the state for a new (dataset, metric) pair. - - Updating the state with new values. - - Finalizing the value for local (single-rank) logging. - - Reducing the values from all ranks in a distributed setting. - - Serializing and deserializing the metric state for checkpointing. + """Base class for handling metric aggregation using the Strategy pattern. + + Each handler implements a specific aggregation strategy (SUM, MEAN, DISTRIBUTION, etc.) + and manages the complete lifecycle: initialization, updates, local finalization, + and distributed reduction. Handlers also handle serialization for checkpointing. + + The handler architecture allows pluggable aggregation strategies while maintaining + consistent interfaces for the MetricsAggregator. """ @abstractmethod diff --git a/torchtune/data/metrics/_metric_aggregator.py b/torchtune/data/metrics/_metric_aggregator.py index cb4f78abf0..da6b152350 100644 --- a/torchtune/data/metrics/_metric_aggregator.py +++ b/torchtune/data/metrics/_metric_aggregator.py @@ -29,14 +29,22 @@ class MetricsAggregator: """Aggregates metrics across datasets and distributed ranks using pluggable handlers. - Uses a handler-based strategy pattern where each aggregation type (SUM, MEAN, etc.) - has its own handler. Maintains only one state per (dataset, metric) pair. + This class uses a handler-based strategy, where each aggregation type (SUM, MEAN, etc.) + has a corresponding AggregationHandler. It maintains a single state object for each + (dataset, metric) pair. - When preparing for logging, uses a two-phase approach: - 1. Local aggregation: Each rank aggregates its metrics independently - 2. Distributed reduction: Results combined across ranks + Internal State Visualization: + { + ("alpaca", "tokens_seen"): MetricState(value=200.0, agg_type=SUM, ...), + ("alpaca", "avg_loss"): MetricState(value=0.01, agg_type=MEAN, metadata={'sum': ..., 'count': ...}), + ("slim_orca", "seq_len"): MetricState(agg_type=DISTRIBUTION, metadata={'values': deque([...])}), + } - The aggregator is checkpointable and restores from state_dict for training resumption. + When preparing metrics for logging, the aggregator follows a two-phase process: + 1. Local Aggregation: Each rank aggregates its metrics independently + 2. Distributed Reduction: If in distributed mode, results are combined across ranks + + The aggregator's state is checkpointable, allowing training resumption. Args: dist_window_size (int): Window size for DistributionAggHandler tracking. @@ -44,7 +52,6 @@ class MetricsAggregator: Example: >>> from torchtune.data.metrics import MetricsAggregator, Metric, AggregationType >>> - >>> # Create aggregator >>> aggregator = MetricsAggregator() >>> >>> # Sample metrics from different batches diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py index f6a7fbf7e2..8521f6e6dd 100644 --- a/torchtune/data/metrics/_metric_transform.py +++ b/torchtune/data/metrics/_metric_transform.py @@ -20,7 +20,11 @@ class Metric: class AggregationType(Enum): - """Defines how a metric's value should be aggregated.""" + """Defines how a metric's value should be aggregated by the MetricsAggregator. + + Each type corresponds to a specific AggregationHandler that implements the logic + for initialization, updates, and distributed reduction. + """ SUM = "sum" MEAN = "mean" @@ -33,22 +37,33 @@ class AggregationType(Enum): class MetricTransform(Transform): """Applied to each dataset sample to generate per-sample metrics for training tracking. - Creates Metric objects that are later aggregated by 'MetricsAggregator'. This separation + Creates Metric objects that are later aggregated by MetricsAggregator. This separation of concerns ensures metrics are correctly aggregated even with multiple dataloader - workers and in distributed settings.""" + workers and in distributed settings. + + The transform must be configured with a dataset name via set_dataset_name() before use. + Each call to __call__ adds metrics to the sample's "metrics" key. + + Example: + >>> transform = DefaultTrainingMetricTransform() + >>> transform.set_dataset_name("alpaca") + >>> sample = {"tokens": [1, 2, 3]} + >>> result = transform(sample) + >>> # result["metrics"] contains list of Metric objects + """ def __init__(self): # dataset_name is set by the dataset using set_dataset_name self.dataset_name: Optional[str] = None def set_dataset_name(self, dataset_name: str) -> None: - """Called by dataset to set the namespace for metrics. + """Called by the dataset to set the namespace for metrics. - The dataset name is used to differentiate multiple datasets stats, - e.g. "train/dataset1/tokens_seen" and "train/dataset2/tokens_seen". + This is used to differentiate metrics from multiple datasets, for example, + "train_alpaca/tokens_seen" vs. "train_slim_orca/tokens_seen". Args: - dataset_name (str): Name of the dataset for metric namespacing + dataset_name (str): Name of the dataset, used for metric namespacing. """ self.dataset_name = dataset_name @@ -67,7 +82,17 @@ def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: raise NotImplementedError("Subclasses must implement _generate_metrics method") def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Apply transform to sample, adding generated metrics.""" + """Apply transform to sample, adding generated metrics to the sample. + + Args: + sample (dict[str, Any]): Input sample dictionary + + Returns: + dict[str, Any]: Sample with metrics added to "metrics" key (list[Metric]) + + Raises: + RuntimeError: If set_dataset_name() was not called before transform usage + """ if self.dataset_name is None: raise RuntimeError( "set_dataset_name() must be called before using the transform." @@ -84,14 +109,17 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: class DefaultTrainingMetricTransform(MetricTransform): - """Generates training metrics: samples_seen, tokens_seen, seq_len distribution. + """Generates common training metrics: samples seen, tokens seen, and sequence length. + + This transform detects the token key in a sample, checking for "tokens" + first and then falling back to "input_ids". - For details about MetricTransform base class behavior, see the parent class docstring. + For details on the base class behavior, see MetricTransform. Tracked metrics: - - samples_seen: Cumulative count of samples processed (SUM aggregation) - - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) - - seq_len: Distribution of sequence lengths (DISTRIBUTION aggregation) + - samples_seen: Cumulative count of samples processed (SUM aggregation) + - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) + - seq_len: Distribution of sequence lengths (DISTRIBUTION aggregation) Example: >>> transform = DefaultTrainingMetricTransform() @@ -99,7 +127,7 @@ class DefaultTrainingMetricTransform(MetricTransform): >>> >>> sample = {"tokens": [1, 2, 3, 4, 5]} # 5 tokens >>> metrics = transform._generate_metrics(sample) - >>> # Creates: + >>> # This generates the following Metric objects: >>> # [ >>> # Metric(dataset_name="alpaca", metric_name="samples_seen", value=1, agg_type=AggregationType.SUM), >>> # Metric(dataset_name="alpaca", metric_name="tokens_seen", value=5, agg_type=AggregationType.SUM), diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py index bb0f647508..f517fece31 100644 --- a/torchtune/datasets/_hf_iterable.py +++ b/torchtune/datasets/_hf_iterable.py @@ -24,34 +24,42 @@ class HfIterableDataset(InfiniteTuneIterableDataset): - """HuggingFace dataset implementation with composable metrics. + """HuggingFace dataset with infinite iteration and composable transforms. - This is an infinite dataset. After exhausting the dataset, it will restart from the beginning. + This is an infinite dataset that wraps a HuggingFace dataset. After exhausting + the dataset, it will restart from the beginning. + + Transform pipeline: raw_data -> message_transform -> model_transform -> output_transform -> metric_transform This dataset is responsible for: - Loading and sharding the dataset - Shuffling at initialization and after each epoch - - Applying transforms + - Applying transforms to the data - Returning an infinite iterator over the dataset - Args: - message_transform (Optional[Callable]): Transforms raw data into Message - model_transform (Optional[Callable]): Take messages and prepares it for the model. Usually the tokenizer. - output_transform (Optional[Callable]): Takes tokenized inputs and prepares it for the recipe. Usually - does some label manipulation, e.g. ignore index. Think of it as recipe-dependent, e.g. SFT, RL, DPO, etc. - metric_transform (Optional[MetricTransform]): Takes the sample and computes metrics, e.g. token count. - If None, a default transform is used. To stop tracking metrics, set it to lambda x: x. - shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. If None or 0, no shuffling is done. - seed (int): Seed for shuffling. - num_shards_per_rank (int): Target number of shards per worker (GPU). It will find a multiple - of world_size * dataloader_workers. - dataset_name (Optional[str]): Name of the dataset. If None, a default name is generated - from the path, source, and split. + Args: + message_transform (Optional[Callable]): Transforms raw data into a `Message`. + model_transform (Optional[Callable]): Prepares messages for the model, + usually by tokenizing them. + output_transform (Optional[Callable]): Prepares tokenized inputs for the + recipe, often by manipulating labels (e.g., setting an ignore index). + This transform is recipe-dependent (e.g., SFT, DPO, etc.). + metric_transform (Optional[MetricTransform]): Computes metrics from a + sample (e.g., token count). If ``None``, a default transform is used. + To disable standard metric tracking, set this to ``lambda x: x``. + shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. + If ``None`` or 0, no shuffling is performed. weight (Optional[float]): Weight for this dataset. Defaults to 1.0. - filter_fn (Optional[Callable]): Filter function to apply to the dataset. - filter_kwargs (Optional[dict[str, Any]]): Keyword arguments to pass to the filter function. - load_dataset_kwargs (dict[str, Any]): Keyword arguments to pass to the load_dataset function. - + seed (int): Seed for shuffling. + num_shards_per_rank (int): The target number of shards per worker (GPU). + The actual number of shards will be a multiple of + ``world_size * dataloader_workers``. + dataset_name (Optional[str]): Name of the dataset. If ``None``, a name is + generated from the ``path``, ``source``, and ``split``. + filter_fn (Optional[Callable]): A function to filter the dataset. + filter_kwargs (Optional[dict[str, Any]]): Keyword arguments for ``filter_fn``. + **load_dataset_kwargs: Keyword arguments for the + :func:`~datasets.load_dataset` function. """ def __init__( @@ -132,9 +140,10 @@ def _setup_hf_dataset( filter_kwargs: Optional[dict[str, Any]] = None, ): """ - Configures the Hugging Face dataset, including sharding, filtering, and - transform mapping. This method is called only once during initialization - to avoid expensive re-computation on each epoch. + One-time setup of HuggingFace dataset that handles Handles distributed sharding, + shuffle configuration, and filtering. + + Called once during __init__ to avoid expensive re-computation. """ # Distributed setup @@ -165,13 +174,13 @@ def _setup_hf_dataset( worker_info = torch.utils.data.get_worker_info() num_dataloader_workers = worker_info.num_workers if worker_info else 1 - # Calculate total workers + # Calculate total workers across all ranks and dataloader processes total_workers = world_size * num_dataloader_workers - # Calculate desired shards + # Find minimum shards that satisfies our target while being divisible by workers desired_shards = world_size * num_shards_per_rank - # Find the smallest multiple of total_workers that is >= desired_shards + # Round up to next multiple of total_workers for even distribution if desired_shards % total_workers == 0: num_shards = desired_shards else: @@ -207,14 +216,14 @@ def _setup_hf_dataset( self._ds = ds def __iter__(self) -> Iterator[dict[str, Any]]: - """Iterate through the dataset infinitely. - - It will restart from the beginning after exhausting the dataset. - - If shuffle_buffer_size is set, it will shuffle the dataset at the beginning of each epoch - when set_epoch is called. - - An additional metric "num_epochs" is added to the sample. + """Infinite iteration over dataset samples. + + Behavior: + - Restarts from beginning when dataset is exhausted + - Reshuffles at start of each epoch (if enabled) + - Applies full transform pipeline to each sample + - Adds 'num_epochs' metric to track dataset progress + - Yields samples indefinitely for continuous training """ while True: # Infinite iteration @@ -224,9 +233,10 @@ def __iter__(self) -> Iterator[dict[str, Any]]: try: for sample in epoch_iterator: - # NOTE: We apply transforms here instead of using .map() call - # to work around https://github.com/huggingface/datasets/issues/7630 - # where .map() can cause incorrect resumption from a checkpoint. + # NOTE: We apply transforms here instead of using .map() to work around + # HuggingFace datasets bug where .map() causes incorrect checkpoint resumption. + # See: https://github.com/huggingface/datasets/issues/7630 + # This ensures transforms are applied fresh on each sample during iteration. sample = self._apply_transforms(sample) # Track the number of epochs completed for each dataset. This is @@ -246,9 +256,10 @@ def __iter__(self) -> Iterator[dict[str, Any]]: yield sample except StopIteration: - pass # Iterator is exhausted, which is expected. + # Expected when dataset is exhausted + pass except Exception as e: - logger.warning( + logger.error( f"Dataset {self.info.name} encountered an unexpected error: {e}." ) raise @@ -263,9 +274,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: self._num_epochs += 1 def state_dict(self) -> dict[str, Any]: - """ - The dataset returns its own state directly, without namespacing. - """ + """Returns dataset checkpoint state.""" hf_state = self._ds.state_dict() state = { "num_epochs": self._num_epochs, diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py index 2267696ef4..fe911aff51 100644 --- a/torchtune/datasets/_interleaved.py +++ b/torchtune/datasets/_interleaved.py @@ -17,17 +17,18 @@ class InterleavedDataset(InfiniteTuneIterableDataset): - """Infinitely interleaves multiple TuneIterableDatasets according to their sampling weights. - - The weights are extracted from each dataset's info.weight property and normalized to sum to 1.0. - - This dataset is responsible for managing the state of its child datasets - to ensure correct checkpointing and resumption. + """Infinitely interleaves multiple datasets according to their sampling weights. + + The weights are extracted from each dataset's ``info.weight`` property and + normalized to sum to 1.0. This dataset manages the state of its child + datasets to ensure correct checkpointing and resumption. Args: - datasets (list[InfiniteTuneIterableDataset]): list of datasets to interleave. - seed (int): Seed for sampling. - weight (float): Weight for this dataset. Defaults to 1.0. - dataset_name (str): Name of the dataset. Defaults to "interleaved_dataset". - sampling_log_maxlen (int): Maximum length of the sampling log. + datasets (list[InfiniteTuneIterableDataset]): A list of datasets to interleave. + seed (int): The seed for sampling. + weight (float): The weight for this dataset. Defaults to 1.0. + dataset_name (str): The name of the dataset. Defaults to "interleaved_dataset". + sampling_log_maxlen (int): The maximum length of the sampling log. """ def __init__( @@ -100,7 +101,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: yield next(child_iters[ds_name]) def state_dict(self) -> dict[str, Any]: - """Save state for the interleaver and its children.""" + """Save interleaver state and all child dataset states.""" # The parent is responsible for namespacing the child states child_states = {ds.info.name: ds.state_dict() for ds in self._datasets} return { diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py index a26c22eafa..0f412e80dc 100644 --- a/torchtune/datasets/_iterable_base.py +++ b/torchtune/datasets/_iterable_base.py @@ -13,9 +13,33 @@ @dataclass(frozen=True) class DatasetInfo: - """Represents hierarchical information about a dataset, including its name, - sampling weight and children. Children is a common case when composing datasets, - e.g. Packed(InterleavedDataset([ds1, ds2])). + """Hierarchical metadata for datasets, enabling composition and weight tracking. + + Used to build tree structures when composing datasets. For example, a nested + `InterleavedDataset` dataset would have this structure: + + Example: + .. code-block:: python + + DatasetInfo(name='parent_interleaved', + weight=1.0, + children=(DatasetInfo(name='child_interleaved', + weight=0.7, + children=(DatasetInfo(name='dataset_a', + weight=0.6, + children=()), + DatasetInfo(name='dataset_b', + weight=0.4, + children=()))), + DatasetInfo(name='dataset_c', weight=0.3, children=()))) + + This hierarchical structure is used for validation (ensuring unique dataset + names) and for logging metrics. + + Attributes: + name (str): Unique identifier for the dataset + weight (float): Sampling weight for dataset selection (default: 1.0) + children (tuple[DatasetInfo, ...]): Nested datasets for composed structures """ name: str @@ -24,10 +48,16 @@ class DatasetInfo: class TuneIterableDataset(IterableDataset, ABC): - """Abstract base class for all torchtune iterable datasets. - It defines the minimal, consistent interface required for all dataset - implementations to ensure they are compatible with the training loop, - checkpointing, and metric logging systems. + """Base class for all torchtune iterable datasets. + + Datasets are composable, enabling complex structures such as: + ``PackedDataset(InterleavedDataset([InterleavedDataset([ds1, ds2]), ds3]))`` + + Each dataset implementation must: + - Track hierarchical metadata via the ``info`` property + - Ensure unique dataset names across the entire tree + - Handle checkpointing: parents resume children's state + - Provide proper state management for exact resumption """ @property @@ -64,19 +94,20 @@ def __iter__(self) -> Iterator[dict[str, Any]]: @abstractmethod def state_dict(self) -> dict[str, Any]: - """Returns a state dictionary for checkpointing""" + """Returns checkpoint state for dataset resumption.""" pass @abstractmethod def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Load state from a state dictionary, used when resuming from a checkpoint.""" + """Restores dataset state from checkpoint.""" pass class InfiniteTuneIterableDataset(TuneIterableDataset): - """Abstract base class for infinite datasets, which yield samples indefinitely. - It only purpose is to make it explicit that the dataset is expected to be infinite, i.e. - it never exhausts. This is helpful to avoid complexity due to some rank hanging because - of lack of data""" + """Base class for infinite datasets that never exhaust. + + Prevents distributed training hangs by ensuring all ranks always + have data available. Datasets restart from beginning when exhausted. + """ pass