Skip to content

Iterable Dataset #2852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: impl-step-based-ckpt
Choose a base branch
from

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Jun 26, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Enable Iterable datasets in torchtune.

CONTEXT: built on top of ongoing PR step-based-ckpt: #2384

TIps when reviewing this pr

Follow this order:

  1. recipes/configs/llama3_2/3B_full.yaml: see the configs
  2. torchtune/datasets/_iterable_base.py: base class for iterable dataset
  3. torchtune/datasets/_hf_iterable.py: ds based on HF -- Can be replaced easily. Downstream does not expect HF.
  4. torchtune/datasets/_interleaved.py: interleave the datasets
  5. torchtune/data/_metrics.py: metrics transform to create the metrics
  6. torchtune/data/_aggregator.py: aggregate the metrics at the recipe level
  7. recipes/full_finetune_distributed.py: everything put together
  8. unit tests

torchtune/datasets/_hf_iterable.py

Changelog

  1. Datasets are infinite
  2. User doesn't define epochs anymore, but training steps (how many times we update the optimizer)
  3. Support for dataset mixing -- follow up PRs is to enable curriculum learning
  4. Support for dataset metric logging -- User can understand epoch per dataset, distribution of token lens, etc. Easy to add new metrics.
  5. HF agnostic. Even though the current dataset is HF, the dataloader, packed, datamixing, metric logging is agnostic to it
  6. Well tested in distributed setting -- WARNING: need better testing for multiprocess dataloader. It doesnt guarantee determinism, so I postponed testing this setting

Config and builder design based on the discussions after this RFC: #2785

Next steps:
7. Gather feedback on metric logging. E.g. we can add more aggregation types.
8. Polish the code a little bit
9. Add packing from this RFC: #2819
10. Add curriculum learning
11. Docs?

Test plan

image image image

UNTESTED: resume from ckpt in the recipe. However, we have plenty of tests showing that resuming works for these iterable datasets.

Copy link

pytorch-bot bot commented Jun 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2852

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Cancelled Jobs

As of commit f89eefe with merge base 3d73591 (image):

NEW FAILURE - The following job has failed:

  • GPU tests / gpu_test (3.11, stable) (gh)
    tests/recipes/test_qat_lora_finetune_distributed.py::TestQATLoRAFinetuneDistributedRecipe::test_training_state_on_resume_with_async_checkpointing[llama3/8B_qat_lora-llama3-tune-False]

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 26, 2025
@felipemello1 felipemello1 changed the title first commit Iterable Dataset Jun 26, 2025
@@ -94,3 +95,72 @@ def slimorca_dataset(
)
return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len)
return ds


