diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index bb765f1917..f825e9194e 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -26,21 +26,30 @@ tokenizer: path: /tmp/Llama-3.2-3B-Instruct/original/tokenizer.model max_seq_len: null -# Dataset and Sampler +# Dataloader +dataloader: + batch_size: 16 + # num_workers and pin_memory can be added here if needed + +# Dataset - now a list to support multiple weighted sources dataset: - _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False # True increases speed - split: train[:95%] -seed: null -shuffle: True -batch_size: 4 - -# Validation -run_val_every_n_steps: null # Change to an integer to enable validation every N steps -dataset_val: - _component_: torchtune.datasets.alpaca_cleaned_dataset - split: train[95%:] -batch_size_val: ${batch_size} + - _component_: torchtune.datasets.slimorca_iterable_dataset + shuffle_buffer_size: 1000 + weight: 0.8 + split: train[:5%] # simular 1 epoch quickly + - _component_: torchtune.datasets.alpaca_iterable_dataset + shuffle_buffer_size: 1000 + weight: 0.2 + split: train[:5%] # simular 1 epoch quickly + +# Packing (TBD by follow up PR) +# packing: +# _component_: torchtune.datasets.packing.SFTPacking +# max_seq_len: 8192 + +seed: 42 + +# Validation not supported yet with iterable datasets # Model Arguments model: @@ -65,10 +74,11 @@ optimizer: loss: _component_: torchtune.modules.loss.LinearCrossEntropyLoss -# Training -epochs: 1 -max_steps_per_epoch: null -gradient_accumulation_steps: 8 # Use to increase effective batch size +# Training - now step-based +num_training_steps: 100 # Total number of training steps to run +save_every_n_steps: 200 # Save a checkpoint every N steps. Using 200 to avoid ckpt. +gradient_accumulation_steps: 1 +dataset_metrics_log_freq: 5 # Log dataset-specific metrics every N steps # Environment device: cuda @@ -83,7 +93,7 @@ optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_ste # Logging metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger + _component_: torchtune.training.metric_logging.WandBLogger log_dir: ${output_dir}/logs log_every_n_steps: 1 log_peak_memory_stats: True diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 10e0aaeb24..444596619b 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -22,11 +22,11 @@ from torch.optim import Optimizer from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp from torchdata.stateful_dataloader import StatefulDataLoader -from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchtune import config, modules, training, utils from torchtune.config._utils import _get_component_from_path from torchtune.data import padded_collate_packed -from torchtune.datasets import ConcatDataset +from torchtune.data.metrics import MetricsAggregator +from torchtune.datasets import InterleavedDataset from torchtune.modules.embedding_utils import resize_token_embeddings from torchtune.modules.loss import SFTLoss from torchtune.modules.moe import utils as moe_utils @@ -207,7 +207,7 @@ def __init__(self, cfg: DictConfig) -> None: self._checkpoint_client = CheckpointClient(cfg) self._enable_fp8_training = cfg.get("enable_fp8_training", False) self._fp8_recipe_name = cfg.get("fp8_recipe_name", None) - self.save_every_n_steps = cfg.get("save_every_n_steps") + self.save_every_n_steps = cfg.get("save_every_n_steps", None) self._run_val_every_n_steps = cfg.get("run_val_every_n_steps", None) if self._run_val_every_n_steps is not None: @@ -273,18 +273,19 @@ def __init__(self, cfg: DictConfig) -> None: self.seed = training.set_seed( seed=cfg.seed, debug_mode=cfg.get("cudnn_deterministic_mode", None) ) - self.epochs_run = 0 - self.total_epochs = cfg.epochs - self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 + # Step-based training support + self.num_training_steps = cfg.num_training_steps + self._metrics_aggregator = None # Will be initialized in setup + def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: """ Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict[training.EPOCHS_KEY] - self.global_step = ckpt_dict[training.STEPS_KEY] + # The new format stores steps directly + self.global_step = ckpt_dict["steps_run"] # on mismatch, warn the user and prevent the override if self.seed != ckpt_dict[training.SEED_KEY]: @@ -295,23 +296,6 @@ def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None: ) ) self.seed = ckpt_dict[training.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: - warn( - message=( - "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" - ) - ) - self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] - - # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: - warn( - message=( - "Config value for total_epochs does not match the checkpoint value, " - f"using the config value: {self.total_epochs}" - ) - ) except KeyError as e: raise KeyError( @@ -324,6 +308,11 @@ def setup(self, cfg: DictConfig) -> None: Setup the recipe. This includes training state (if resume_from_checkpoint is True), model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. """ + if cfg.get("dataset_val") is not None: + raise NotImplementedError( + "Validation is not supported yet with iterable datasets since it currently requiresinfinite datasets." + ) + if self.fsdp_cpu_offload: # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU @@ -434,13 +423,15 @@ def setup(self, cfg: DictConfig) -> None: utils.log_rank_zero(self._logger, "Loss is initialized.") + # Initialize metrics aggregator for dataset metrics tracking + self._metrics_aggregator = MetricsAggregator() + # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, - shuffle=cfg.shuffle, - batch_size=cfg.batch_size, + cfg_dataloader=cfg.dataloader, collate_fn=collate_name, dataloader_state_dict=( state_dict[training.DATALOADER_KEY] @@ -452,12 +443,10 @@ def setup(self, cfg: DictConfig) -> None: # Setup validation dataloader if validation dataset is provided self._val_dataloader = None if cfg.get("dataset_val") is not None: - batch_size_val = cfg.get("batch_size_val", cfg.batch_size) self._val_dataloader = self._setup_data( cfg_dataset=cfg.dataset_val, - batch_size=batch_size_val, + cfg_dataloader=cfg.get("dataloader_val", None), collate_fn=collate_name, - shuffle=False, dataloader_state_dict=( state_dict[training.VAL_DATALOADER_KEY] if training.VAL_DATALOADER_KEY in state_dict @@ -465,38 +454,13 @@ def setup(self, cfg: DictConfig) -> None: ), ) - # Finally update the recipe state which can only be correctly set after all of the - # other components have been initialized and updated. - # - # Number of training steps in each epoch depends on the number of batches produced - # by the dataloader, the max_steps_per_epoch param set by the user and the - # gradient_accumulation_steps param. This value is used for logging and tracking - # training state. The computation should happen after the dataloader has been setup - self._steps_per_epoch = ( - len(self._dataloader) // self._gradient_accumulation_steps - ) - if ( - self.max_steps_per_epoch is not None - and self.max_steps_per_epoch < self._steps_per_epoch - ): - self._steps_per_epoch = self.max_steps_per_epoch - - if self.save_every_n_steps is None: - self.save_every_n_steps = self._steps_per_epoch - self.checkpoint_dir_prefix = "epoch" - else: - self.checkpoint_dir_prefix = "step" - - if ( - self._resume_from_checkpoint - and self.global_step % self._steps_per_epoch == 0 - ): - list(self._dataloader) + # Set checkpoint dir prefix to step-based + self.checkpoint_dir_prefix = "step" # Setup lr scheduler self._lr_scheduler = self._setup_lr_scheduler( cfg_lr_scheduler=cfg.get("lr_scheduler", None), - num_training_steps=self.total_epochs * self._steps_per_epoch, + num_training_steps=self.num_training_steps, last_epoch=self.global_step - 1, ) @@ -799,53 +763,66 @@ def _setup_optimizer( def _setup_data( self, - cfg_dataset: DictConfig, - shuffle: bool, - batch_size: int, + cfg_dataset: Union[DictConfig, ListConfig], + cfg_dataloader: DictConfig, collate_fn: str, dataloader_state_dict: Optional[dict[str, Any]] = None, ) -> StatefulDataLoader: """ - All data related setup happens here. This recipe currently supports only - map-style datasets. If a state_dict is provided (meaning we are resuming a training run), - it is loaded into the dataloader. + Set up the dataloader for iterable datasets. """ - if isinstance(cfg_dataset, ListConfig): - datasets = [ - config.instantiate(single_cfg_dataset, self._tokenizer) - for single_cfg_dataset in cfg_dataset - ] - ds = ConcatDataset(datasets=datasets) - packed = getattr(ds, "packed", False) + + # 1. Create all datasets + iterable_datasets = [] + cfg_dataset_list = cfg_dataset + if not isinstance(cfg_dataset_list, ListConfig): + cfg_dataset_list = [cfg_dataset_list] + + for ds_cfg in cfg_dataset_list: + ds = config.instantiate(ds_cfg, model_transform=self._tokenizer) + iterable_datasets.append(ds) + + # 2. Interleave datasets if any + if len(iterable_datasets) > 1: + ds = InterleavedDataset( + datasets=iterable_datasets, + seed=self.seed, + ) else: - ds = config.instantiate(cfg_dataset, self._tokenizer) - packed = cfg_dataset.get("packed", False) + ds = iterable_datasets[0] - # Instantiate collate_fn - if "left_pad_sequence" in collate_fn: - raise RuntimeError("left_pad_sequence collator is only for inference.") - collate_fn = _get_component_from_path(collate_fn) + # 3. Apply packing + # TODO: follow up PR + packed = False - sampler = StatefulDistributedSampler( - ds, num_replicas=self.dp_degree, rank=self.dp_rank, shuffle=shuffle, seed=0 + # 4. Define a collate function wrapper to handle metrics + base_collate_fn = ( + padded_collate_packed if packed else _get_component_from_path(collate_fn) ) + + def _collate_with_metrics_wrapper( + batch: list[dict[str, Any]] + ) -> dict[str, Any]: + # TODO: handling of metrics should prob be done in collate_fn. + # putting this here for now to avoid making more changes to this PR. + all_metrics = [] + clean_batch = [] + for sample in batch: + if "metrics" in sample: + all_metrics.extend(sample.pop("metrics")) + clean_batch.append(sample) + + collated_batch = base_collate_fn(clean_batch) + collated_batch["metrics"] = all_metrics + return collated_batch + + # 5. Create DataLoader dataloader = StatefulDataLoader( dataset=ds, - batch_size=batch_size, - sampler=sampler, - collate_fn=( - partial( - collate_fn, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - pad_to_multiple_of=self.parallel_dims.min_seq_len_divisor, - ) - if not packed - else padded_collate_packed - ), - # dropping last avoids shape issues with compile + flex attention - drop_last=True, + collate_fn=_collate_with_metrics_wrapper, + **cfg_dataloader, ) + if dataloader_state_dict is not None: dataloader.load_state_dict(dataloader_state_dict) @@ -918,9 +895,7 @@ def validate(self) -> dict[str, float]: return log_dict def save_checkpoint(self, *, epoch: int, full_tensors: bool): - if self.global_step % self._steps_per_epoch == 0: - epoch += 1 - + """Save checkpoint based on global step.""" self._checkpoint_client.save_checkpoint( model=self._model, optimizer=( @@ -930,19 +905,21 @@ def save_checkpoint(self, *, epoch: int, full_tensors: bool): ), training_progress=TrainingProgress( seed=self.seed, - epochs_run=epoch, - total_epochs=self.total_epochs, - max_steps_per_epoch=self.max_steps_per_epoch, + epochs_run=0, # TODO: not needed. To be deprecated. + total_epochs=1, # TODO: not needed. To be deprecated. + max_steps_per_epoch=-1, # TODO: not needed. To be deprecated. steps_run=self.global_step, - total_training_steps=self.total_epochs * self._steps_per_epoch, + total_training_steps=self.num_training_steps, dataloader_state_dict=self._dataloader.state_dict(), val_dataloader_state_dict=( self._val_dataloader.state_dict() if self._val_dataloader is not None else {} ), + # FIXME: add to load_ckpt and TrainingProgress too + metrics_aggregator_state_dict=self._metrics_aggregator.state_dict(), ), - epoch=epoch, + epoch=epoch, # TODO: not needed. To be deprecated. single_device=False, full_tensors=full_tensors, dir_prefix=self.checkpoint_dir_prefix, @@ -968,180 +945,172 @@ def train(self) -> None: num_tokens = 0 self._profiler.start() - # self.epochs_run should be non-zero when we're resuming from a checkpoint - for curr_epoch in range(self.epochs_run, self.total_epochs): - inner_step_count = self.global_step % self._steps_per_epoch - pbar = tqdm( - initial=inner_step_count, - total=self._steps_per_epoch, - desc=f"{self.epochs_run}|{self.global_step}", - ) - # Get iterator for the dataloader - self._dataloader.sampler.set_epoch(curr_epoch) - dataloader_iter = iter(self._dataloader) - batch_count = 0 + pbar = tqdm( + initial=self.global_step, total=self.num_training_steps, desc="Training" + ) - # Continue looping until we reach max steps or exhaust the dataset - while inner_step_count < self._steps_per_epoch: - # Try to get the next batch, break if we've reached the end of the dataset - try: - batch = next(dataloader_iter) - except StopIteration: - break + dataloader_iter = iter(self._dataloader) + batch_count = 0 - # Start tracking CUDA memory for active steps for just the first epoch + while self.global_step < self.num_training_steps: + try: + batch = next(dataloader_iter) + except StopIteration: + self._logger.warning( + "Dataloader iterator exhausted unexpectedly. Ending training." + ) + break + + if "metrics" in batch: + self._metrics_aggregator.update(batch.pop("metrics")) + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and self.profiler_profile_memory + and batch_count == self.profiler_wait_steps + self.profiler_warmup_steps + and self._device.type == "cuda" + ): + torch.cuda.memory._record_memory_history() + + utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum() + num_tokens += current_num_tokens + + with self.train_context( + self.context_parallel_manager(list(batch.values())) + ): + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + current_loss = self._loss_step(batch) * current_num_tokens + running_loss += current_loss + # For optimizer in backward, we need to normalize before calling backward + # This case and gradient accumulation are mutually exclusive + if self._optimizer_in_bwd: + torch.distributed.all_reduce(num_tokens) + torch.distributed.all_reduce(running_loss) + current_loss = current_loss * (self.dp_degree / num_tokens) + current_loss.backward() + + # Optimizer step (if not fused in backward call) + if (batch_count + 1) % self._gradient_accumulation_steps == 0: + if not self._optimizer_in_bwd: + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + + # Manually scale the gradients from unnormalized loss by total # of tokens + self._grad_scaler( + list(self._model.parameters()), + self.world_size / num_tokens, + False if self.parallel_dims.tp_enabled else None, + ) + + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + # If sharded, collect the DTensor here + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.full_tensor() + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Step the learning rate scheduler + if self._lr_scheduler is not None: + self._lr_scheduler.step() + + self.global_step += 1 + # If float8 training is enabled, perform a single all-reduce to compute the + # scale for all float8 parameters efficiently instead of doing many small + # all-reduces for each parameter if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and batch_count - == self.profiler_wait_steps + self.profiler_warmup_steps - and self._device.type == "cuda" + self._enable_fp8_training + and is_fp8_tensorwise_scaling(self._fp8_recipe_name) + and self.dp_degree > 1 ): - torch.cuda.memory._record_memory_history() - - utils.batch_to_device(batch, self._device) - - # Calculate the number of unmasked tokens in the current batch - # and increment the total number of tokens seen in the step - current_num_tokens = ( - batch["labels"] != self._loss_fn.ignore_index - ).sum() - num_tokens += current_num_tokens + precompute_float8_dynamic_scale_for_fsdp(self._model) - with self.train_context( - self.context_parallel_manager(list(batch.values())) - ): - # Loss is normalized by default so we multiply by the number of tokens - # This way we can normalize by the total number of tokens if we're accumulating gradients - current_loss = self._loss_step(batch) * current_num_tokens - running_loss += current_loss - # For optimizer in backward, we need to normalize before calling backward - # This case and gradient accumulation are mutually exclusive - if self._optimizer_in_bwd: - torch.distributed.all_reduce(num_tokens) - torch.distributed.all_reduce(running_loss) - current_loss = current_loss * (self.dp_degree / num_tokens) - current_loss.backward() - - # Optimizer step (if not fused in backward call) - if (batch_count + 1) % self._gradient_accumulation_steps == 0: - if not self._optimizer_in_bwd: - # Get total number of tokens across all ranks to normalize gradients - torch.distributed.all_reduce(num_tokens) - # This will ensure that the logged loss matches what we're optimizing - torch.distributed.all_reduce(running_loss) - - # Manually scale the gradients from unnormalized loss by total # of tokens - self._grad_scaler( - list(self._model.parameters()), - self.world_size / num_tokens, - False if self.parallel_dims.tp_enabled else None, - ) + loss_to_log = running_loss.detach().item() / num_tokens + pbar.update(1) + pbar.set_description( + f"Step: {self.global_step}|Loss: {loss_to_log:.4f}" + ) - if self._clip_grad_norm is not None: - grad_norm = torch.nn.utils.clip_grad_norm_( - self._model.parameters(), - max_norm=float(self._clip_grad_norm), - ) - # If sharded, collect the DTensor here - if isinstance(grad_norm, DTensor): - grad_norm = grad_norm.full_tensor() - self._optimizer.step() - self._optimizer.zero_grad(set_to_none=True) - - # Step the learning rate scheduler - if self._lr_scheduler is not None: - self._lr_scheduler.step() - - self.global_step += 1 - inner_step_count += 1 - - # If float8 training is enabled, perform a single all-reduce to compute the - # scale for all float8 parameters efficiently instead of doing many small - # all-reduces for each parameter - if ( - self._enable_fp8_training - and is_fp8_tensorwise_scaling(self._fp8_recipe_name) - and self.dp_degree > 1 - ): - precompute_float8_dynamic_scale_for_fsdp(self._model) - - loss_to_log = running_loss.detach().item() / num_tokens - pbar.update(1) - pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + # Log per-step metrics + if self.global_step % self._log_every_n_steps == 0: + # Get dataset metrics outside of rank zero check since it involves all_gather + dataset_metrics = self._metrics_aggregator.get_metrics_for_logging( + prefix="train" ) - # Log per-step metrics - if ( - self.global_step % self._log_every_n_steps == 0 - and self._is_rank_zero - ): + if self._is_rank_zero: time_per_step = time.perf_counter() - t0 log_dict = { "loss": loss_to_log, "lr": get_lr( - ( - self._optimizer - if not self._optimizer_in_bwd - else self._optim_ckpt_wrapper - ), + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper ), "tokens_per_second_per_gpu": ( num_tokens / self.parallel_dims.non_data_parallel_size ) / (time_per_step * self.world_size), } + if dataset_metrics: + log_dict.update(dataset_metrics) if self._log_peak_memory_stats: log_dict.update( training.get_memory_stats(device=self._device) ) if self._clip_grad_norm is not None: log_dict.update({"grad_norm": grad_norm}) - self._metric_logger.log_dict( - log_dict, - step=self.global_step, - ) + self._metric_logger.log_dict(log_dict, step=self.global_step) - # Save checkpoint if specified by user - if self.global_step % self.save_every_n_steps == 0: - self.save_checkpoint(epoch=curr_epoch, full_tensors=False) - - # Reset running stats for the next step - running_loss = 0 - num_tokens = 0 - t0 = time.perf_counter() - - # Stop tracking CUDA memory now that active steps are complete + # Save checkpoint if specified by user if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and batch_count - == self.profiler_wait_steps - + self.profiler_warmup_steps - + self.profiler_active_steps - and self._device.type == "cuda" + self.save_every_n_steps is not None + and self.global_step % self.save_every_n_steps == 0 ): - torch.cuda.memory._record_memory_history(enabled=None) - - self._profiler.step() - batch_count += 1 - - # Run validation after gradient update - if ( - self._run_val_every_n_steps is not None - and self.global_step % self._run_val_every_n_steps == 0 - ): - pbar.refresh() - self.validate() - - self.epochs_run += 1 + self.save_checkpoint(epoch=0, full_tensors=False) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and self.profiler_profile_memory + and batch_count + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + and self._device.type == "cuda" + ): + torch.cuda.memory._record_memory_history(enabled=None) + + self._profiler.step() + batch_count += 1 + + # Run validation after gradient update + if ( + self._run_val_every_n_steps is not None + and self.global_step % self._run_val_every_n_steps == 0 + ): + pbar.refresh() + self.validate() self._profiler.stop() - self.save_checkpoint(epoch=curr_epoch, full_tensors=True) + self.save_checkpoint(epoch=0, full_tensors=True) def cleanup(self) -> None: if self._is_rank_zero: diff --git a/tests/torchtune/data/test_metrics_aggregator.py b/tests/torchtune/data/test_metrics_aggregator.py new file mode 100644 index 0000000000..db2ab3f617 --- /dev/null +++ b/tests/torchtune/data/test_metrics_aggregator.py @@ -0,0 +1,463 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for MetricsAggregator functionality. + +This module tests the metrics collection and aggregation system including: +- All aggregation types (SUM, MEAN, MAX, MIN, DISTRIBUTION, CATEGORICAL_COUNT) +- State management and checkpointing +- Multi-dataset metric namespacing +- Distributed metrics aggregation +- Metric consistency validation + +Uses synthetic metrics to verify correct aggregation behavior across scenarios. +""" + +import logging + +import pytest +import torch.distributed as dist +from tests.test_utils import gpu_test +from torch.testing._internal.common_fsdp import FSDPTest + +from torchtune.data.metrics import AggregationType, Metric, MetricsAggregator + + +class TestMetricsAggregator: + """Tests for MetricsAggregator core functionality and edge cases.""" + + @pytest.mark.parametrize( + "agg_type,test_values,expected", + [ + (AggregationType.SUM, [1, 2, 3, 4], 10), + (AggregationType.MEAN, [10, 20, 30, 40], 25.0), + (AggregationType.MAX, [-5, 10, 3, 15], 15), + (AggregationType.MIN, [5, -2, 8, 1], -2), + ( + AggregationType.CATEGORICAL_COUNT, + ["A", "B", "A", "C", "A"], + {"A": 3, "B": 1, "C": 1}, + ), + ], + ) + def test_aggregation_types(self, agg_type, test_values, expected): + """Tests each AggregationType with representative data to verify correct computation. + + Covers aggregation types: + - SUM: Simple addition across values + - MEAN: Average computation with proper count tracking + - MAX/MIN: Extrema identification + - CATEGORICAL_COUNT: Category frequency counting + """ + aggregator = MetricsAggregator() + + metrics = [ + Metric( + dataset_name="test", metric_name="metric", value=val, agg_type=agg_type + ) + for val in test_values + ] + aggregator.update(metrics) + + result = aggregator.get_metrics_for_logging(prefix="train") + + if agg_type == AggregationType.CATEGORICAL_COUNT: + for category, count in expected.items(): + assert result[f"train_test/metric_count_{category}"] == count + else: + assert result["train_test/metric"] == expected + + def test_distribution_metrics(self): + """Tests that DISTRIBUTION aggregation computes statistics (mean, min, max, percentiles).""" + aggregator = MetricsAggregator() + values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + metrics = [ + Metric("test", "dist_metric", val, AggregationType.DISTRIBUTION) + for val in values + ] + aggregator.update(metrics) + + result = aggregator.get_metrics_for_logging(prefix="train") + + # Verify distribution statistics + assert result["train_test/dist_metric_stat_mean"] == 5.5 + assert result["train_test/dist_metric_stat_min"] == 1 + assert result["train_test/dist_metric_stat_max"] == 10 + assert result["train_test/dist_metric_stat_p50"] == 5.5 + + def test_state_management(self): + """Test metrics aggregator state persistence and restoration for checkpointing scenarios.""" + # Create aggregator with mixed metric types to test state saving + aggregator1 = MetricsAggregator() + initial_metrics = [ + Metric("ds1", "counter", 10, AggregationType.SUM), + Metric("ds1", "average", 5.0, AggregationType.MEAN), + Metric("ds2", "categories", "X", AggregationType.CATEGORICAL_COUNT), + ] + aggregator1.update(initial_metrics) + + # Save state + state = aggregator1.state_dict() + + # Create new aggregator and restore state + aggregator2 = MetricsAggregator() + aggregator2.load_state_dict(state) + + # Both should have identical metrics + metrics1 = aggregator1.get_metrics_for_logging(prefix="train") + metrics2 = aggregator2.get_metrics_for_logging(prefix="train") + assert metrics1 == metrics2 + + # Continue updating both - should remain identical + additional_metrics = [ + Metric("ds1", "counter", 5, AggregationType.SUM), + Metric("ds1", "average", 15.0, AggregationType.MEAN), + ] + aggregator1.update(additional_metrics) + aggregator2.update(additional_metrics) + + final_metrics1 = aggregator1.get_metrics_for_logging(prefix="train") + final_metrics2 = aggregator2.get_metrics_for_logging(prefix="train") + assert final_metrics1 == final_metrics2 + + # Verify expected values + assert final_metrics1["train_ds1/counter"] == 15 # 10 + 5 + assert final_metrics1["train_ds1/average"] == 10.0 # (5 + 15) / 2 + + def test_multiple_datasets(self): + """Test that metrics from multiple datasets are correctly namespaced.""" + aggregator = MetricsAggregator() + + metrics = [ + Metric("dataset1", "samples", 100, AggregationType.SUM), + Metric("dataset2", "samples", 200, AggregationType.SUM), + Metric("dataset1", "tokens", 1000, AggregationType.SUM), + Metric("dataset2", "tokens", 2000, AggregationType.SUM), + ] + aggregator.update(metrics) + + result = aggregator.get_metrics_for_logging(prefix="train") + + assert result["train_dataset1/samples"] == 100 + assert result["train_dataset2/samples"] == 200 + assert result["train_dataset1/tokens"] == 1000 + assert result["train_dataset2/tokens"] == 2000 + + def test_empty_aggregator(self): + """Test that empty aggregator returns empty metrics.""" + aggregator = MetricsAggregator() + result = aggregator.get_metrics_for_logging(prefix="train") + assert result == {} + + def test_prefix_handling(self): + """Test that prefix is correctly applied to metric keys.""" + aggregator = MetricsAggregator() + metrics = [ + Metric("test_ds", "metric1", 42, AggregationType.SUM), + Metric("test_ds", "metric2", 84, AggregationType.SUM), + ] + aggregator.update(metrics) + + # Test with prefix + result_with_prefix = aggregator.get_metrics_for_logging(prefix="validation") + assert result_with_prefix["validation_test_ds/metric1"] == 42 + assert result_with_prefix["validation_test_ds/metric2"] == 84 + + # Test without prefix (uses default "data") + result_no_prefix = aggregator.get_metrics_for_logging() + assert result_no_prefix["data_test_ds/metric1"] == 42 + assert result_no_prefix["data_test_ds/metric2"] == 84 + + def test_metric_consistency_validation(self): + """Test that same metric name must use same aggregation type.""" + aggregator = MetricsAggregator() + + # First metric with SUM aggregation + metrics1 = [Metric("test", "my_metric", 10, AggregationType.SUM)] + aggregator.update(metrics1) + + # Try to use same metric name with different aggregation type - should fail + metrics2 = [Metric("test", "my_metric", 5.0, AggregationType.MEAN)] + with pytest.raises( + ValueError, match="is already registered with aggregation type sum" + ): + aggregator.update(metrics2) + + # Same metric name with same aggregation type should work + metrics3 = [Metric("test", "my_metric", 20, AggregationType.SUM)] + aggregator.update(metrics3) # Should not raise + + result = aggregator.get_metrics_for_logging(prefix="train") + assert result["train_test/my_metric"] == 30 # 10 + 20 + + def test_metric_consistency_across_datasets(self): + """Test that same metric name can use different aggregation types across different datasets.""" + aggregator = MetricsAggregator() + + # Same metric name but different datasets - should be allowed + metrics = [ + Metric("dataset1", "metric", 10, AggregationType.SUM), + Metric("dataset2", "metric", 5.0, AggregationType.MEAN), + ] + aggregator.update(metrics) # Should not raise + + result = aggregator.get_metrics_for_logging(prefix="train") + assert result["train_dataset1/metric"] == 10 + assert result["train_dataset2/metric"] == 5.0 + + def test_handler_generated_metric_validation(self): + """Test that handler-generated metrics are validated for consistency.""" + aggregator = MetricsAggregator() + + # Create a user-defined metric that will conflict with distribution stats + user_metrics = [ + Metric("test", "dist_metric_stat_mean", 42, AggregationType.SUM) + ] + aggregator.update(user_metrics) + + # Now try to add a distribution metric that will generate conflicting stat names + dist_metrics = [Metric("test", "dist_metric", 10, AggregationType.DISTRIBUTION)] + aggregator.update(dist_metrics) + + # This should fail when trying to get metrics for logging because the handler + # will try to create "dist_metric_stat_mean" which conflicts with the user metric + with pytest.raises( + ValueError, match="is already registered with aggregation type sum" + ): + aggregator.get_metrics_for_logging(prefix="train") + + def test_handler_replacement_warning(self, caplog): + """Test that replacing handlers in use generates a warning.""" + aggregator = MetricsAggregator() + + # Add a metric that uses SUM aggregation + metrics = [Metric("test", "sum_metric", 10, AggregationType.SUM)] + aggregator.update(metrics) + + # Replace the SUM handler - should generate warning + from torchtune.data.metrics._metric_agg_handlers import SumAggHandler + + with caplog.at_level(logging.WARNING): + aggregator.register_handler(AggregationType.SUM, SumAggHandler()) + + # Check that the expected warning was logged + assert len(caplog.records) == 1 + assert "Replacing handler for AggregationType.SUM" in caplog.records[0].message + + +class TestDistributedMetricsAggregator(FSDPTest): + """Distributed tests for MetricsAggregator using FSDPTest infrastructure.""" + + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_distributed_all_aggregation_types(self): + """ + Test that all aggregation types work correctly in distributed setting. + Each rank contributes different values to ensure proper reduction across ranks. + """ + aggregator = MetricsAggregator() + rank = dist.get_rank() + + # Each rank contributes different values to test cross-rank aggregation + base_value = (rank + 1) * 10 # rank 0: 10, rank 1: 20 + + metrics = [ + Metric("test", "sum_metric", base_value, AggregationType.SUM), + Metric("test", "mean_metric", base_value + 5, AggregationType.MEAN), + Metric("test", "max_metric", base_value * 10, AggregationType.MAX), + Metric("test", "min_metric", base_value // 2, AggregationType.MIN), + ] + + # DISTRIBUTION: Each rank adds 5 values for distribution statistics + # rank 0: [0, 1, 2, 3, 4], rank 1: [10, 11, 12, 13, 14] + for i in range(5): + metrics.append( + Metric( + "test", "dist_metric", rank * 10 + i, AggregationType.DISTRIBUTION + ) + ) + + # CATEGORICAL_COUNT: Different categories per rank to test counting + # rank 0: 3 of cat_A, 2 of cat_B + # rank 1: 1 of cat_A, 4 of cat_C + if rank == 0: + metrics.extend( + [ + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT + ), + ] + ) + else: + metrics.extend( + [ + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + ] + ) + + # Update aggregator and get results + aggregator.update(metrics) + result = aggregator.get_metrics_for_logging(prefix="train") + + # Verify aggregation results across all ranks + # SUM: rank 0 adds 10, rank 1 adds 20 -> total 30 + # MEAN: rank 0 has 15, rank 1 has 25 -> avg 20 + # MAX: rank 0 has 100, rank 1 has 200 -> max 200 + # MIN: rank 0 has 5, rank 1 has 10 -> min 5 + assert result["train_test/sum_metric"] == 30 + assert result["train_test/mean_metric"] == 20 + assert result["train_test/max_metric"] == 200 + assert result["train_test/min_metric"] == 5 + + # DISTRIBUTION: Combined values [0,1,2,3,4,10,11,12,13,14] + # Mean should be average of local means: (2 + 12) / 2 = 7 + assert result["train_test/dist_metric_stat_mean"] == 7 + assert result["train_test/dist_metric_stat_min"] == 0 + assert result["train_test/dist_metric_stat_max"] == 14 + + # CATEGORICAL_COUNT: Total counts across ranks + # cat_A: 3(rank0) + 1(rank1) = 4, cat_B: 2(rank0) + 0(rank1) = 2, cat_C: 0(rank0) + 4(rank1) = 4 + assert result["train_test/cat_metric_count_cat_A"] == 4 + assert result["train_test/cat_metric_count_cat_B"] == 2 + assert result["train_test/cat_metric_count_cat_C"] == 4 + + @gpu_test(gpu_count=2) + def test_distributed_state_dict_resumption(self): + """ + Test that MetricsAggregator state_dict save/restore works correctly in distributed setting. + Verifies: + - State can be saved after partial updates across ranks + - State can be restored consistently across ranks + - Continued updates after restore produce identical results + - Distributed aggregation works correctly after restoration + """ + rank = dist.get_rank() + + # Phase 1: Create aggregator and add initial metrics + aggregator1 = MetricsAggregator() + + # Each rank contributes different initial values + base_value = rank * 100 # rank 0: 0, rank 1: 100 + + initial_metrics = [ + Metric("test", "sum_metric", base_value, AggregationType.SUM), + Metric("test", "mean_metric", base_value // 2, AggregationType.MEAN), + Metric("test", "max_metric", base_value * 2, AggregationType.MAX), + ] + + # Add some DISTRIBUTION values - each rank adds 3 values + for i in range(3): + initial_metrics.append( + Metric( + "test", "dist_metric", rank * 100 + i, AggregationType.DISTRIBUTION + ) + ) + + # Add CATEGORICAL_COUNT values + if rank == 0: + initial_metrics.extend( + [ + Metric( + "test", + "cat_metric", + "type_A", + AggregationType.CATEGORICAL_COUNT, + ), + Metric( + "test", + "cat_metric", + "type_A", + AggregationType.CATEGORICAL_COUNT, + ), + ] + ) + else: + initial_metrics.extend( + [ + Metric( + "test", + "cat_metric", + "type_B", + AggregationType.CATEGORICAL_COUNT, + ), + Metric( + "test", + "cat_metric", + "type_B", + AggregationType.CATEGORICAL_COUNT, + ), + Metric( + "test", + "cat_metric", + "type_B", + AggregationType.CATEGORICAL_COUNT, + ), + ] + ) + + aggregator1.update(initial_metrics) + + # Save state_dict after initial update + state_dict = aggregator1.state_dict() + + # Phase 2: Create new aggregator and restore from state_dict + aggregator2 = MetricsAggregator() + aggregator2.load_state_dict(state_dict) + + # Verify both aggregators produce identical results after restore + result1 = aggregator1.get_metrics_for_logging(prefix="checkpoint") + result2 = aggregator2.get_metrics_for_logging(prefix="checkpoint") + assert ( + result1 == result2 + ), f"Rank {rank}: Aggregators differ after state_dict restore" + + # Phase 3: Add more metrics to both aggregators + additional_metrics = [ + Metric("test", "sum_metric", rank * 1000, AggregationType.SUM), + Metric("test", "min_metric", rank * 1000, AggregationType.MIN), + ] + + # Update both aggregators with additional metrics + aggregator1.update(additional_metrics) + aggregator2.update(additional_metrics) + + # Phase 4: Verify final results are identical across both aggregators + final_result1 = aggregator1.get_metrics_for_logging(prefix="final") + final_result2 = aggregator2.get_metrics_for_logging(prefix="final") + assert ( + final_result1 == final_result2 + ), f"Rank {rank}: Final results differ after continued updates" diff --git a/tests/torchtune/data/test_metrics_transform.py b/tests/torchtune/data/test_metrics_transform.py new file mode 100644 index 0000000000..ebfb1c81a1 --- /dev/null +++ b/tests/torchtune/data/test_metrics_transform.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests cover: +- DefaultTrainingMetricTransform +- Basic metric generation (samples_seen, tokens_seen, seq_len) +- Dataset name validation and requirements +- Proper metric type assignment and aggregation configuration +""" + +import pytest + +from torchtune.data.metrics import AggregationType, DefaultTrainingMetricTransform + + +class TestDefaultTrainingMetricTransform: + """Tests for DefaultTrainingMetricTransform functionality.""" + + def test_dataset_name_not_set_raises_error(self): + """Test that the transform raises a RuntimeError if used before + `set_dataset_name` is called, ensuring that metrics are always + correctly attributed to a dataset.""" + transform = DefaultTrainingMetricTransform() + sample = {"tokens": [1, 2, 3]} + + with pytest.raises(RuntimeError, match="set_dataset_name"): + transform(sample) + + def test_basic_metrics_generation(self): + """Test that transform generates expected training metrics for input samples.""" + transform = DefaultTrainingMetricTransform() + # Set dataset name required for metric generation + transform.set_dataset_name("test_dataset") + + sample = {"tokens": [1, 2, 3, 4, 5]} + result = transform(sample) + + # Transform should preserve original sample data unchanged + assert result["tokens"] == [1, 2, 3, 4, 5] + + # Should generate exactly 3 metrics: samples_seen, tokens_seen, seq_len + assert "metrics" in result + metrics = result["metrics"] + assert len(metrics) == 3 + + # Verify each metric has correct properties and aggregation type + for metric in metrics: + if metric.metric_name == "samples_seen": + assert metric.dataset_name == "test_dataset" + assert metric.value == 1 + assert metric.agg_type == AggregationType.SUM + + elif metric.metric_name == "tokens_seen": + assert metric.dataset_name == "test_dataset" + assert metric.value == 5 + assert metric.agg_type == AggregationType.SUM + + elif metric.metric_name == "seq_len": + assert metric.dataset_name == "test_dataset" + assert metric.value == 5 + assert metric.agg_type == AggregationType.DISTRIBUTION diff --git a/tests/torchtune/datasets/test_hf_iterable.py b/tests/torchtune/datasets/test_hf_iterable.py new file mode 100644 index 0000000000..adcede297a --- /dev/null +++ b/tests/torchtune/datasets/test_hf_iterable.py @@ -0,0 +1,383 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for HfIterableDataset core functionality. + +This module tests the foundational iterable dataset capabilities including: +- Basic iteration and data loading +- Epoch boundary handling and tracking +- Shuffling behavior across epochs +- Checkpointing and state restoration +- Distributed training scenarios + +Uses synthetic JSON data with predictable patterns to verify correct behavior. +""" + +import math +import shutil +import tempfile +from itertools import islice +from pathlib import Path + +import pytest +import torch.distributed as dist +from tests.test_utils import gpu_test +from torch.testing._internal.common_fsdp import FSDPTest + +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchtune.data.metrics import DefaultTrainingMetricTransform, MetricsAggregator +from torchtune.datasets import HfIterableDataset + +from .test_iterable_utils import collate_with_metrics, generate_ckpt + +# Test Constants - Avoid perfect divisions +SMALL_DATASET_SIZE = 23 +MEDIUM_DATASET_SIZE = 35 +SEED = 42 +BATCH_SIZE = 5 +DEFAULT_SHUFFLE_BUFFER_SIZE = 8 + + +def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None: + """Creates a dummy JSON test data file with token samples of varying lengths. + + Args: + path (Path): The path to the file to create + num_samples (int): The number of samples to create + offset (int): The offset to add to the sample ID to ensure unique IDs in different datasets + """ + with open(path, "w") as f: + for i in range(num_samples): + sample_id = i + offset + # Realistic token length variation (1-3 tokens) + token_len = (i % 3) + 1 + tokens = list(range(sample_id, sample_id + token_len)) + f.write( + f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}", "labels": {tokens}}}\n' + ) + + +@pytest.fixture +def tmp_data_dir(tmp_path): + """Provide temporary directory for test data files.""" + return tmp_path + + +@pytest.fixture +def small_dataset_file(tmp_data_dir): + path = tmp_data_dir / "small_data.json" + create_test_json_file(path, SMALL_DATASET_SIZE, offset=0) + return str(path) + + +@pytest.fixture +def dataset_factory(): + """Factory for creating HfIterableDataset instances with common defaults.""" + + def _create_dataset( + data_file: str, + dataset_name: str = "test_dataset", + shuffle: bool = False, + **kwargs, + ) -> HfIterableDataset: + return HfIterableDataset( + path="json", + data_files=data_file, + split="train", + dataset_name=dataset_name, + seed=SEED, + shuffle_buffer_size=10 if shuffle else 0, + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=2, + **kwargs, + ) + + return _create_dataset + + +class TestHfIterableDataset: + """Tests for HfIterableDataset basic functionality.""" + + def test_default_dataset_name(self, small_dataset_file): + """Test that dataset name is auto-generated from path when not provided.""" + # Create dataset without specifying name + dataset = HfIterableDataset( + path="json", + data_files=small_dataset_file, + split="train", + # dataset_name not provided - should auto-generate + seed=SEED, + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=4, + ) + + # Should generate name from path and split + assert dataset.info.name == "json_train" + # Test default sampling weight + assert dataset.info.weight == 1.0 + + # Test giving a name and custom weight + custom_weight = 2.5 + dataset2 = HfIterableDataset( + path="json", + data_files=small_dataset_file, + split="train", + dataset_name="my_dataset", + weight=custom_weight, + seed=SEED, + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=4, + ) + + # Should use provided name and weight + assert dataset2.info.name == "my_dataset" + # Test custom sampling weight + assert dataset2.info.weight == custom_weight + + @pytest.mark.parametrize("num_epochs", [0.5, 1.0, 2.5]) + def test_epoch_boundaries_and_checkpointing( + self, num_epochs, dataset_factory, small_dataset_file + ): + """ + Tests that for N epochs, each sample appears exactly N times (rounded down), + the epoch metric is correct, and checkpointing works as expected. + """ + + # 1. Setup Dataloaders and Aggregators for original and resumed runs + def create_loader_and_aggregator(): + dataset = dataset_factory(small_dataset_file, shuffle=False) + loader = StatefulDataLoader( + dataset, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics + ) + aggregator = MetricsAggregator() + return loader, aggregator + + loader1, aggregator1 = create_loader_and_aggregator() + loader2, aggregator2 = create_loader_and_aggregator() + + # 2. Calculate steps for the test run + total_samples = int(SMALL_DATASET_SIZE * num_epochs) + total_steps = total_samples // BATCH_SIZE + + steps_before_checkpoint = max(1, total_steps // 2) + steps_after_checkpoint = total_steps - steps_before_checkpoint + + # 3. Generate checkpoint and resume + result = generate_ckpt( + loader1, + aggregator1, + steps_before_checkpoint=steps_before_checkpoint, + steps_after_checkpoint=steps_after_checkpoint, + resume_dataloader=loader2, + resume_aggregator=aggregator2, + ) + + # 4. Verify checkpointing and resumption + orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] + assert ( + orig_post_ids == resumed_ids + ), "Resumed batches should be identical for deterministic run" + assert ( + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics should match" + + def test_shuffling_behavior(self, dataset_factory, small_dataset_file): + """Tests that shuffling changes data order between epochs but preserves the set of samples.""" + # Test unshuffled dataset + unshuffled_ds = dataset_factory( + small_dataset_file, dataset_name="unshuffled", shuffle=False + ) + + # Get samples from two passes through the dataset + epoch_samples = list(islice(iter(unshuffled_ds), SMALL_DATASET_SIZE * 2)) + + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] + second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] + + # Unshuffled should have same order in both epochs + first_epoch_ids = [sample["id"] for sample in first_epoch_samples] + second_epoch_ids = [sample["id"] for sample in second_epoch_samples] + assert first_epoch_ids == list(range(SMALL_DATASET_SIZE)) + assert second_epoch_ids == list(range(SMALL_DATASET_SIZE)) + + # Test shuffled dataset + shuffled_ds = dataset_factory( + small_dataset_file, dataset_name="shuffled", shuffle=True + ) + + # Collect full epochs to compare + epoch_samples = list(islice(iter(shuffled_ds), SMALL_DATASET_SIZE * 2)) + + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] + second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] + + # Extract IDs for comparison + first_epoch_ids = [sample["id"] for sample in first_epoch_samples] + second_epoch_ids = [sample["id"] for sample in second_epoch_samples] + + # Shuffled epochs should have different order + assert first_epoch_ids != list( + range(SMALL_DATASET_SIZE) + ), f"Shuffled should not be sorted, got {first_epoch_ids}" + assert ( + first_epoch_ids != second_epoch_ids + ), f"Shuffled epochs should be shuffled differently, got {first_epoch_ids} and {second_epoch_ids}" + + # But should contain the same set of IDs + assert set(first_epoch_ids) == set( + range(SMALL_DATASET_SIZE) + ), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_ids}" + assert set(second_epoch_ids) == set( + range(SMALL_DATASET_SIZE) + ), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_ids}" + + def test_epoch_tracking(self, dataset_factory, small_dataset_file): + """Test that epoch number is correctly tracked across dataset restarts.""" + dataset = dataset_factory(small_dataset_file, shuffle=False) + + # Two epoch samples + epoch_samples = list(islice(iter(dataset), SMALL_DATASET_SIZE * 2)) + + first_epoch_samples = epoch_samples[:SMALL_DATASET_SIZE] + second_epoch_samples = epoch_samples[SMALL_DATASET_SIZE:] + + # All should have epoch 0 + first_epoch_metrics = [] + for sample in first_epoch_samples: + first_epoch_metrics.extend(sample["metrics"]) + epoch_values = [ + metric.value + for metric in first_epoch_metrics + if metric.metric_name == "epoch" + ] + assert all( + epoch_value == 0 for epoch_value in epoch_values + ), f"Epoch values should be 0, got {epoch_values}" + + # All should have epoch 1 + second_epoch_metrics = [] + for sample in second_epoch_samples: + second_epoch_metrics.extend(sample["metrics"]) + epoch_values = [ + metric.value + for metric in second_epoch_metrics + if metric.metric_name == "epoch" + ] + assert all( + epoch_value == 1 for epoch_value in epoch_values + ), f"Epoch values should be 1, got {epoch_values}" + + +class TestDistributedHfIterableDataset(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_distributed_epoch_boundary_checkpointing(self): + """ + Test epoch boundary handling with checkpointing in distributed setting. + Ensures proper handling of: + - Checkpointing at 0.9, 1.0, and 2.5 epoch boundaries + - Correct sample distribution across epochs + - Proper state restoration after checkpointing + """ + rank = dist.get_rank() + + # Create shared temp directory (only rank 0 creates it) + if rank == 0: + temp_dir = tempfile.mkdtemp(prefix="epoch_test_") + else: + temp_dir = "" + + # Broadcast temp directory path to all ranks + temp_dir_list = [temp_dir] + dist.broadcast_object_list(temp_dir_list, src=0) + temp_dir = temp_dir_list[0] + tmp_path = Path(temp_dir) + + try: + medium_dataset_file = tmp_path / "medium_data.json" + + # Only rank 0 creates the data file, all ranks read from it + if rank == 0: + create_test_json_file(medium_dataset_file, MEDIUM_DATASET_SIZE) + dist.barrier() # Wait for file creation + + # Test multiple epoch boundaries + for num_epochs in [0.9, 1.0, 2.5]: + + def create_loader_and_aggregator(): + dataset = HfIterableDataset( + path="json", + data_files=str(medium_dataset_file), + split="train", + dataset_name="epoch_test", + seed=SEED, + shuffle_buffer_size=0, # No shuffle for determinism + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=2, + ) + loader = StatefulDataLoader( + dataset, + batch_size=BATCH_SIZE, + collate_fn=collate_with_metrics, + num_workers=0, + ) + return loader, MetricsAggregator() + + loader1, aggregator1 = create_loader_and_aggregator() + loader2, aggregator2 = create_loader_and_aggregator() + + # Calculate steps to reach desired epoch boundary + samples_per_rank = MEDIUM_DATASET_SIZE // dist.get_world_size() + total_samples = int(samples_per_rank * num_epochs) + total_steps = total_samples // BATCH_SIZE + + if total_steps < 2: + raise ValueError( + f"Not enough steps for meaningful test: {total_steps}" + ) + + # Split steps between before and after checkpoint + steps_before = max(1, total_steps // 2) + steps_after = total_steps - steps_before + + result = generate_ckpt( + loader1, + aggregator1, + steps_before, + steps_after, + resume_dataloader=loader2, + resume_aggregator=aggregator2, + ) + + # Verify deterministic resumption - critical for distributed training + orig_post_ids = [ + b["id"].tolist() for b in result["post_checkpoint_batches"] + ] + resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] + assert orig_post_ids == resumed_ids, ( + f"Rank {rank}: Non-deterministic resume for {num_epochs} epochs. " + f"This indicates checkpoint/resume state is not properly preserved." + ) + + # Verify epoch metric is correctly tracked + final_metrics = result["final_metrics"] + expected_epoch = math.floor( + num_epochs - 1e-9 + ) # -1e-9 so 1.0 epochs -> 0 + assert ( + final_metrics["train_epoch_test/num_epochs"] == expected_epoch + ), f"Epoch count incorrect for {num_epochs} epochs test scenario" + + finally: + # Clean up temp directory (only rank 0) + if rank == 0: + shutil.rmtree(temp_dir) diff --git a/tests/torchtune/datasets/test_interleaved.py b/tests/torchtune/datasets/test_interleaved.py new file mode 100644 index 0000000000..37bda9adcc --- /dev/null +++ b/tests/torchtune/datasets/test_interleaved.py @@ -0,0 +1,625 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for InterleavedDataset functionality. + +This module tests the multi-dataset interleaving capabilities, including: +- Dataset composition with weighted sampling +- Nested interleaving structures +- Metrics collection and aggregation across datasets +- Checkpointing and state restoration +- Distributed training scenarios + +The tests use synthetic JSON data with distinct ID ranges per dataset +to verify correct sampling ratios and data isolation. +""" + +import shutil +import tempfile +from itertools import islice +from pathlib import Path +from unittest.mock import patch + +import pytest + +import torch +import torch.distributed as dist +from tests.test_utils import gpu_test +from torch.testing._internal.common_fsdp import FSDPTest +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchtune.data.metrics import DefaultTrainingMetricTransform, MetricsAggregator +from torchtune.datasets import HfIterableDataset, InterleavedDataset + +# Import test utilities +from .test_iterable_utils import collate_with_metrics, generate_ckpt + +# Test Constants +SMALL_DATASET_SIZE = 23 +MEDIUM_DATASET_SIZE = 35 +LARGE_DATASET_SIZE = 47 +SEED = 42 +BATCH_SIZE = 5 + + +def create_test_json_file(path: Path, num_samples: int, offset: int = 0) -> None: + """Creates a dummy JSON test data file with token samples of varying lengths. + + Args: + path (Path): The path to the file to create + num_samples (int): The number of samples to create + offset (int): The offset to add to the sample ID to ensure unique IDs in different datasets + """ + with open(path, "w") as f: + for i in range(num_samples): + sample_id = i + offset + # Realistic token length variation (1-3 tokens) + token_len = (i % 3) + 1 + tokens = list(range(sample_id, sample_id + token_len)) + f.write( + f'{{"id": {sample_id}, "tokens": {tokens}, "text": "sample_{sample_id}", "labels": {tokens}}}\n' + ) + + +@pytest.fixture +def tmp_data_dir(tmp_path): + """Provide temporary directory for test data files. + All test datasets are created in this isolated directory to avoid conflicts.""" + return tmp_path + + +@pytest.fixture +def small_dataset_file(tmp_data_dir): + """Create small dataset (23 samples) with IDs 0-22 for testing basic functionality.""" + path = tmp_data_dir / "small_data.json" + create_test_json_file(path, SMALL_DATASET_SIZE, offset=0) + return str(path) + + +@pytest.fixture +def medium_dataset_file(tmp_data_dir): + """Create medium dataset (35 samples) with IDs 100-134 for multi-dataset testing.""" + path = tmp_data_dir / "medium_data.json" + create_test_json_file(path, MEDIUM_DATASET_SIZE, offset=100) + return str(path) + + +@pytest.fixture +def large_dataset_file(tmp_data_dir): + """Create large dataset (47 samples) with IDs 1000-1046 for nested interleaving tests.""" + path = tmp_data_dir / "large_data.json" + create_test_json_file(path, LARGE_DATASET_SIZE, offset=1000) + return str(path) + + +@pytest.fixture +def dataset_factory(): + """Factory for creating HfIterableDataset instances with common defaults.""" + + def _create_dataset( + data_file: str, + dataset_name: str = "test_dataset", + shuffle: bool = False, + **kwargs, + ) -> HfIterableDataset: + return HfIterableDataset( + path="json", + data_files=data_file, + split="train", + dataset_name=dataset_name, + seed=SEED, + shuffle_buffer_size=10 if shuffle else 0, + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=2, + **kwargs, + ) + + return _create_dataset + + +class TestInterleavedDataset: + """Tests for multi-dataset interleaving functionality.""" + + def test_initialization_validation(self, dataset_factory, small_dataset_file): + """Tests that the dataset raises errors for invalid configurations, like duplicate names.""" + + # Test 1: Duplicate dataset names should raise an error + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) + ds2 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) + + with pytest.raises( + ValueError, match="Duplicate dataset names found in hierarchy" + ): + InterleavedDataset(datasets=[ds1, ds2], seed=SEED) + + # Test 2: Nested interleaved datasets should be supported + ds3 = dataset_factory(small_dataset_file, dataset_name="ds3", weight=1.5) + interleaved_child = InterleavedDataset( + [ds1, ds3], seed=SEED, dataset_name="interleaved_child" + ) + + # Create a parent interleaved dataset containing the nested one + ds4 = dataset_factory(small_dataset_file, dataset_name="ds4", weight=0.5) + + # Test 3: Weight normalization should work with a warning + with patch("logging.Logger.warning") as mock_warning: + interleaved_parent = InterleavedDataset( + [interleaved_child, ds4], seed=SEED, dataset_name="interleaved_parent" + ) + + # Verify that a warning was logged about weight normalization + mock_warning.assert_called_once() + warning_message = mock_warning.call_args[0][0] + assert "normalized" in warning_message.lower() + + # Verify the hierarchical structure is correct + assert interleaved_parent.info.name == "interleaved_parent" + assert len(interleaved_parent.info.children) == 2 + # Datasets are sorted alphabetically, so ds4 comes before interleaved_child + assert interleaved_parent.info.children[0].name == "ds4" + assert interleaved_parent.info.children[1].name == "interleaved_child" + + # Verify the nested structure within the nested dataset + # interleaved_child is at index 1 due to alphabetical sorting + nested_info = interleaved_parent.info.children[1] + assert len(nested_info.children) == 2 + assert nested_info.children[0].name == "ds1" + assert nested_info.children[1].name == "ds3" + + # Verify that sampling weights are normalized to sum to 1.0 + # Access the internal normalized weights tensor + normalized_weights = interleaved_parent._normalized_weights + assert isinstance(normalized_weights, torch.Tensor) + assert len(normalized_weights) == 2 + + # ds4: 0.5/(0.5+1.0) = 1/3, interleaved_child: 1.0/(0.5+1.0) = 2/3 + assert abs(normalized_weights[0].item() - 1 / 3) < 1e-3 + assert abs(normalized_weights[1].item() - 2 / 3) < 1e-3 + assert abs(normalized_weights.sum().item() - 1.0) < 1e-6 + + # Verify that original weights in info remain unnormalized + child_weights = [child.weight for child in interleaved_parent.info.children] + assert abs(child_weights[0] - 0.5) < 1e-6 # ds4 original weight + assert ( + abs(child_weights[1] - 1.0) < 1e-6 + ) # interleaved_child original weight + + def test_single_dataset(self, dataset_factory, small_dataset_file): + """Tests that InterleavedDataset works correctly with a single dataset.""" + # Create a single dataset + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.5) + + # Should work without issues + interleaved = InterleavedDataset([ds1], seed=SEED) + + # Verify the hierarchical structure + assert interleaved.info.name == "interleaved_dataset" # default name + assert len(interleaved.info.children) == 1 + assert interleaved.info.children[0].name == "ds1" + assert interleaved.info.children[0].weight == 0.5 + + # Verify normalized weights sum to 1.0 (single dataset gets weight 1.0) + normalized_weights = interleaved._normalized_weights + assert isinstance(normalized_weights, torch.Tensor) + assert len(normalized_weights) == 1 + assert abs(normalized_weights[0].item() - 1.0) < 1e-6 + + # Test that iteration works correctly + samples = list(islice(iter(interleaved), 10)) + assert len(samples) == 10 + + # All samples should come from the single dataset (ds1 has IDs 0-22) + sample_ids = {sample["id"] for sample in samples} + expected_ids = set(range(10)) # ds1 has IDs 0-22 + assert sample_ids == expected_ids + + def test_sampling_ratios( + self, + dataset_factory, + small_dataset_file, + medium_dataset_file, + large_dataset_file, + ): + """Tests that datasets are sampled according to their assigned weights in nested structure.""" + # Create three datasets with distinct ID ranges + # ds1 has IDs 0-22, ds2 has IDs 100-134, ds3 has IDs 1000-1046 + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.3) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.7) + ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) + + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) + child_interleaved = InterleavedDataset( + [ds1, ds2], seed=SEED, dataset_name="child" + ) + parent_interleaved = InterleavedDataset( + [child_interleaved, ds3], seed=SEED, dataset_name="parent" + ) + + # Collect 400 samples + sample_count = 400 + samples = list(islice(iter(parent_interleaved), sample_count)) + + # Count samples by checking ID ranges + ds1_count = sum(1 for s in samples if 0 <= s["id"] < SMALL_DATASET_SIZE) + ds2_count = sum( + 1 for s in samples if 100 <= s["id"] < (MEDIUM_DATASET_SIZE + 100) + ) + ds3_count = sum( + 1 for s in samples if 1000 <= s["id"] < (LARGE_DATASET_SIZE + 1000) + ) + + assert ds1_count + ds2_count + ds3_count == sample_count + + # Calculate ratios + ds1_ratio = ds1_count / sample_count + ds2_ratio = ds2_count / sample_count + ds3_ratio = ds3_count / sample_count + + # Expected ratios based on nested weighting: + # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 + # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each + # Final ratios: ds1=0.5*0.3=0.15, ds2=0.5*0.7=0.35, ds3=0.5 + expected_ds1_ratio = 0.15 + expected_ds2_ratio = 0.35 + expected_ds3_ratio = 0.5 + + # Allow 10% tolerance due to randomness + assert ( + abs(ds1_ratio - expected_ds1_ratio) < 0.1 + ), f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert ( + abs(ds2_ratio - expected_ds2_ratio) < 0.1 + ), f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert ( + abs(ds3_ratio - expected_ds3_ratio) < 0.1 + ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + + def test_metrics_aggregation( + self, + dataset_factory, + small_dataset_file, + medium_dataset_file, + large_dataset_file, + ): + """Tests that metrics from all child datasets are collected and aggregated in nested structure.""" + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.2) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.8) + ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) + + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) + child_interleaved = InterleavedDataset( + [ds1, ds2], seed=SEED, dataset_name="child" + ) + parent_interleaved = InterleavedDataset( + [child_interleaved, ds3], seed=SEED, dataset_name="parent" + ) + + aggregator = MetricsAggregator() + + # Process some samples + total_samples = 300 + for sample in islice(iter(parent_interleaved), total_samples): + aggregator.update(sample["metrics"]) + + metrics = aggregator.get_metrics_for_logging(prefix="train") + + # Should have metrics from all three datasets, with flat keys + assert "train_ds1/samples_seen" in metrics + assert "train_ds2/samples_seen" in metrics + assert "train_ds3/samples_seen" in metrics + + # All datasets should have contributed samples + assert metrics["train_ds1/samples_seen"] > 0 + assert metrics["train_ds2/samples_seen"] > 0 + assert metrics["train_ds3/samples_seen"] > 0 + + # Total samples should equal what we processed + calculated_total_samples = ( + metrics["train_ds1/samples_seen"] + + metrics["train_ds2/samples_seen"] + + metrics["train_ds3/samples_seen"] + ) + assert calculated_total_samples == total_samples + + # Test that ratios are approximately correct based on nested weighting + ds1_ratio = metrics["train_ds1/samples_seen"] / total_samples + ds2_ratio = metrics["train_ds2/samples_seen"] / total_samples + ds3_ratio = metrics["train_ds3/samples_seen"] / total_samples + + # Expected ratios based on nested weighting: + # Inner weights: ds1=0.2, ds2=0.8 -> inner total=1.0 + # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each + # Final ratios: ds1=0.5*0.2=0.1, ds2=0.5*0.8=0.4, ds3=0.5 + expected_ds1_ratio = 0.1 + expected_ds2_ratio = 0.4 + expected_ds3_ratio = 0.5 + + # Allow 10% tolerance due to randomness + assert ( + abs(ds1_ratio - expected_ds1_ratio) < 0.1 + ), f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert ( + abs(ds2_ratio - expected_ds2_ratio) < 0.1 + ), f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert ( + abs(ds3_ratio - expected_ds3_ratio) < 0.1 + ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + + def test_checkpointing( + self, + dataset_factory, + small_dataset_file, + medium_dataset_file, + large_dataset_file, + ): + """Tests that interleaved dataset checkpointing preserves sampling state in nested structure.""" + + def create_interleaved(): + ds1 = dataset_factory(small_dataset_file, dataset_name="ds1", weight=0.3) + ds2 = dataset_factory(medium_dataset_file, dataset_name="ds2", weight=0.7) + ds3 = dataset_factory(large_dataset_file, dataset_name="ds3", weight=1.0) + + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) + child_interleaved = InterleavedDataset( + [ds1, ds2], seed=SEED, dataset_name="child" + ) + return InterleavedDataset( + [child_interleaved, ds3], seed=SEED, dataset_name="parent" + ) + + # Original run + interleaved1 = create_interleaved() + loader1 = StatefulDataLoader( + interleaved1, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics + ) + aggregator1 = MetricsAggregator() + + # Resumed run + interleaved2 = create_interleaved() + loader2 = StatefulDataLoader( + interleaved2, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics + ) + aggregator2 = MetricsAggregator() + + result = generate_ckpt( + loader1, + aggregator1, + steps_before_checkpoint=10, + steps_after_checkpoint=20, + resume_dataloader=loader2, + resume_aggregator=aggregator2, + ) + + orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] + resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] + assert ( + orig_post_ids == resumed_ids + ), "Resumed batches should be identical for deterministic run" + assert ( + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics should match" + + # Test sampling log functionality + # Check that sampling log contains tuples of (iteration_count, dataset_name) + state_dict = interleaved1.state_dict() + sampling_log = state_dict["sampling_log"] + iteration_count = state_dict["iteration_count"] + + assert len(sampling_log) > 0, "Sampling log should not be empty" + assert iteration_count > 0, "Iteration count should be positive" + + # Check sampling ratios by analyzing the actual samples processed during the test + # Since the sampling log only shows immediate children ("child", "ds3"), + # we need to look at the actual sample IDs to determine leaf dataset usage + + # Collect all sample IDs from the batches processed during checkpointing + all_sample_ids = [] + for batch_list in [ + result["pre_checkpoint_batches"], + result["post_checkpoint_batches"], + ]: + for batch in batch_list: + all_sample_ids.extend(batch["id"].tolist()) + + # Count samples by ID ranges: ds1 has IDs 0-22, ds2 has IDs 100-134, ds3 has IDs 1000-1046 + ds1_count = sum(1 for id in all_sample_ids if 0 <= id < SMALL_DATASET_SIZE) + ds2_count = sum( + 1 for id in all_sample_ids if 100 <= id < (MEDIUM_DATASET_SIZE + 100) + ) + ds3_count = sum( + 1 for id in all_sample_ids if 1000 <= id < (LARGE_DATASET_SIZE + 1000) + ) + total_samples = ds1_count + ds2_count + ds3_count + ds1_ratio = ds1_count / total_samples + ds2_ratio = ds2_count / total_samples + ds3_ratio = ds3_count / total_samples + + # Expected ratios based on nested weighting: + # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 + # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each + # Final ratios: ds1=0.5*0.3=0.15, ds2=0.5*0.7=0.35, ds3=0.5 + expected_ds1_ratio = 0.15 + expected_ds2_ratio = 0.35 + expected_ds3_ratio = 0.5 + + # Allow larger tolerance due to small sample size in checkpointing test + assert ( + abs(ds1_ratio - expected_ds1_ratio) < 0.2 + ), f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert ( + abs(ds2_ratio - expected_ds2_ratio) < 0.2 + ), f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert ( + abs(ds3_ratio - expected_ds3_ratio) < 0.2 + ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + + +class TestDistributedInterleavedDataset(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_distributed_interleaved_checkpointing(self): + """ + Test interleaved dataset checkpointing with distributed settings using nested structure. + Assertions: + - Each rank processes non-overlapping data shards + - Sampling ratios for nested structure (ds1: 15%, ds2: 35%, ds3: 50%) are maintained across ranks + - Checkpoint/resume produces identical batches (deterministic) + - Metrics correctly aggregate across ranks + """ + rank = dist.get_rank() + + # Create shared temp directory (only rank 0 creates it) + if rank == 0: + temp_dir = tempfile.mkdtemp(prefix="interleaved_test_") + else: + temp_dir = None + + # Broadcast temp directory to all ranks + temp_dir_list = [temp_dir] if temp_dir is not None else [""] + dist.broadcast_object_list(temp_dir_list, src=0) + temp_dir = temp_dir_list[0] + tmp_path = Path(temp_dir) + + try: + + def create_dataset(): + file1 = tmp_path / "ds1.json" + file2 = tmp_path / "ds2.json" + file3 = tmp_path / "ds3.json" + + # Only rank 0 creates the data files + if rank == 0: + create_test_json_file(file1, SMALL_DATASET_SIZE) # IDs 0-22 + create_test_json_file( + file2, MEDIUM_DATASET_SIZE, offset=100 + ) # IDs 100-134 + create_test_json_file( + file3, LARGE_DATASET_SIZE, offset=1000 + ) # IDs 1000-1046 + dist.barrier() # Wait for file creation + + ds1 = HfIterableDataset( + path="json", + data_files=str(file1), + split="train", + dataset_name="ds1", + shuffle_buffer_size=0, # No shuffle for determinism + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=2, + weight=0.3, + ) + ds2 = HfIterableDataset( + path="json", + data_files=str(file2), + split="train", + dataset_name="ds2", + shuffle_buffer_size=0, # No shuffle for determinism + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=2, + weight=0.7, + ) + ds3 = HfIterableDataset( + path="json", + data_files=str(file3), + split="train", + dataset_name="ds3", + shuffle_buffer_size=0, # No shuffle for determinism + metric_transform=DefaultTrainingMetricTransform(), + num_shards_per_rank=2, + weight=1.0, + ) + + # Create nested structure: interleaved([interleaved([ds1, ds2]), ds3]) + child_interleaved = InterleavedDataset( + [ds1, ds2], seed=SEED, dataset_name="child" + ) + return InterleavedDataset( + [child_interleaved, ds3], seed=SEED, dataset_name="parent" + ) + + def create_dataloader(dataset): + loader = StatefulDataLoader( + dataset, + batch_size=BATCH_SIZE, + num_workers=0, # Avoid multiprocessing in distributed tests + collate_fn=collate_with_metrics, + ) + return loader, MetricsAggregator() + + # Run checkpointing test with small number of steps + loader1, aggregator1 = create_dataloader(create_dataset()) + loader2, aggregator2 = create_dataloader(create_dataset()) + + result = generate_ckpt( + loader1, + aggregator1, + 3, + 3, # 3 steps before, 3 steps after checkpoint + resume_dataloader=loader2, + resume_aggregator=aggregator2, + ) + + # Verify deterministic resumption + orig_post_ids = [ + b["id"].tolist() for b in result["post_checkpoint_batches"] + ] + resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] + assert orig_post_ids == resumed_ids, ( + f"Rank {rank}: Non-deterministic interleaved resume. " + f"This indicates sampling state is not properly preserved." + ) + assert ( + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics don't match resumed metrics - aggregator state issue" + + # Verify sampling ratio is approximately maintained for nested structure + all_ids = [] + for batch in ( + result["pre_checkpoint_batches"] + result["post_checkpoint_batches"] + ): + all_ids.extend(batch["id"].tolist()) + + # Count samples by ID ranges: ds1 has IDs 0-22, ds2 has IDs 100-134, ds3 has IDs 1000-1046 + ds1_samples = sum(1 for id in all_ids if 0 <= id < SMALL_DATASET_SIZE) + ds2_samples = sum( + 1 for id in all_ids if 100 <= id < (MEDIUM_DATASET_SIZE + 100) + ) + ds3_samples = sum( + 1 for id in all_ids if 1000 <= id < (LARGE_DATASET_SIZE + 1000) + ) + total_samples = ds1_samples + ds2_samples + ds3_samples + + if total_samples > 0: + ds1_ratio = ds1_samples / total_samples + ds2_ratio = ds2_samples / total_samples + ds3_ratio = ds3_samples / total_samples + + # Expected ratios based on nested weighting: + # Inner weights: ds1=0.3, ds2=0.7 -> inner total=1.0 + # Outer weights: inner=1.0, ds3=1.0 -> normalized to 0.5 each + # Final ratios: ds1=0.5*0.3=0.15, ds2=0.5*0.7=0.35, ds3=0.5 + expected_ds1_ratio = 0.15 + expected_ds2_ratio = 0.35 + expected_ds3_ratio = 0.5 + + assert ( + abs(ds1_ratio - expected_ds1_ratio) < 0.1 + ), f"ds1 ratio {ds1_ratio:.2f} should be ~{expected_ds1_ratio}" + assert ( + abs(ds2_ratio - expected_ds2_ratio) < 0.1 + ), f"ds2 ratio {ds2_ratio:.2f} should be ~{expected_ds2_ratio}" + assert ( + abs(ds3_ratio - expected_ds3_ratio) < 0.1 + ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" + + finally: + # Clean up temp directory (only rank 0) + if rank == 0: + shutil.rmtree(temp_dir) diff --git a/tests/torchtune/datasets/test_iterable_utils.py b/tests/torchtune/datasets/test_iterable_utils.py new file mode 100644 index 0000000000..28c6d8e464 --- /dev/null +++ b/tests/torchtune/datasets/test_iterable_utils.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import torch + +from torch.utils.data import DataLoader +from torchtune.data.metrics import MetricsAggregator + + +def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: + """Simple collate that extracts metrics and pads tokens.""" + all_metrics = [] + clean_batch = [] + for sample in batch: + if "metrics" in sample: + all_metrics.extend(sample.pop("metrics")) + clean_batch.append(sample) + + if not clean_batch: + return {"metrics": all_metrics} + + # Simple padding for tokens + ids = torch.tensor([item["id"] for item in clean_batch]) + tokens = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(item["tokens"]) for item in clean_batch], + batch_first=True, + padding_value=-1, # Use -1 for padding to distinguish from valid IDs + ) + collated = { + "id": ids, + "tokens": tokens, + } + + # Add text field for non-tensor data + if "text" in clean_batch[0]: + collated["text"] = [item["text"] for item in clean_batch] + + collated["metrics"] = all_metrics + return collated + + +def generate_ckpt( + dataloader: DataLoader, + aggregator: MetricsAggregator, + steps_before_checkpoint: int, + steps_after_checkpoint: int, + resume_dataloader: Optional[DataLoader] = None, + resume_aggregator: Optional[MetricsAggregator] = None, +) -> dict[str, Any]: + """ + Generates a checkpoint by running through data and saving checkpoint mid-stream. + Optionally, a second dataloader and aggregator can be given to resume from ckpt + and run steps_after_checkpoint to match the first one. + + Args: + dataloader (DataLoader): The dataloader to test + aggregator (MetricsAggregator): The metrics aggregator to use + steps_before_checkpoint (int): Number of steps to run before saving checkpoint + steps_after_checkpoint (int): Number of steps to run after checkpoint + resume_dataloader (Optional[DataLoader]): Optional new dataloader to test resuming. + If None, returns empty resumed_batches. + resume_aggregator (Optional[MetricsAggregator]): Optional new aggregator to test resuming. + If None, returns empty resumed_metrics. + + Returns: + dict[str, Any]: Dict with batches/metrics from both pre and post checkpoint runs. + """ + iterator = iter(dataloader) + + # Collect batches before and after checkpoint + batches = [] + checkpoint_state = None + metrics_at_checkpoint = {} + + total_steps = steps_before_checkpoint + steps_after_checkpoint + + for idx, batch in enumerate(iterator): + batches.append(batch) + + # Process metrics + if "metrics" in batch: + aggregator.update(batch.pop("metrics")) + + # Save checkpoint state after steps_before_checkpoint + if idx == steps_before_checkpoint - 1: # -1 because idx is 0-based + checkpoint_state = { + "loader": dataloader.state_dict(), + "aggregator": aggregator.state_dict(), + } + metrics_at_checkpoint = aggregator.get_metrics_for_logging(prefix="train") + + # Stop after total steps + if idx == total_steps - 1: + break + + # Split batches + pre_checkpoint_batches = batches[:steps_before_checkpoint] + post_checkpoint_batches = batches[steps_before_checkpoint:] + + # Resume with new instances if provided + resumed_batches = [] + resumed_metrics = {} + + if ( + resume_dataloader is not None + and resume_aggregator is not None + and checkpoint_state is not None + ): + # Test resuming with new instances + resume_dataloader.load_state_dict(checkpoint_state["loader"]) + resume_aggregator.load_state_dict(checkpoint_state["aggregator"]) + resume_iterator = iter(resume_dataloader) + + # Collect only the post-checkpoint batches when resuming + for idx, batch in enumerate(resume_iterator): + resumed_batches.append(batch) + + # Process metrics + if "metrics" in batch: + resume_aggregator.update(batch.pop("metrics")) + + # Stop after steps_after_checkpoint + if idx == steps_after_checkpoint - 1: + break + + resumed_metrics = resume_aggregator.get_metrics_for_logging(prefix="train") + + return { + # Original run + "pre_checkpoint_batches": pre_checkpoint_batches, + "post_checkpoint_batches": post_checkpoint_batches, + "metrics_at_checkpoint": metrics_at_checkpoint, + "final_metrics": aggregator.get_metrics_for_logging(prefix="train"), + # Resumed run + "resumed_batches": resumed_batches, + "resumed_metrics": resumed_metrics, + # Internal state for loading - only if someone needs to manually load + "_checkpoint_state": checkpoint_state, + } diff --git a/torchtune/data/metrics/__init__.py b/torchtune/data/metrics/__init__.py new file mode 100644 index 0000000000..778245f83a --- /dev/null +++ b/torchtune/data/metrics/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.data.metrics._metric_agg_handlers import ( + AggregationHandler, + CategoricalCountAggHandler, + DistributionAggHandler, + MaxAggHandler, + MeanAggHandler, + MetricState, + MinAggHandler, + SumAggHandler, +) +from torchtune.data.metrics._metric_aggregator import MetricsAggregator +from torchtune.data.metrics._metric_transform import ( + AggregationType, + DefaultTrainingMetricTransform, + Metric, + MetricTransform, +) + +__all__ = [ + "AggregationType", + "AggregationHandler", + "CategoricalCountAggHandler", + "DefaultTrainingMetricTransform", + "DistributionAggHandler", + "MaxAggHandler", + "MeanAggHandler", + "Metric", + "MetricState", + "MetricsAggregator", + "MetricTransform", + "MinAggHandler", + "SumAggHandler", +] diff --git a/torchtune/data/metrics/_metric_agg_handlers.py b/torchtune/data/metrics/_metric_agg_handlers.py new file mode 100644 index 0000000000..ac3f9a2fd7 --- /dev/null +++ b/torchtune/data/metrics/_metric_agg_handlers.py @@ -0,0 +1,476 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from abc import ABC, abstractmethod +from collections import Counter, deque +from dataclasses import dataclass, field +from typing import Any + +import torch + +from torchtune.data.metrics._metric_transform import AggregationType, Metric + +logger = logging.getLogger(__name__) + + +@dataclass +class MetricState: + """Mutable state object representing aggregated metric for (dataset, metric) on a single rank. + + Attributes: + dataset_name (str): Name of the dataset. + metric_name (str): Name of the metric. + value (float): Current aggregated value, whose meaning depends on the aggregation type + (e.g., running sum, current max). + agg_type (AggregationType): Aggregation type. + metadata (dict[str, Any]): Additional state like count, list of values, etc. + """ + + dataset_name: str + metric_name: str + value: float + agg_type: AggregationType + metadata: dict[str, Any] = field(default_factory=dict) + + +class AggregationHandler(ABC): + """Base class for handling metric aggregation using the Strategy pattern. + + Each handler implements a specific aggregation strategy (SUM, MEAN, DISTRIBUTION, etc.) + and manages the complete lifecycle: initialization, updates, local finalization, + and distributed reduction. Handlers also handle serialization for checkpointing. + + The handler architecture allows pluggable aggregation strategies while maintaining + consistent interfaces for the MetricsAggregator. + """ + + @abstractmethod + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + """Create a new MetricState for a (dataset_name, metric_name) pair. + + Args: + dataset_name (str): Name of the dataset. Especially useful when tracking multiple datasets. + metric_name (str): Name of the metric. + agg_type (AggregationType): Aggregation type. + + Returns: + MetricState: New MetricState for this (dataset_name, metric_name) pair. + """ + pass + + @abstractmethod + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + """Update cumulative MetricState with new metric info. + + Args: + local_agg_metric (MetricState): Cumulative state of the aggregation for this metric in the local rank. + metric (Metric): Input metric info. + """ + pass + + @abstractmethod + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + """ + Computes the final value from the locally aggregated state. + + This method may expand a single metric into multiple, for instance, + a distribution into mean, min, max, and percentiles. + + Args: + local_agg_metric (MetricState): The locally aggregated metric state to finalize. + + Returns: + list[MetricState]: List of finalized metric states. + """ + pass + + @abstractmethod + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + """ + Merge MetricStates from all ranks into final result. + + Args: + local_agg_metrics (list[MetricState]): list of MetricStates for this (dataset_name, metric_name) pair. + + Returns: + MetricState: Final result for this (dataset_name, metric_name) pair. + """ + pass + + def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert handler-specific metadata to serializable format. + + Args: + metadata (dict[str, Any]): AggHandler-specific metadata. + + Returns: + dict[str, Any]: Serializable metadata. + + Override this when using non-serializable types like deque or Counter. + For example, convert deque to list, Counter to dict. + """ + return metadata.copy() + + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Restore handler-specific metadata from serialized format. + + Args: + metadata (dict[str, Any]): AggHandler-specific metadata. + + Returns: + dict[str, Any]: Deserialized metadata. + + Override this to reverse the serialize_metadata transformation. + For example, convert list back to deque, dict back to Counter. + """ + return metadata.copy() + + +class SumAggHandler(AggregationHandler): + """AggHandler for SUM aggregation. Initializes with 0.0 and accumulates metric values.""" + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + if not isinstance(metric.value, (int, float)): + raise ValueError( + f"SumAggHandler expects numeric values, got {type(metric.value)}" + ) + local_agg_metric.value += metric.value + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + return [local_agg_metric] + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + if not local_agg_metrics: + raise ValueError("Cannot aggregate empty list of metrics") + + total = sum(metric.value for metric in local_agg_metrics) + return MetricState( + dataset_name=local_agg_metrics[0].dataset_name, + metric_name=local_agg_metrics[0].metric_name, + value=total, + agg_type=local_agg_metrics[0].agg_type, + metadata=local_agg_metrics[0].metadata.copy(), + ) + + +class MaxAggHandler(AggregationHandler): + """AggHandler for MAX aggregation. Tracks maximum value across all updates.""" + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=float("-inf"), + agg_type=agg_type, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + if not isinstance(metric.value, (int, float)): + raise ValueError( + f"MaxAggHandler expects numeric values, got {type(metric.value)}" + ) + local_agg_metric.value = max(local_agg_metric.value, metric.value) + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + return [local_agg_metric] + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + max_value = max(r.value for r in local_agg_metrics) + return MetricState( + dataset_name=local_agg_metrics[0].dataset_name, + metric_name=local_agg_metrics[0].metric_name, + value=max_value, + agg_type=local_agg_metrics[0].agg_type, + metadata=local_agg_metrics[0].metadata.copy(), + ) + + +class MinAggHandler(AggregationHandler): + """AggHandler for MIN aggregation. Tracks minimum value across all updates.""" + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=float("inf"), + agg_type=agg_type, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + if not isinstance(metric.value, (int, float)): + raise ValueError( + f"MinAggHandler expects numeric values, got {type(metric.value)}" + ) + local_agg_metric.value = min(local_agg_metric.value, metric.value) + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + return [local_agg_metric] + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + min_value = min(r.value for r in local_agg_metrics) + return MetricState( + dataset_name=local_agg_metrics[0].dataset_name, + metric_name=local_agg_metrics[0].metric_name, + value=min_value, + agg_type=local_agg_metrics[0].agg_type, + metadata=local_agg_metrics[0].metadata.copy(), + ) + + +class MeanAggHandler(AggregationHandler): + """AggHandler for MEAN aggregation. Maintains running sum and count to compute average.""" + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + metadata={"sum": 0.0, "count": 0}, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.metadata["sum"] += metric.value + local_agg_metric.metadata["count"] += 1 + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + count = local_agg_metric.metadata["count"] + local_agg_metric.value = ( + local_agg_metric.metadata["sum"] / count if count > 0 else 0.0 + ) + return [local_agg_metric] + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + total_sum = sum(metric.metadata["sum"] for metric in local_agg_metrics) + total_count = sum(metric.metadata["count"] for metric in local_agg_metrics) + + return MetricState( + dataset_name=local_agg_metrics[0].dataset_name, + metric_name=local_agg_metrics[0].metric_name, + value=total_sum / total_count if total_count > 0 else 0.0, + agg_type=local_agg_metrics[0].agg_type, + metadata={"sum": total_sum, "count": total_count}, + ) + + +class DistributionAggHandler(AggregationHandler): + """AggHandler for DISTRIBUTION aggregation. Maintains a sliding window of values + and expands into multiple statistical metrics (mean, min, max, percentiles, std). + + Note: Percentiles and standard deviation are approximated in distributed settings by averaging local + percentiles and standard deviations across ranks. This is mathematically imprecise but provides a + reasonable approximation for monitoring purposes. + + Args: + window_size (int): Maximum number of recent values to retain for statistics. + + Raises: + ValueError: If window_size is not positive. + """ + + def __init__(self, window_size: int = 1000): + if window_size <= 0: + raise ValueError(f"window_size must be positive, got {window_size}") + self.window_size = window_size + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + metadata={"values": deque(maxlen=self.window_size)}, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.metadata["values"].append(metric.value) + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + values = list(local_agg_metric.metadata["values"]) + if not values: + return [] + + return self._compute_distribution_stats(local_agg_metric, values) + + def _compute_distribution_stats( + self, local_agg_metric: MetricState, values: list[float] + ) -> list[MetricState]: + """Compute statistical metrics from distribution values using torch for efficiency.""" + if not values: + return [] + + # Use float64 for precision matching python's float + values_tensor = torch.tensor(values, dtype=torch.float64) + n = len(values_tensor) + + # Compute all stats from the tensor + sum_val = torch.sum(values_tensor).item() + mean_val = sum_val / n + min_val = torch.min(values_tensor).item() + max_val = torch.max(values_tensor).item() + + # Compute all percentiles in one go + percentile_definitions = torch.tensor([0.05, 0.5, 0.95], dtype=torch.float64) + p05_val, p50_val, p95_val = torch.quantile( + values_tensor, percentile_definitions + ).tolist() + + # Return multiple MetricStates with proper agg_types for distributed reduction + # NOTE: Percentiles use MEAN aggregation which approximates global percentiles + # by averaging local percentiles. + metrics = [ + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_stat_mean", + value=mean_val, + agg_type=AggregationType.MEAN, + metadata={"sum": sum_val, "count": n}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_stat_min", + value=min_val, + agg_type=AggregationType.MIN, + metadata={}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_stat_max", + value=max_val, + agg_type=AggregationType.MAX, + metadata={}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_stat_p05", + value=p05_val, + agg_type=AggregationType.MEAN, + metadata={"sum": p05_val, "count": 1}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_stat_p50", + value=p50_val, + agg_type=AggregationType.MEAN, + metadata={"sum": p50_val, "count": 1}, + ), + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_stat_p95", + value=p95_val, + agg_type=AggregationType.MEAN, + metadata={"sum": p95_val, "count": 1}, + ), + ] + # Standard deviation is only well-defined for n > 1 + if n > 1: + std_val = torch.std(values_tensor).item() + metrics.append( + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_stat_std", + value=std_val, + agg_type=AggregationType.MEAN, + metadata={"sum": std_val, "count": 1}, + ) + ) + return metrics + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + raise NotImplementedError( + "Metrics with AggregationType.DISTRIBUTION are converted to other " + "AggregationTypes for distributed reduction. finalize_dist_agg should not be called." + ) + + def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert deque to list for serialization.""" + serialized = metadata.copy() + if "values" in serialized: + serialized["values"] = list(serialized["values"]) + return serialized + + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert list back to deque.""" + deserialized = metadata.copy() + if "values" in deserialized: + deserialized["values"] = deque( + deserialized["values"], maxlen=self.window_size + ) + return deserialized + + +class CategoricalCountAggHandler(AggregationHandler): + """AggHandler for CATEGORICAL_COUNT aggregation. Counts occurrences of categorical values + and expands into individual count metrics for each category.""" + + def initialize_metric_state( + self, dataset_name: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + metadata={"counts": Counter()}, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.metadata["counts"][metric.value] += 1 + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + # Expand categorical counts into individual metrics + results = [] + for category, count in local_agg_metric.metadata["counts"].items(): + results.append( + MetricState( + dataset_name=local_agg_metric.dataset_name, + metric_name=f"{local_agg_metric.metric_name}_count_{category}", + value=count, + agg_type=AggregationType.SUM, + ) + ) + return results + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + raise NotImplementedError( + "Metrics with AggregationType.CATEGORICAL_COUNT are converted to other " + "AggregationTypes for distributed reduction. finalize_dist_agg should not be called." + ) + + def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert Counter to dict for serialization.""" + serialized = metadata.copy() + if "counts" in serialized: + serialized["counts"] = dict(serialized["counts"]) + return serialized + + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert dict back to Counter.""" + deserialized = metadata.copy() + if "counts" in deserialized: + deserialized["counts"] = Counter(deserialized["counts"]) + return deserialized diff --git a/torchtune/data/metrics/_metric_aggregator.py b/torchtune/data/metrics/_metric_aggregator.py new file mode 100644 index 0000000000..da6b152350 --- /dev/null +++ b/torchtune/data/metrics/_metric_aggregator.py @@ -0,0 +1,344 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import logging +from collections import defaultdict +from typing import Any, Union + +import torch.distributed as dist + +from torchtune.data.metrics._metric_agg_handlers import ( + AggregationHandler, + CategoricalCountAggHandler, + DistributionAggHandler, + MaxAggHandler, + MeanAggHandler, + MetricState, + MinAggHandler, + SumAggHandler, +) +from torchtune.data.metrics._metric_transform import AggregationType, Metric + +logger = logging.getLogger(__name__) + + +class MetricsAggregator: + """Aggregates metrics across datasets and distributed ranks using pluggable handlers. + + This class uses a handler-based strategy, where each aggregation type (SUM, MEAN, etc.) + has a corresponding AggregationHandler. It maintains a single state object for each + (dataset, metric) pair. + + Internal State Visualization: + { + ("alpaca", "tokens_seen"): MetricState(value=200.0, agg_type=SUM, ...), + ("alpaca", "avg_loss"): MetricState(value=0.01, agg_type=MEAN, metadata={'sum': ..., 'count': ...}), + ("slim_orca", "seq_len"): MetricState(agg_type=DISTRIBUTION, metadata={'values': deque([...])}), + } + + When preparing metrics for logging, the aggregator follows a two-phase process: + 1. Local Aggregation: Each rank aggregates its metrics independently + 2. Distributed Reduction: If in distributed mode, results are combined across ranks + + The aggregator's state is checkpointable, allowing training resumption. + + Args: + dist_window_size (int): Window size for DistributionAggHandler tracking. + + Example: + >>> from torchtune.data.metrics import MetricsAggregator, Metric, AggregationType + >>> + >>> aggregator = MetricsAggregator() + >>> + >>> # Sample metrics from different batches + >>> batch1_metrics = [ + ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + ... ] + >>> + >>> batch2_metrics = [ + ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + ... ] + >>> + >>> # Update with metrics + >>> aggregator.update(batch1_metrics) + >>> aggregator.update(batch2_metrics) + >>> + >>> # Get final results + >>> results = aggregator.get_metrics_for_logging(prefix="train") + >>> # {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} + + Raises: + ValueError: If dist_window_size is not positive. + """ + + def __init__(self, dist_window_size: int = 1000): + if dist_window_size <= 0: + raise ValueError( + f"dist_window_size must be positive, got {dist_window_size}" + ) + + # Storage: {(dataset, metric): MetricState} - O(unique metrics) not O(samples) + self._metric_states: dict[tuple[str, str], MetricState] = {} + self._dist_window_size = dist_window_size + + # Track aggregation types for validation - prevents same metric name with different agg types + self._metric_agg_types: dict[tuple[str, str], AggregationType] = {} + + # Create handler registry - all handlers initialized upfront + self._handlers: dict[AggregationType, AggregationHandler] = { + AggregationType.SUM: SumAggHandler(), + AggregationType.MAX: MaxAggHandler(), + AggregationType.MIN: MinAggHandler(), + AggregationType.MEAN: MeanAggHandler(), + AggregationType.DISTRIBUTION: DistributionAggHandler(dist_window_size), + AggregationType.CATEGORICAL_COUNT: CategoricalCountAggHandler(), + } + + def _validate_metric_consistency(self, metric: Union[Metric, MetricState]) -> None: + """Validate that metric name uses consistent aggregation type.""" + metric_key = (metric.dataset_name, metric.metric_name) + metric_name = metric.metric_name + + if metric_key in self._metric_agg_types: + existing_agg_type = self._metric_agg_types[metric_key] + if existing_agg_type != metric.agg_type: + raise ValueError( + f"Metric '{metric_name}' in dataset '{metric.dataset_name}' " + f"is already registered with aggregation type {existing_agg_type.value}, " + f"but a handler or user code tried to use it with type {metric.agg_type.value}. " + f"Use different metric names for different aggregation types." + ) + else: + # Track this metric's aggregation type + self._metric_agg_types[metric_key] = metric.agg_type + + def register_handler( + self, agg_type: AggregationType, handler: AggregationHandler + ) -> None: + """Register custom aggregation handler for specified type. + + Args: + agg_type (AggregationType): The aggregation type to handle + handler (AggregationHandler): Handler instance implementing the AggregationHandler interface + """ + # Warn if replacing a handler that's already in use + if agg_type in self._handlers and any( + state.agg_type == agg_type for state in self._metric_states.values() + ): + logger.warning( + f"Replacing handler for {agg_type} - aggregation type already in use by existing metrics. " + f"This may affect existing metric behavior." + ) + + self._handlers[agg_type] = handler + + def update(self, metrics: list[Metric]) -> None: + """Update (dataset_name, metric_name) metric state with new values. + + Args: + metrics (list[Metric]): List of metrics to update the state with + + Raises: + ValueError: If no handler is registered for a metric's aggregation type, + or if metric name conflicts with existing aggregation type. + """ + for metric in metrics: + # Same metric name must use same aggregation type + self._validate_metric_consistency(metric) + + metric_key = (metric.dataset_name, metric.metric_name) + handler = self._handlers.get(metric.agg_type) + + if handler is None: + raise ValueError( + f"No handler registered for aggregation type: {metric.agg_type}" + ) + + if metric_key not in self._metric_states: + self._metric_states[metric_key] = handler.initialize_metric_state( + metric.dataset_name, metric.metric_name, metric.agg_type + ) + + local_agg_metric = self._metric_states[metric_key] + handler.update(local_agg_metric, metric) # Mutates local_agg_metric + + def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]: + """Get final metrics for logging in standard format. + + Args: + prefix (str): Prefix for metric names in the returned dictionary + + Returns: + dict[str, float]: Dictionary with keys like "{prefix}_{dataset_name}/{metric_name}" + and float values. For example, with `prefix="train"`, `dataset_name="alpaca"`, + `metric_name="loss"`, the key would be `train_alpaca/loss`. + """ + final_results = self._compute_unified_metrics() + + return { + f"{prefix}_{result.dataset_name}/{result.metric_name}": result.value + for result in final_results + } + + def _compute_unified_metrics(self) -> list[MetricState]: + """ + Compute metrics handling both local and distributed cases uniformly. + + Returns: + list[MetricState]: Final results ready for logging + """ + # Step 1: Get local results from all handlers (may expand distributions/categoricals) + prepared_results = [] + for local_agg_metric in self._metric_states.values(): + handler = self._handlers[local_agg_metric.agg_type] + generated_metrics = handler.finalize_local_agg(local_agg_metric) + + # Validate each newly generated metric state immediately + for gen_metric in generated_metrics: + self._validate_metric_consistency(gen_metric) + + prepared_results.extend(generated_metrics) + + # Step 2: Apply distributed reduction if needed + if dist.is_initialized() and dist.get_world_size() > 1: + prepared_results = self._finalize_dist_agg(prepared_results) + + return prepared_results + + def _finalize_dist_agg( + self, local_agg_metrics: list[MetricState] + ) -> list[MetricState]: + """Apply distributed reduction to local results. + + Args: + local_agg_metrics (list[MetricState]): (dataset_name, metric_name) metric pairs from this rank + + Returns: + list[MetricState]: Reduced results combining all ranks + """ + world_size = dist.get_world_size() + + # Gather all results from all ranks + all_results = [None] * world_size + dist.all_gather_object(all_results, local_agg_metrics) + + # Group by (dataset_name, metric_name) for reduction + grouped = defaultdict(list) + for rank_results in all_results: + if rank_results: # Handle ranks with no metrics + for result in rank_results: + result_key = (result.dataset_name, result.metric_name) + grouped[result_key].append(result) + + # Apply handler-specific distributed reduction + reduced_results = [] + for result_key, results_list in grouped.items(): + if not results_list: + continue # Skip empty groups + + # All results for a key should have same agg_type + agg_type = results_list[0].agg_type + handler = self._handlers[agg_type] + reduced_result = handler.finalize_dist_agg(results_list) + reduced_results.append(reduced_result) + + return reduced_results + + def state_dict(self) -> dict[str, Any]: + """Serialize aggregator state for checkpointing. + + Returns: + dict[str, Any]: Serializable dictionary containing all aggregator state + """ + serializable_state = {} + required_agg_types = set() # Track aggregation types used in saved states + + for metric_key, local_agg_metric in self._metric_states.items(): + # Get handler for this result's aggregation type + handler = self._handlers[local_agg_metric.agg_type] + required_agg_types.add(local_agg_metric.agg_type) + + # Convert MetricState to serializable dict + result_dict = { + "dataset_name": local_agg_metric.dataset_name, + "metric_name": local_agg_metric.metric_name, + "value": local_agg_metric.value, + "agg_type": local_agg_metric.agg_type, + "metadata": handler.serialize_metadata(local_agg_metric.metadata), + } + + # Convert tuple key to string for JSON compatibility + serializable_state[str(metric_key)] = result_dict + + return { + "state": serializable_state, + "dist_window_size": self._dist_window_size, + "required_agg_types": list( + required_agg_types + ), # Save which handlers are needed + # Save which aggregation types are used for each metric + "metric_agg_types": { + str(k): v.value for k, v in self._metric_agg_types.items() + }, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load aggregator state from checkpoint. + + Args: + state_dict (dict[str, Any]): Dictionary containing serialized aggregator state + + Raises: + ValueError: If required handlers are missing after checkpoint restore + """ + self._dist_window_size = state_dict.get("dist_window_size", 1000) + + # Sanity check: Ensure all required handlers are available + required_agg_types = state_dict.get("required_agg_types", []) + missing_handlers = [] + for agg_type in required_agg_types: + if agg_type not in self._handlers: + missing_handlers.append(agg_type) + + if missing_handlers: + raise ValueError( + f"Missing handlers for aggregation types: {missing_handlers}. " + f"Custom handlers must be re-registered before checkpoint restore." + ) + + deserialized_state = {} + for key_str, result_dict in state_dict["state"].items(): + # Convert string keys back to tuples + metric_key = ast.literal_eval(key_str) + + # Get handler for this aggregation type + agg_type = result_dict["agg_type"] + handler = self._handlers[agg_type] + + # Restore metadata using handler-specific deserialization + metadata = handler.deserialize_metadata(result_dict["metadata"]) + + # Create MetricState from dict + local_agg_metric = MetricState( + dataset_name=result_dict["dataset_name"], + metric_name=result_dict["metric_name"], + value=result_dict["value"], + agg_type=result_dict["agg_type"], + metadata=metadata, + ) + + deserialized_state[metric_key] = local_agg_metric + + self._metric_states = deserialized_state + + # Restore validation state + self._metric_agg_types = {} + for key_str, agg_type_str in state_dict.get("metric_agg_types", {}).items(): + key = ast.literal_eval(key_str) + self._metric_agg_types[key] = AggregationType(agg_type_str) diff --git a/torchtune/data/metrics/_metric_transform.py b/torchtune/data/metrics/_metric_transform.py new file mode 100644 index 0000000000..8521f6e6dd --- /dev/null +++ b/torchtune/data/metrics/_metric_transform.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Optional, Union + +from torchtune.modules.transforms import Transform + + +@dataclass(frozen=True) +class Metric: + dataset_name: str + metric_name: str + value: Union[int, float, str] + agg_type: "AggregationType" + + +class AggregationType(Enum): + """Defines how a metric's value should be aggregated by the MetricsAggregator. + + Each type corresponds to a specific AggregationHandler that implements the logic + for initialization, updates, and distributed reduction. + """ + + SUM = "sum" + MEAN = "mean" + DISTRIBUTION = "distribution" + CATEGORICAL_COUNT = "categorical_count" + MAX = "max" + MIN = "min" + + +class MetricTransform(Transform): + """Applied to each dataset sample to generate per-sample metrics for training tracking. + + Creates Metric objects that are later aggregated by MetricsAggregator. This separation + of concerns ensures metrics are correctly aggregated even with multiple dataloader + workers and in distributed settings. + + The transform must be configured with a dataset name via set_dataset_name() before use. + Each call to __call__ adds metrics to the sample's "metrics" key. + + Example: + >>> transform = DefaultTrainingMetricTransform() + >>> transform.set_dataset_name("alpaca") + >>> sample = {"tokens": [1, 2, 3]} + >>> result = transform(sample) + >>> # result["metrics"] contains list of Metric objects + """ + + def __init__(self): + # dataset_name is set by the dataset using set_dataset_name + self.dataset_name: Optional[str] = None + + def set_dataset_name(self, dataset_name: str) -> None: + """Called by the dataset to set the namespace for metrics. + + This is used to differentiate metrics from multiple datasets, for example, + "train_alpaca/tokens_seen" vs. "train_slim_orca/tokens_seen". + + Args: + dataset_name (str): Name of the dataset, used for metric namespacing. + """ + self.dataset_name = dataset_name + + def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: + """Generate metrics for a single sample. Must be implemented by subclasses. + + Args: + sample (dict[str, Any]): The sample dictionary to generate metrics from + + Returns: + list[Metric]: List of metrics generated for this sample + + Raises: + NotImplementedError: If subclass does not implement this method. + """ + raise NotImplementedError("Subclasses must implement _generate_metrics method") + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Apply transform to sample, adding generated metrics to the sample. + + Args: + sample (dict[str, Any]): Input sample dictionary + + Returns: + dict[str, Any]: Sample with metrics added to "metrics" key (list[Metric]) + + Raises: + RuntimeError: If set_dataset_name() was not called before transform usage + """ + if self.dataset_name is None: + raise RuntimeError( + "set_dataset_name() must be called before using the transform." + ) + + # Generate metrics for this sample + metrics = self._generate_metrics(sample) + + # Add to existing metrics list or create new one + if "metrics" not in sample: + sample["metrics"] = [] + sample["metrics"].extend(metrics) + return sample + + +class DefaultTrainingMetricTransform(MetricTransform): + """Generates common training metrics: samples seen, tokens seen, and sequence length. + + This transform detects the token key in a sample, checking for "tokens" + first and then falling back to "input_ids". + + For details on the base class behavior, see MetricTransform. + + Tracked metrics: + - samples_seen: Cumulative count of samples processed (SUM aggregation) + - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) + - seq_len: Distribution of sequence lengths (DISTRIBUTION aggregation) + + Example: + >>> transform = DefaultTrainingMetricTransform() + >>> transform.set_dataset_name("alpaca") + >>> + >>> sample = {"tokens": [1, 2, 3, 4, 5]} # 5 tokens + >>> metrics = transform._generate_metrics(sample) + >>> # This generates the following Metric objects: + >>> # [ + >>> # Metric(dataset_name="alpaca", metric_name="samples_seen", value=1, agg_type=AggregationType.SUM), + >>> # Metric(dataset_name="alpaca", metric_name="tokens_seen", value=5, agg_type=AggregationType.SUM), + >>> # Metric(dataset_name="alpaca", metric_name="seq_len", value=5, agg_type=AggregationType.DISTRIBUTION) + >>> # ] + """ + + def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: + if self.dataset_name is None: + raise RuntimeError( + "set_dataset_name() must be called before using the transform." + ) + + # Determine token key + token_key = "tokens" if "tokens" in sample else "input_ids" + token_len = len(sample.get(token_key, [])) + + # Create metrics for this sample + return [ + Metric( + dataset_name=self.dataset_name, + metric_name="samples_seen", + value=1, + agg_type=AggregationType.SUM, + ), + Metric( + dataset_name=self.dataset_name, + metric_name="tokens_seen", + value=token_len, + agg_type=AggregationType.SUM, + ), + Metric( + dataset_name=self.dataset_name, + metric_name="seq_len", + value=token_len, + agg_type=AggregationType.DISTRIBUTION, + ), + ] diff --git a/torchtune/data/metrics/readme.md b/torchtune/data/metrics/readme.md new file mode 100644 index 0000000000..bf6cb8d27b --- /dev/null +++ b/torchtune/data/metrics/readme.md @@ -0,0 +1,176 @@ +# TorchTune Metrics Module + +## Overview + +The metrics module provides a robust system for tracking and aggregating training metrics across multiple datasets and distributed environments. It follows a **strategy pattern** design with pluggable aggregation handlers to efficiently handle different types of metrics. + +## Architecture Overview + +``` +┌────────────────────────────────────────────────────┐ +│ Training Loop │ +└─────────────────────┬──────────────────────────────┘ + │ +┌─────────────────────▼──────────────────────────────┐ +│ MetricTransform │ +│ • Applied to each sample │ +│ • Generates per-sample metrics │ +│ • Examples: tokens_seen, seq_len, samples_seen │ +└─────────────────────┬──────────────────────────────┘ + │ list[Metric] +┌─────────────────────▼──────────────────────────────┐ +│ MetricsAggregator │ +│ • Aggregates metrics across samples and ranks │ +│ • Uses pluggable AggregationHandlers │ +│ • Handles distributed reduction │ +└─────────────────────┬──────────────────────────────┘ + │ {prefix}_{dataset_name}/{metric_name} # prefix is "train", "val", etc. +┌─────────────────────▼──────────────────────────────┐ +│ Logging System │ +│ • W&B, TensorBoard, etc. │ +│ • Gets formatted metrics ready for logging │ +└────────────────────────────────────────────────────┘ +``` + +## File Structure + +- **`_metric_transform.py`**: Defines `Metric`, `AggregationType`, and transform classes +- **`_metric_agg_handlers.py`**: Aggregation strategy implementations +- **`_metric_aggregator.py`**: Main aggregator orchestrating the handlers + +## Customizing metrics + +- **Custom transforms**: Extend `MetricTransform` for domain-specific metrics +- **Handler registration**: Register custom handlers for specialized aggregation needs + +####### +## TODO +## Move this from here to website docs +####### + +## Core Components + +### 1. MetricTransform +Generates per-sample metrics during data processing. + +**Key Features:** +- Applied to each sample in the dataset +- Creates `Metric` objects with dataset name, metric name, value, and aggregation type +- Handles dataset namespacing for multi-dataset scenarios + +**Example Usage:** +```python +from torchtune.data.metrics import DefaultTrainingMetricTransform, AggregationType + +transform = DefaultTrainingMetricTransform() +transform.set_dataset_name("alpaca") + +# Applied to each sample +sample = {"tokens": [1, 2, 3, 4, 5]} +sample = transform(sample) +# sample["metrics"] now contains: +# [ +# Metric(dataset_name="alpaca", name="samples_seen", value=1, agg_type=AggregationType.SUM), +# Metric(dataset_name="alpaca", name="tokens_seen", value=5, agg_type=AggregationType.SUM), +# Metric(dataset_name="alpaca", name="seq_len", value=5, agg_type=AggregationType.DISTRIBUTION) +# ] +``` + +### 2. MetricsAggregator +Efficiently aggregates metrics across samples and distributed ranks. + +**Key Features:** +- Handler-based strategy pattern for different aggregation types +- Distributed-aware with automatic rank reduction +- Checkpointable state for training resumption +- Keep track of (metric, dataset) pairs + +**Aggregation Types (at the time of writing):** +- `SUM`: Cumulative totals (e.g., total tokens processed) +- `MEAN`: Running averages (e.g., average loss) +- `MAX/MIN`: Extrema tracking (e.g., max sequence length seen) +- `DISTRIBUTION`: Statistical summaries (mean, min, max, percentiles) +- `CATEGORICAL_COUNT`: Category cumulative counts (e.g. num of samples from a given category) + +**Example Usage:** +```python +from torchtune.data.metrics import MetricsAggregator, Metric, AggregationType + +# Create aggregator +aggregator = MetricsAggregator() + +# Sample metrics from different batches +batch1_metrics = [ + Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), +] + +batch2_metrics = [ + Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), +] + +# Update with metrics +aggregator.update(batch1_metrics) +aggregator.update(batch2_metrics) + +# Get final results +results = aggregator.get_metrics_for_logging(prefix="train") +# {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} +``` + +### 3. AggregationHandlers +Pluggable strategies for different aggregation patterns. + +``` +AggregationHandler (ABC) +├── SumAggHandler # value += metric.value +├── MeanAggHandler # tracks sum and count +├── MaxAggHandler # value = max(value, metric.value) +├── MinAggHandler # value = min(value, metric.value) +├── DistributionAggHandler # maintains value window + stats +└── CategoricalCountAggHandler # Counter for categories +``` + +**Custom Handler Example:** +```python +class CustomAggHandler(AggregationHandler): + def initialize_metric_state(self, dataset_name, metric_name, agg_type): + return MetricState( + dataset_name=dataset_name, + metric_name=metric_name, + value=, # should change + agg_type=agg_type, + metadata={} # may need to change + ) + + def update(self, local_agg_metric, metric): + ... + + def finalize_local_agg(self, local_agg_metric): + ... + + def finalize_dist_agg(self, local_agg_metrics): + ... + +# Register with aggregator +aggregator.register_handler(AggregationType.CUSTOM, CustomAggHandler()) +``` + +## Distributed Training Support + +The metrics system automatically handles distributed environments: + +1. **Local Aggregation**: Each rank aggregates its own metrics +2. **Distributed Reduction**: Results are combined across ranks using `all_gather_object` +3. **Type-Aware Reduction**: Each aggregation type uses appropriate reduction (sum, mean, max, etc.) + +**Distributed Flow:** +``` +Rank 0: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)] +Rank 1: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)] + ↓ + AllGather + Reduce + ↓ + Final Results [(ds1, metric1), (ds1, metric2)] +``` diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index b0c7c11738..b38663578e 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -5,18 +5,29 @@ # LICENSE file in the root directory of this source tree. from torchtune.datasets import multimodal -from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset +from torchtune.datasets._alpaca import ( + alpaca_cleaned_dataset, + alpaca_dataset, + alpaca_iterable_dataset, +) from torchtune.datasets._chat import chat_dataset from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset from torchtune.datasets._concat import ConcatDataset from torchtune.datasets._grammar import grammar_dataset +from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset from torchtune.datasets._instruct import instruct_dataset +from torchtune.datasets._interleaved import InterleavedDataset +from torchtune.datasets._iterable_base import ( + DatasetInfo, + InfiniteTuneIterableDataset, + TuneIterableDataset, +) from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import preference_dataset, PreferenceDataset from torchtune.datasets._samsum import samsum_dataset -from torchtune.datasets._sft import SFTDataset -from torchtune.datasets._slimorca import slimorca_dataset +from torchtune.datasets._sft import sft_iterable_dataset, SFTDataset +from torchtune.datasets._slimorca import slimorca_dataset, slimorca_iterable_dataset from torchtune.datasets._stack_exchange_paired import stack_exchange_paired_dataset from torchtune.datasets._text_completion import ( text_completion_dataset, @@ -25,23 +36,31 @@ from torchtune.datasets._wikitext import wikitext_dataset __all__ = [ - "alpaca_dataset", "alpaca_cleaned_dataset", + "alpaca_dataset", + "alpaca_iterable_dataset", + "chat_dataset", + "cnn_dailymail_articles_dataset", + "ConcatDataset", + "DatasetInfo", "grammar_dataset", - "samsum_dataset", - "stack_exchange_paired_dataset", - "slimorca_dataset", + "hh_rlhf_helpful_dataset", + "HfIterableDataset", "instruct_dataset", + "InterleavedDataset", + "multimodal", + "PackedDataset", "preference_dataset", - "chat_dataset", + "PreferenceDataset", + "samsum_dataset", + "SFTDataset", + "sft_iterable_dataset", + "slimorca_dataset", + "slimorca_iterable_dataset", + "stack_exchange_paired_dataset", "text_completion_dataset", "TextCompletionDataset", - "cnn_dailymail_articles_dataset", - "PackedDataset", - "ConcatDataset", + "InfiniteTuneIterableDataset", + "TuneIterableDataset", "wikitext_dataset", - "PreferenceDataset", - "SFTDataset", - "hh_rlhf_helpful_dataset", - "multimodal", ] diff --git a/torchtune/datasets/_alpaca.py b/torchtune/datasets/_alpaca.py index 1ecee62f53..4326b7024f 100644 --- a/torchtune/datasets/_alpaca.py +++ b/torchtune/datasets/_alpaca.py @@ -10,8 +10,9 @@ from torchtune.data._messages import AlpacaToMessages +from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._packed import PackedDataset -from torchtune.datasets._sft import SFTDataset +from torchtune.datasets._sft import sft_iterable_dataset, SFTDataset from torchtune.modules.transforms.tokenizers import ModelTokenizer @@ -101,3 +102,64 @@ def alpaca_dataset( original Alpaca dataset, `yahma/alpaca-cleaned `_. See the dataset page and :func:`~torchtune.datasets.alpaca_dataset` for more details. """ + + +def alpaca_iterable_dataset( + model_transform: ModelTokenizer, + *, + path: str = "tatsu-lab/alpaca", + column_map: Optional[dict[str, str]] = None, + train_on_input: bool = True, + shuffle_buffer_size: Optional[int] = 1000, + seed: int = 42, + dataset_name: Optional[str] = None, + filter_fn: Optional[Callable] = None, + split: str = "train", + **load_dataset_kwargs: dict[str, Any], +) -> HfIterableDataset: + """ + Support for iterable version of Alpaca-style datasets. + + This returns an infinite iterable dataset that supports checkpointing + and metrics tracking, designed for step-based training. + + Args: + model_transform (ModelTokenizer): Model tokenizer used to tokenize the messages. + path (str): path to dataset repository on Hugging Face. Default is ``tatsu-lab/alpaca``. + column_map (Optional[dict[str, str]]): a mapping from the expected columns in the message transform + :class:`~torchtune.data.AlpacaToMessages` to the new column names in the dataset. Keys should be + "instruction", "input", and "output" and values should be the actual column names. + train_on_input (bool): Whether the model is trained on the prompt or not. Default is True. + shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. If None or 0, no shuffling is done. + seed (int): Seed for shuffling. + dataset_name (Optional[str]): Name of the dataset for metrics tracking. If None, auto-generated. + filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. + split (str): ``split`` argument for ``datasets.load_dataset``. Default is "train". + **load_dataset_kwargs (dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. + + Returns: + HfIterableDataset: iterable dataset configured with source data and transforms + + Example: + >>> from torchdata.stateful_dataloader import StatefulDataLoader + >>> alpaca_ds = alpaca_iterable_dataset(tokenizer=tokenizer) + >>> dataloader = StatefulDataLoader(alpaca_ds, batch_size=8) + >>> for batch in dataloader: + >>> print(f"Batch size: {len(batch)}") + >>> Batch size: 8 + """ + message_transform = AlpacaToMessages( + train_on_input=train_on_input, column_map=column_map + ) + + return sft_iterable_dataset( + message_transform=message_transform, + model_transform=model_transform, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + dataset_name=dataset_name, + filter_fn=filter_fn, + split=split, + path=path, + **load_dataset_kwargs, + ) diff --git a/torchtune/datasets/_hf_iterable.py b/torchtune/datasets/_hf_iterable.py new file mode 100644 index 0000000000..f517fece31 --- /dev/null +++ b/torchtune/datasets/_hf_iterable.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Any, Callable, Iterator, Optional + +import torch +import torch.distributed as dist +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node + +from torchtune.data.metrics import ( + AggregationType, + DefaultTrainingMetricTransform, + Metric, + MetricTransform, +) +from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset + +logger = logging.getLogger(__name__) + + +class HfIterableDataset(InfiniteTuneIterableDataset): + """HuggingFace dataset with infinite iteration and composable transforms. + + This is an infinite dataset that wraps a HuggingFace dataset. After exhausting + the dataset, it will restart from the beginning. + + Transform pipeline: raw_data -> message_transform -> model_transform -> output_transform -> metric_transform + + This dataset is responsible for: + - Loading and sharding the dataset + - Shuffling at initialization and after each epoch + - Applying transforms to the data + - Returning an infinite iterator over the dataset + + Args: + message_transform (Optional[Callable]): Transforms raw data into a `Message`. + model_transform (Optional[Callable]): Prepares messages for the model, + usually by tokenizing them. + output_transform (Optional[Callable]): Prepares tokenized inputs for the + recipe, often by manipulating labels (e.g., setting an ignore index). + This transform is recipe-dependent (e.g., SFT, DPO, etc.). + metric_transform (Optional[MetricTransform]): Computes metrics from a + sample (e.g., token count). If ``None``, a default transform is used. + To disable standard metric tracking, set this to ``lambda x: x``. + shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. + If ``None`` or 0, no shuffling is performed. + weight (Optional[float]): Weight for this dataset. Defaults to 1.0. + seed (int): Seed for shuffling. + num_shards_per_rank (int): The target number of shards per worker (GPU). + The actual number of shards will be a multiple of + ``world_size * dataloader_workers``. + dataset_name (Optional[str]): Name of the dataset. If ``None``, a name is + generated from the ``path``, ``source``, and ``split``. + filter_fn (Optional[Callable]): A function to filter the dataset. + filter_kwargs (Optional[dict[str, Any]]): Keyword arguments for ``filter_fn``. + **load_dataset_kwargs: Keyword arguments for the + :func:`~datasets.load_dataset` function. + """ + + def __init__( + self, + *, + message_transform: Optional[Callable] = None, + model_transform: Optional[Callable] = None, + output_transform: Optional[Callable] = None, + metric_transform: Optional[MetricTransform] = None, + shuffle_buffer_size: Optional[int] = 1000, + weight: Optional[float] = 1.0, + seed: int = 42, + num_shards_per_rank: int = 64, + dataset_name: Optional[str] = None, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[dict[str, Any]] = None, + **load_dataset_kwargs, + ): + # Store configuration + self._shuffle_buffer_size = shuffle_buffer_size + self._seed = seed + self._message_transform = message_transform + self._model_transform = model_transform + self._output_transform = output_transform + self._weight = weight if weight is not None else 1.0 + + # Create default transform if not provided + self._metric_transform = metric_transform or DefaultTrainingMetricTransform() + + # Auto-generate dataset name if not provided, ensuring it's always a string + if dataset_name is None: + path = load_dataset_kwargs.get("path", None) + source = load_dataset_kwargs.get("source", None) + split = load_dataset_kwargs.get("split", None) + name_parts = [] + for item in [path, source, split]: + if item is not None: + name_parts.append(str(item).replace("/", "_")) + dataset_name = "_".join(name_parts) + + # Build the hierarchical info object for this dataset + self._info = DatasetInfo(name=dataset_name, weight=self._weight) + + # Set dataset name on the transform if it supports it + if hasattr(self._metric_transform, "set_dataset_name"): + self._metric_transform.set_dataset_name(dataset_name) + + # Internal state for resumption + self._num_epochs = 0 + + # Load and setup HF dataset + self._setup_hf_dataset( + load_dataset_kwargs, num_shards_per_rank, filter_fn, filter_kwargs + ) + + @property + def info(self) -> DatasetInfo: + """Returns info for this leaf dataset, which has no children.""" + return self._info + + def _apply_transforms(self, sample: dict[str, Any]) -> dict[str, Any]: + """Apply transforms if they exist, otherwise return sample unchanged.""" + if self._message_transform is not None: + sample = self._message_transform(sample) + if self._model_transform is not None: + sample = self._model_transform(sample) + if self._output_transform is not None: + sample = self._output_transform(sample) + if self._metric_transform is not None: + sample = self._metric_transform(sample) + return sample + + def _setup_hf_dataset( + self, + load_dataset_kwargs: dict[str, Any], + num_shards_per_rank: int, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[dict[str, Any]] = None, + ): + """ + One-time setup of HuggingFace dataset that handles Handles distributed sharding, + shuffle configuration, and filtering. + + Called once during __init__ to avoid expensive re-computation. + """ + + # Distributed setup + world_size, rank = 1, 0 + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + + # Load and shard dataset + ds = load_dataset(**load_dataset_kwargs) + + # Use to_iterable_dataset for non-streaming datasets + is_streaming = load_dataset_kwargs.get("streaming", False) + if is_streaming: + logger.warning( + f"Streaming datasets were not yet tested for distributed training. " + f"split_dataset_by_node is applied, but no resharding was done manually. " + f"Dataset '{self.info.name}' has " + f"{getattr(ds, 'num_shards', 'unknown')}, and your training has {world_size} ranks." + f"See: https://huggingface.co/docs/datasets/en/package_reference/main_classes?#datasets.IterableDataset.shard" + f"Consider setting streaming=False, which should also be faster." + ) + if not is_streaming: + # Define number of shards based on (world_size, num of shards per GPU, dataloader workers) + # E.g. world_size=2, num_shards_per_rank=16, dataloader_workers=3 + # we will try 2*16 = 32 shards. Since 32 is not a multiple of 6, we will do 36 shards. + # Each rank gets 18 shards, each dataloader worker in that rank gets 6 shards. + worker_info = torch.utils.data.get_worker_info() + num_dataloader_workers = worker_info.num_workers if worker_info else 1 + + # Calculate total workers across all ranks and dataloader processes + total_workers = world_size * num_dataloader_workers + + # Find minimum shards that satisfies our target while being divisible by workers + desired_shards = world_size * num_shards_per_rank + + # Round up to next multiple of total_workers for even distribution + if desired_shards % total_workers == 0: + num_shards = desired_shards + else: + num_shards = total_workers * ( + (desired_shards + total_workers - 1) // total_workers + ) + + # If the dataset is not streaming and has a defined length, + # we cannot have num_shards > dataset_size. + if hasattr(ds, "__len__"): + dataset_size = len(ds) + if num_shards > dataset_size: + raise ValueError( + f"Number of shards ({num_shards}) is greater than the dataset size ({dataset_size})." + f"Please decrease one of {num_shards_per_rank=} or {num_dataloader_workers=} or {world_size=}." + ) + + ds = ds.to_iterable_dataset(num_shards=num_shards) + + # Shuffle the dataset + if self._shuffle_buffer_size and self._shuffle_buffer_size > 0: + ds = ds.shuffle(seed=self._seed, buffer_size=self._shuffle_buffer_size) + + # Distribute across ranks + if world_size > 1: + ds = split_dataset_by_node(ds, rank=rank, world_size=world_size) + + # Apply filtering if specified + if filter_fn: + filter_kwargs = filter_kwargs or {} + ds = ds.filter(filter_fn, **filter_kwargs) + + self._ds = ds + + def __iter__(self) -> Iterator[dict[str, Any]]: + """Infinite iteration over dataset samples. + + Behavior: + - Restarts from beginning when dataset is exhausted + - Reshuffles at start of each epoch (if enabled) + - Applies full transform pipeline to each sample + - Adds 'num_epochs' metric to track dataset progress + - Yields samples indefinitely for continuous training + """ + + while True: # Infinite iteration + self._ds.set_epoch(self._num_epochs) + epoch_iterator = iter(self._ds) + samples_yielded = 0 + + try: + for sample in epoch_iterator: + # NOTE: We apply transforms here instead of using .map() to work around + # HuggingFace datasets bug where .map() causes incorrect checkpoint resumption. + # See: https://github.com/huggingface/datasets/issues/7630 + # This ensures transforms are applied fresh on each sample during iteration. + sample = self._apply_transforms(sample) + + # Track the number of epochs completed for each dataset. This is + # especially useful when interleaving multiple datasets, but + # also necessary to track dataset-level metrics. + metric_num_epochs = Metric( + dataset_name=self.info.name, + metric_name="num_epochs", + value=self._num_epochs, + agg_type=AggregationType.MAX, + ) + if "metrics" not in sample: + sample["metrics"] = [] + sample["metrics"].append(metric_num_epochs) + + samples_yielded += 1 + yield sample + + except StopIteration: + # Expected when dataset is exhausted + pass + except Exception as e: + logger.error( + f"Dataset {self.info.name} encountered an unexpected error: {e}." + ) + raise + + # Check if we got zero samples - this might indicate an issue + if samples_yielded == 0: + logger.warning( + f"Dataset {self.info.name} epoch {self._num_epochs} yielded 0 samples - potential issue!" + ) + + # Epoch complete - increment and continue infinite loop + self._num_epochs += 1 + + def state_dict(self) -> dict[str, Any]: + """Returns dataset checkpoint state.""" + hf_state = self._ds.state_dict() + state = { + "num_epochs": self._num_epochs, + "hf_dataset_state": hf_state, + } + return state + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """ + Load state from checkpoint, including restoring the state of the + Hugging Face IterableDataset. + """ + self._num_epochs = state_dict["num_epochs"] + hf_state = state_dict["hf_dataset_state"] + + # HF is responsible for resuming the dataset state + # where it last left off + self._ds.load_state_dict(hf_state) diff --git a/torchtune/datasets/_interleaved.py b/torchtune/datasets/_interleaved.py new file mode 100644 index 0000000000..fe911aff51 --- /dev/null +++ b/torchtune/datasets/_interleaved.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from collections import deque +from typing import Any, Iterator + +import torch + +from torchtune.datasets._iterable_base import DatasetInfo, InfiniteTuneIterableDataset + +logger = logging.getLogger(__name__) + + +class InterleavedDataset(InfiniteTuneIterableDataset): + """Infinitely interleaves multiple datasets according to their sampling weights. + + The weights are extracted from each dataset's ``info.weight`` property and + normalized to sum to 1.0. This dataset manages the state of its child + datasets to ensure correct checkpointing and resumption. + + Args: + datasets (list[InfiniteTuneIterableDataset]): A list of datasets to interleave. + seed (int): The seed for sampling. + weight (float): The weight for this dataset. Defaults to 1.0. + dataset_name (str): The name of the dataset. Defaults to "interleaved_dataset". + sampling_log_maxlen (int): The maximum length of the sampling log. + """ + + def __init__( + self, + datasets: list[InfiniteTuneIterableDataset], + seed: int, + weight: float = 1.0, + dataset_name: str = "interleaved_dataset", + sampling_log_maxlen: int = 10000, + ): + self._datasets = sorted(datasets, key=lambda ds: ds.info.name) + self._sampling_log_maxlen = sampling_log_maxlen + + # Build the hierarchical info object for this dataset + self._info = DatasetInfo( + name=dataset_name, + weight=weight, + children=tuple(ds.info for ds in self._datasets), + ) + + # Validate the entire hierarchy using the base class method + self._validate_unique_dataset_names() + + # Extract weights from direct children and normalize them + child_weights = [info.weight for info in self._info.children] + total_weight = sum(child_weights) + if not math.isclose(total_weight, 1.0, rel_tol=1e-9): + logger.warning( + f"Interleaved dataset normalized weights to sum to 1.0. " + f"Previous weights={child_weights}, " + f"new weights={[w / total_weight for w in child_weights]}" + ) + self._normalized_weights = torch.tensor( + [w / total_weight for w in child_weights], dtype=torch.float + ) + + # Track sampling decisions for debugging and analysis + self._sampling_log: deque[tuple[int, str]] = deque( + maxlen=self._sampling_log_maxlen + ) + self._iteration_count = 0 + self._sampling_generator = torch.Generator().manual_seed(seed) + + @property + def info(self) -> DatasetInfo: + return self._info + + def __iter__(self) -> Iterator[dict[str, Any]]: + """Interleave samples from child infinite datasets""" + # Create a dictionary of iterators for each child dataset + child_iters = {ds.info.name: iter(ds) for ds in self._datasets} + + while True: + # Sample a child dataset based on the normalized weights + ds_idx: int = torch.multinomial( + self._normalized_weights, + 1, + replacement=True, + generator=self._sampling_generator, + ).item() + + selected_ds = self._datasets[ds_idx] + ds_name = selected_ds.info.name + + # Log + self._sampling_log.append((self._iteration_count, ds_name)) + self._iteration_count += 1 + + # Yield the next sample from the selected child iterator + yield next(child_iters[ds_name]) + + def state_dict(self) -> dict[str, Any]: + """Save interleaver state and all child dataset states.""" + # The parent is responsible for namespacing the child states + child_states = {ds.info.name: ds.state_dict() for ds in self._datasets} + return { + "sampling_generator_state": self._sampling_generator.get_state(), + "child_states": child_states, + "sampling_log": list(self._sampling_log), + "iteration_count": self._iteration_count, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load state for the interleaver and its children.""" + self._sampling_generator.set_state(state_dict["sampling_generator_state"]) + child_states = state_dict["child_states"] + + for ds in self._datasets: + ds.load_state_dict(child_states[ds.info.name]) + + # Load sampling log and iteration count + self._sampling_log = deque( + state_dict.get("sampling_log", []), maxlen=self._sampling_log_maxlen + ) + self._iteration_count = state_dict.get("iteration_count", 0) diff --git a/torchtune/datasets/_iterable_base.py b/torchtune/datasets/_iterable_base.py new file mode 100644 index 0000000000..0f412e80dc --- /dev/null +++ b/torchtune/datasets/_iterable_base.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Iterator + +from torch.utils.data import IterableDataset + + +@dataclass(frozen=True) +class DatasetInfo: + """Hierarchical metadata for datasets, enabling composition and weight tracking. + + Used to build tree structures when composing datasets. For example, a nested + `InterleavedDataset` dataset would have this structure: + + Example: + .. code-block:: python + + DatasetInfo(name='parent_interleaved', + weight=1.0, + children=(DatasetInfo(name='child_interleaved', + weight=0.7, + children=(DatasetInfo(name='dataset_a', + weight=0.6, + children=()), + DatasetInfo(name='dataset_b', + weight=0.4, + children=()))), + DatasetInfo(name='dataset_c', weight=0.3, children=()))) + + This hierarchical structure is used for validation (ensuring unique dataset + names) and for logging metrics. + + Attributes: + name (str): Unique identifier for the dataset + weight (float): Sampling weight for dataset selection (default: 1.0) + children (tuple[DatasetInfo, ...]): Nested datasets for composed structures + """ + + name: str + weight: float = 1.0 + children: tuple["DatasetInfo", ...] = field(default_factory=tuple) + + +class TuneIterableDataset(IterableDataset, ABC): + """Base class for all torchtune iterable datasets. + + Datasets are composable, enabling complex structures such as: + ``PackedDataset(InterleavedDataset([InterleavedDataset([ds1, ds2]), ds3]))`` + + Each dataset implementation must: + - Track hierarchical metadata via the ``info`` property + - Ensure unique dataset names across the entire tree + - Handle checkpointing: parents resume children's state + - Provide proper state management for exact resumption + """ + + @property + @abstractmethod + def info(self) -> DatasetInfo: + """Returns a hierarchical structure of all dataset information, including + this dataset and its children.""" + pass + + def _validate_unique_dataset_names(self) -> None: + """Traverses the DatasetInfo tree and raises ValueError on duplicate names.""" + root_info = self.info + names = [] + to_process = [root_info] + + while to_process: + node = to_process.pop(0) + names.append(node.name) + to_process.extend(node.children) + + # Check for duplicates after traversing the whole tree + duplicates = [name for name in set(names) if names.count(name) > 1] + if duplicates: + raise ValueError( + f"Duplicate dataset names found in hierarchy: {duplicates=}, all names={names}" + ) + + @abstractmethod + def __iter__(self) -> Iterator[dict[str, Any]]: + """Returns an iterator over the dataset. Each implementation is responsible + for its own iteration logic, including shuffling, distribution of data across ranks, + and making it an infinite stream.""" + pass + + @abstractmethod + def state_dict(self) -> dict[str, Any]: + """Returns checkpoint state for dataset resumption.""" + pass + + @abstractmethod + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Restores dataset state from checkpoint.""" + pass + + +class InfiniteTuneIterableDataset(TuneIterableDataset): + """Base class for infinite datasets that never exhaust. + + Prevents distributed training hangs by ensuring all ranks always + have data available. Datasets restart from beginning when exhausted. + """ + + pass diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 2e74ec66a0..f7638bb609 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -4,14 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Mapping, Optional +from typing import Any, Callable, Optional import numpy as np +import torch from datasets import load_dataset from torch.utils.data import Dataset from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages +from torchtune.data.metrics import DefaultTrainingMetricTransform +from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.modules.transforms import Transform @@ -143,7 +146,7 @@ def __init__( self._message_transform = message_transform self._model_transform = model_transform - def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: if self._message_transform is not None: transformed_sample = self._message_transform(sample) if "messages" in transformed_sample: @@ -178,3 +181,109 @@ def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]: tokenized_dict = transformed_sample return tokenized_dict + + +class SFTOutputTransform(Transform): + """Applied to each dataset sample to build the `"labels"` tensor for causal-LM SFT training. + + Expects sample to contain 1-D torch tensors + "tokens": token IDs, dtype=torch.long + "mask": bool/int where **True** marks positions to ignore + + If they are not tensors, they are converted to tensors. + + Produces ``"labels"`` of the same shape such that + labels[t] = tokens[t+1] # shift left + labels[t] = IGNORE_IDX if mask[t+1] # respect mask + labels[-1] = IGNORE_IDX # last token has no target + + All ops are vectorised; only one fresh tensor (`labels`) is allocated. + """ + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + + tokens = sample["tokens"] + mask = sample["mask"] + + # Sanity checks + if not isinstance(tokens, torch.Tensor): + tokens = torch.tensor(tokens) + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask) + + if tokens.ndim != 1 or mask.ndim != 1: + raise ValueError("Both 'tokens' and 'mask' must be 1-D tensors.") + + # build labels + # pre-fill with IGNORE so we don’t need extra assignments later + labels = tokens.new_full(tokens.shape, CROSS_ENTROPY_IGNORE_IDX) + + # left-shift via cheap views (no copy) + labels[:-1].copy_(tokens[1:]) + + # apply mask in-place (single fused kernel on GPU/CPU) + labels[:-1].masked_fill_(mask[1:].bool(), CROSS_ENTROPY_IGNORE_IDX) + + # return a shallow-copied mapping so the original sample stays intact + out = dict(sample) + out["labels"] = labels + return out + + +def sft_iterable_dataset( + model_transform: Transform, + *, + weight: int = 1, + message_transform: Transform, + shuffle_buffer_size: Optional[int] = 1000, + seed: int = 42, + num_shards_per_rank: int = 64, + dataset_name: Optional[str] = None, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[dict[str, Any]] = None, + **load_dataset_kwargs: dict[str, Any], +) -> HfIterableDataset: + """ + Creates an SFT-ready iterable dataset with appropriate output transform. + + Args: + model_transform (Transform): Usually the tokenizer + weight (int): Weight of the dataset. Used for sampling when interleaving datasets. + message_transform (Transform): Transform to convert raw data to messages + shuffle_buffer_size (Optional[int]): Buffer size for shuffling + seed (int): Random seed for shuffling + num_shards_per_rank (int): Target shards per worker + dataset_name (Optional[str]): Name for metrics namespacing + filter_fn (Optional[Callable]): Filter function + filter_kwargs (Optional[dict[str, Any]]): Filter function kwargs + **load_dataset_kwargs (dict[str, Any]): Args passed to load_dataset + + Returns: + HfIterableDataset: Configured for SFT training + + Example: + >>> from torchtune.data import AlpacaToMessages + >>> message_transform = AlpacaToMessages(train_on_input=False) + >>> ds = sft_iterable_dataset( + ... message_transform=message_transform, + ... model_transform=tokenizer, + ... path="tatsu-lab/alpaca" + ... ) + """ + + output_transform = SFTOutputTransform() + + return HfIterableDataset( + message_transform=message_transform, + model_transform=model_transform, + output_transform=output_transform, + metric_transform=DefaultTrainingMetricTransform(), + shuffle_buffer_size=shuffle_buffer_size, + weight=weight, + seed=seed, + num_shards_per_rank=num_shards_per_rank, + dataset_name=dataset_name, + filter_fn=filter_fn, + filter_kwargs=filter_kwargs, + **load_dataset_kwargs, + ) diff --git a/torchtune/datasets/_slimorca.py b/torchtune/datasets/_slimorca.py index ac49b56d63..0346e5b73a 100644 --- a/torchtune/datasets/_slimorca.py +++ b/torchtune/datasets/_slimorca.py @@ -7,9 +7,10 @@ from typing import Any, Callable, Optional, Union from torchtune.data import ShareGPTToMessages +from torchtune.datasets._hf_iterable import HfIterableDataset from torchtune.datasets._packed import PackedDataset -from torchtune.datasets._sft import SFTDataset +from torchtune.datasets._sft import sft_iterable_dataset, SFTDataset from torchtune.modules.transforms.tokenizers import ModelTokenizer @@ -94,3 +95,72 @@ def slimorca_dataset( ) return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len) return ds + + +def slimorca_iterable_dataset( + model_transform: ModelTokenizer, + *, + path: str = "Open-Orca/SlimOrca-Dedup", + split: str = "train", + column_map: Optional[dict[str, str]] = None, + train_on_input: bool = False, + new_system_prompt: Optional[str] = None, + shuffle_buffer_size: Optional[int] = 1000, + seed: int = 42, + num_shards_per_rank: int = 64, + dataset_name: Optional[str] = None, + filter_fn: Optional[Callable] = None, + filter_kwargs: Optional[dict[str, Any]] = None, + **load_dataset_kwargs: dict[str, Any], +) -> HfIterableDataset: + """ + Support for SlimOrca-style conversational datasets using iterable approach. + + This creates an infinite iterable dataset that automatically shards and shuffles data, + making it suitable for step-based training without explicit epoch boundaries. + + Args: + model_transform (ModelTokenizer): Model tokenizer used to tokenize the messages. + path (str): path to dataset repository on Hugging Face. Default is "Open-Orca/SlimOrca-Dedup". + split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a + subset of a given split, e.g. ``split="train[:10%]"``. Default is "train + column_map (Optional[dict[str, str]]): mapping from expected "conversations" column + to actual column name in dataset. If None, uses default "conversations". + train_on_input (bool): Whether to train on input or mask it. Default is False. + new_system_prompt (Optional[str]): If specified, prepend system message to every sample. + shuffle_buffer_size (Optional[int]): Size of shuffle buffer. If None or 0, no shuffling. + seed (int): Seed for shuffling. Default is 42. + num_shards_per_rank (int): Target number of shards per worker. Default is 64. + dataset_name (Optional[str]): Name for metrics. If None, auto-generated from source. + filter_fn (Optional[Callable]): Filter function to apply to dataset. + filter_kwargs (Optional[dict[str, Any]]): Kwargs for filter function. + **load_dataset_kwargs (dict[str, Any]): Additional kwargs for load_dataset. + + Returns: + HfIterableDataset: Configured iterable dataset + + Example: + >>> from torchtune.datasets import slimorca_iterable_dataset + >>> ds = slimorca_iterable_dataset(shuffle_buffer_size=1000) + >>> for sample in ds: + >>> print(sample["tokens"][:10]) # First 10 tokens + """ + message_transform = ShareGPTToMessages( + train_on_input=train_on_input, + column_map=column_map, + new_system_prompt=new_system_prompt, + ) + + return sft_iterable_dataset( + path=path, + split=split, + message_transform=message_transform, + model_transform=model_transform, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + num_shards_per_rank=num_shards_per_rank, + dataset_name=dataset_name, + filter_fn=filter_fn, + filter_kwargs=filter_kwargs, + **load_dataset_kwargs, + ) diff --git a/torchtune/training/checkpointing/_checkpoint_client.py b/torchtune/training/checkpointing/_checkpoint_client.py index 39b8989284..d943ad697f 100644 --- a/torchtune/training/checkpointing/_checkpoint_client.py +++ b/torchtune/training/checkpointing/_checkpoint_client.py @@ -19,6 +19,7 @@ StateDictOptions, ) from torchtune import config, training, utils +from torchtune.data.metrics import MetricsAggregator from torchtune.modules.optim import OptimizerInBackward from torchtune.modules.peft import ( get_adapter_state_dict, @@ -47,6 +48,7 @@ class TrainingProgress: total_training_steps: Optional[int] = None dataloader_state_dict: Optional[dict[str, Any]] = None val_dataloader_state_dict: Optional[dict[str, Any]] = None + metrics_aggregator_state_dict: Optional[dict[str, Any]] = None def state_dict(self) -> dict[str, object]: return { @@ -58,6 +60,7 @@ def state_dict(self) -> dict[str, object]: "total_training_steps": self.total_training_steps, training.DATALOADER_KEY: self.dataloader_state_dict, training.VAL_DATALOADER_KEY: self.val_dataloader_state_dict, + "metrics_aggregator_state_dict": self.metrics_aggregator_state_dict, } @@ -442,6 +445,7 @@ def load_distributed_checkpoint( adapter_config: Optional[dict[str, Any]] = None, dataloader: Optional[torchdata.stateful_dataloader.StatefulDataLoader] = None, single_device: bool = False, + metrics_aggregator: Optional[MetricsAggregator] = None, ) -> dict[str, Any]: """ This method is used to resume training from a distributed checkpoint state. @@ -459,6 +463,9 @@ def load_distributed_checkpoint( checkpoint_dict: dict[str, Any] = {} model_state_dict = model.state_dict() optim_state_dict = optimizer.state_dict() + metrics_aggregator_state_dict = ( + metrics_aggregator.state_dict() if metrics_aggregator else {} + ) # Hack to properly initialize the learning rate scheduler # TODO: Find a better way to do this, possibly by including the following @@ -481,6 +488,7 @@ def load_distributed_checkpoint( "steps_run": 0, "total_training_steps": 0, training.DATALOADER_KEY: dataloader.state_dict() if dataloader else {}, + "metrics_aggregator_state_dict": metrics_aggregator_state_dict, } )