def slimorca_iterable_dataset(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added here to demonstrate datamix iterable dataset with this example. Personally, i dislike exposing all of the args and defaults. I would prefer to expose only whats specific to this builder.

Comment on lines 101 to 104
logger.warning(
f"Child dataset {self._datasets[ds_name].dataset_name} was exhausted. "
"This is unexpected for an infinite dataset. Re-initializing its iterator."
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not 100% sure i like this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this: simply have a subclass for InfiniteIterable so this is super explicit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did this one land? I don't see InfiniteIterable anywhere (personally I don't know enough yet to have a strong preference here, just wanna understand where things currently stand)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i made changes but didnt push them yet. I added a dummy class that does nothing:

class InfiniteTuneIterableDataset(TuneIterableDataset):
    """Abstract base class for infinite datasets, which yield samples indefinitely.
    It only purpose is to make it explicit that the dataset is expected to be infinite, i.e. 
    it never exhausts. This is helpful to avoid complexity due to some rank hanging because
    of lack of data""
    pass

and replaced this logger.warning with raise ValueError.

I think its better to have zero tolerance. Datasets that are not infinite need work to make sure no rank hangs.

@@ -101,3 +102,64 @@ def alpaca_dataset(
original Alpaca dataset, `yahma/alpaca-cleaned <https://huggingface.co/datasets/yahma/alpaca-cleaned>`_.
See the dataset page and :func:`~torchtune.datasets.alpaca_dataset` for more details.
"""


def alpaca_iterable_dataset(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added here to demonstrate datamix iterable dataset with this example. Personally, i dislike exposing all of the args and defaults. I would prefer to expose only whats specific to this builder.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you are doing this with ``load_dataset_kwargs, right? Or did you mean something else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's a function, so... get_alpaca_iterable_dataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the get makes sense, but its not the pattern we have in tune :/

Copy link
Contributor

@Darktex Darktex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR! I mainly had a question on the interaction with packing and on the SFT transform

@@ -101,3 +102,64 @@ def alpaca_dataset(
original Alpaca dataset, `yahma/alpaca-cleaned <https://huggingface.co/datasets/yahma/alpaca-cleaned>`_.
See the dataset page and :func:`~torchtune.datasets.alpaca_dataset` for more details.
"""


def alpaca_iterable_dataset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you are doing this with ``load_dataset_kwargs, right? Or did you mean something else?

Comment on lines 101 to 104
logger.warning(
f"Child dataset {self._datasets[ds_name].dataset_name} was exhausted. "
"This is unexpected for an infinite dataset. Re-initializing its iterator."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this: simply have a subclass for InfiniteIterable so this is super explicit

from torch.utils.data import IterableDataset


class TuneIterableDataset(IterableDataset, ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this guy to interact with packing and IIUC I don't believe this is currently happening?

The algo we should implement is this:

  1. One batch can be made of multiple calls to next. We keep taking until we exceed the max seq len. When we do, we put the last one aside (we'll use it to start the next batch), pad the current one to max len and return.
  2. The calls to next will go to the interleaved dataset, therefore we automatically construct mixed batches from multiple datasets without much effort
  3. Also, every time we call next we should make space for logging transforms (which we are, you already wrote them). I think it's ok to make your metrics transforms and aggregators an optional property here so the semantics are clearer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have packing here: #2819

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read every line of this PR. (Kidding but I tried to at least look at most of the important stuff.) Thanks for taking on this massive set of changes, I think the dataset classes are a big improvement

Comment on lines 127 to 132
self.new_metric(
name="tokens_seen", value=token_len, agg_type=AggregationType.SUM
),
self.new_metric(
name="seq_len", value=token_len, agg_type=AggregationType.DISTRIBUTION
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A minor thing, but to me metrics having the same value but different aggregation types should not actually be represented as distinct metrics. Like I should be able to just define how a metric is computed for a given sample, then separately choose different types of aggregation as needed

Copy link
Contributor Author

@felipemello1 felipemello1 Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, so agg_type being a List[AggregationType]?

self.new_metric(
                name="tokens_seen", value=token_len, 
                agg_type=[AggregationType.SUM, AggregationType.MEAN
            ),

I dont know if the extra complexity is worth it. Adding two metrics is cheap. wdyt?

Today the user can just do:
self.new_metric(
                name="tokens_seen_sum", value=token_len, 
                agg_type=AggregationType.SUM
            ), 
self.new_metric(
                name="tokens_seen_mean", value=token_len, 
                agg_type=AggregationType.MEAN
            ), 

from torchtune.data.metrics._metric_transform import AggregationType, Metric


class MetricsAggregator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A high level comment: the relationship between this and the agg handlers is not super clear to me. It seems like we are using a registry pattern where the handlers are responsible for defining the actual aggregation logic. But then the all-gather happens in here. (Separately I stand by my claim that it would be better to hold off on more complex cases like distribution aggregators so as not to boil the ocean here.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the handlers are responsible for defining the actual aggregation logic. But then the all-gather happens in here.

why is that a contradiction?

  1. The MetricsAggregator calls the handler.finalize_local_agg
  2. then does a single all_gather to get the results from all ranks for all metrics
  3. Then calls handler._finalize_dist_agg([aggregated_results_per_rank]*n_ranks)

Do you wanna suggest a different way of doing it? Or is it hard to spot this pattern in the code?

Copy link
Contributor Author

@felipemello1 felipemello1 Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my claim that it would be better to hold off on more complex cases like distribution aggregators

If we dont do aggregation across ranks, we wouldnt be able to count things like "tokens_seen", right? :/

Or do you mean that we should delete DistributionAggHandler? To clarify, this distribution has nothing to do with multiple gpus. Its just stats, e.g. std, percentiles, max, min, etc. Maybe i should rename if its causing confusion.

Comment on lines 312 to 315
if cfg.get("dataset_val") is not None:
raise NotImplementedError(
"Validation is not supported yet with iterable datasets."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific technical reason here? Or we just haven't gotten to it yet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validation datasets are not infinite!!! Need to figure out how to solve this one, but it wont be on this PR

@codecov-commenter
Copy link

codecov-commenter commented Jul 7, 2025

Codecov Report

Attention: Patch coverage is 74.18831% with 318 lines in your changes missing coverage. Please review.

Please upload report for BASE (impl-step-based-ckpt@54a48bb). Learn more about missing BASE report.

Files with missing lines Patch % Lines
recipes/full_finetune_distributed.py 0.00% 92 Missing ⚠️
tests/torchtune/datasets/test_interleaved.py 78.18% 53 Missing ⚠️
tests/torchtune/data/test_metrics_aggregator.py 71.42% 44 Missing ⚠️
tests/torchtune/datasets/test_hf_iterable.py 73.75% 37 Missing ⚠️
torchtune/data/metrics/_metric_agg_handlers.py 84.45% 23 Missing ⚠️
torchtune/data/metrics/_metric_aggregator.py 77.66% 23 Missing ⚠️
torchtune/datasets/_hf_iterable.py 84.76% 16 Missing ⚠️
torchtune/datasets/_sft.py 33.33% 16 Missing ⚠️
torchtune/datasets/_iterable_base.py 88.88% 4 Missing ⚠️
...htune/training/checkpointing/_checkpoint_client.py 0.00% 3 Missing ⚠️
... and 4 more
Additional details and impacted files
@@                   Coverage Diff                   @@
##             impl-step-based-ckpt    #2852   +/-   ##
=======================================================
  Coverage                        ?   60.64%           
=======================================================
  Files                           ?      449           
  Lines                           ?    28224           
  Branches                        ?        0           
=======================================================
  Hits                            ?    17116           
  Misses                          ?    11108           
  Partials                        ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

# HuggingFace datasets bug where .map() causes incorrect checkpoint resumption.
# See: https://github.com/huggingface/datasets/issues/7630
# This ensures transforms are applied fresh on each sample during iteration.
sample = self._apply_transforms(sample)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applying transformations inside the dataset before returning every sample would mean no possibility of parallelizing them (either within every dataset or across datasets). Is that expected?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced offline. I had the wrong impression. We shard the dataset so if a user turned on num_workers>0 it would lead to multiple processes all reading the same dataset but different shards of it; and so apply transformations in each process.

Comment on lines +203 to +205
# Shuffle the dataset
if self._shuffle_buffer_size and self._shuffle_buffer_size > 0:
ds = ds.shuffle(seed=self._seed, buffer_size=self._shuffle_buffer_size)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better if we shuffled before sharding ?

Copy link
Contributor Author

@felipemello1 felipemello1 Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe the sharding is happening when we call .to_iterable_dataset. If we do after split_by_node, then i guess the shuffle would only happen inside of the node, and not across nodes.

e.g.
shuffling before:
[0,1,2,3,4,5] -> [4,1,0,5,3,2]

shuffling after
[0,1], [2,3], [4,5] -> [1,0], [2,3], [5,4]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. So we should shuffle first, then shard with .to_iterable_dataset call, then split_by_node.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, why not this?

  1. .to_iterable_dataset(num_shards)
  2. shuffle
  3. split_by_node

sharding happens at 1. Shouldnt we shuffle after sharding?

Copy link
Contributor Author

@felipemello1 felipemello1 Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, thanks for helping me double check this

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can shuffle after sharding, but then the shuffle will only be within the shard (unless HF's ds.shuffle is doing something non-obvious here). The example you had above is clearly depicting that. Smaller the sample size, less optimal the shuffle. But obv this is something that can be tweaked based on perf.

Copy link
Contributor Author

@felipemello1 felipemello1 Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can shuffle after sharding, but then the shuffle will only be within the shard (

oh, i see. I guess the shuffle will still be across all shards, because the shards are not assigned to any rank yet. Just need to make sure that every rank uses the same seed. The issue is if we shuffle after split_by_node. But i need to double check that in their docs/forum. Last time i saw that was a few weeks ago.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants