Skip to content

Commit f601e42

Browse files
Merge pull request #2703 from AI-Hypercomputer:aireen/grain_source_pool
PiperOrigin-RevId: 834944697
2 parents acd9e07 + c7bacec commit f601e42

File tree

3 files changed

+72
-11
lines changed

3 files changed

+72
-11
lines changed

src/MaxText/configs/base.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,8 +596,15 @@ grain_eval_files: ''
596596
grain_file_type: 'arrayrecord' # arrayrecord or parquet
597597
grain_worker_count: 1
598598
grain_per_worker_buffer_size: 1
599+
# num_threads and prefetch_buffer_size are per-worker per-dataset. Used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions)
600+
# The default value matches that in the Grain package. If mixing multiple data sources, consider lowering these values to reduce memory usage.
601+
grain_num_threads: 16
602+
grain_prefetch_buffer_size: 500
599603
grain_worker_count_eval: 1
600604
grain_per_worker_buffer_size_eval: 1
605+
grain_num_threads_eval: 16
606+
grain_prefetch_buffer_size_eval: 500
607+
grain_data_source_max_workers: 16 # Max workers for ThreadPoolExecutor when mixing multiple Grain data sources.
601608
# for using pathways
602609
colocated_python_data_input: False # experimental feature, under testing
603610

src/MaxText/configs/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,15 @@ class GrainDataset(BaseModel):
857857
grain_per_worker_buffer_size_eval: int = Field(
858858
1, description="Buffer size for each worker for Grain data loading during evaluation."
859859
)
860+
grain_num_threads: int = Field(16, description="Number of threads for Grain ReadOptions during training.")
861+
grain_prefetch_buffer_size: int = Field(500, description="Prefetch buffer size for Grain ReadOptions during training.")
862+
grain_num_threads_eval: int = Field(16, description="Number of threads for Grain ReadOptions during evaluation.")
863+
grain_prefetch_buffer_size_eval: int = Field(
864+
500, description="Prefetch buffer size for Grain ReadOptions during evaluation."
865+
)
866+
grain_data_source_max_workers: int = Field(
867+
16, description="Max workers for ThreadPoolExecutor when mixing multiple Grain data sources."
868+
)
860869

861870

862871
class FineTuning(BaseModel):

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pathlib import Path
1919
import functools
2020
import ml_collections
21+
from concurrent import futures
2122

2223
import jax
2324

@@ -53,26 +54,58 @@ def get_datasets(
5354
dataloading_host_index,
5455
dataloading_host_count,
5556
grain_worker_count,
57+
grain_num_threads,
58+
grain_prefetch_buffer_size,
59+
grain_data_source_max_workers,
5660
):
5761
"""Load dataset from array_record files for using with grain"""
5862
if data_file_type == "arrayrecord":
63+
# Helper function to find files, create data source, and wrap in MapDataset
64+
def create_dataset_from_pattern(pattern):
65+
files = find_data_files(pattern)
66+
source = grain.ArrayRecordDataSource(files)
67+
return grain.MapDataset.source(source)
68+
5969
if ";" in data_file_pattern:
6070
data_file_patterns, weights = zip(*[pattern.split(",") for pattern in data_file_pattern.split(";")])
6171
assert len(data_file_patterns) == len(weights), "Number of data file patterns and weights must match"
6272
weights = [float(weight) for weight in weights]
6373
weights = [round(weight / sum(weights), 4) for weight in weights]
64-
dataset_list = [
65-
grain.MapDataset.source(grain.ArrayRecordDataSource(find_data_files(pattern))) for pattern in data_file_patterns
66-
]
67-
dataset = grain.MapDataset.mix(dataset_list, weights)
74+
75+
# Parallelize file finding (globbing), data source creation, and dataset wrapping
76+
# File finding and source creation are I/O-bound operations that release the GIL
77+
executor = futures.ThreadPoolExecutor(max_workers=grain_data_source_max_workers)
78+
dataset_list = list(executor.map(create_dataset_from_pattern, data_file_patterns))
79+
executor.shutdown(wait=True)
80+
81+
# Apply shuffle, repeat, sharding, and conversion to IterDataset to each dataset before mixing
82+
for d, _ in enumerate(dataset_list):
83+
if shuffle:
84+
dataset_list[d] = dataset_list[d].shuffle(seed=shuffle_seed)
85+
dataset_list[d] = dataset_list[d].repeat(num_epoch)
86+
dataset_list[d] = dataset_list[d][dataloading_host_index::dataloading_host_count] # sharding
87+
dataset_list[d] = dataset_list[d].to_iter_dataset(
88+
read_options=grain.ReadOptions(
89+
num_threads=grain_num_threads,
90+
prefetch_buffer_size=grain_prefetch_buffer_size,
91+
)
92+
)
93+
# Use IterDataset.mix instead of MapDataset.mix in order to have per-mixture component checkpoints
94+
# for supporting changing the mixture after checkpointing
95+
dataset = grain.IterDataset.mix(dataset_list, weights)
6896
else:
69-
data_files = find_data_files(data_file_pattern)
70-
dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files))
71-
if shuffle:
72-
dataset = dataset.shuffle(seed=shuffle_seed)
73-
dataset = dataset.repeat(num_epoch)
74-
dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding
75-
dataset = dataset.to_iter_dataset()
97+
# Single pattern case - no need for parallelization
98+
dataset = create_dataset_from_pattern(data_file_pattern)
99+
if shuffle:
100+
dataset = dataset.shuffle(seed=shuffle_seed)
101+
dataset = dataset.repeat(num_epoch)
102+
dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding
103+
dataset = dataset.to_iter_dataset(
104+
read_options=grain.ReadOptions(
105+
num_threads=grain_num_threads,
106+
prefetch_buffer_size=grain_prefetch_buffer_size,
107+
)
108+
)
76109
elif data_file_type == "parquet":
77110
data_files = find_data_files(data_file_pattern)
78111
dataset = grain.MapDataset.source(data_files)
@@ -237,6 +270,9 @@ def make_grain_train_iterator(
237270
dataloading_host_index=process_indices.index(jax.process_index()),
238271
dataloading_host_count=len(process_indices),
239272
grain_worker_count=config.grain_worker_count,
273+
grain_num_threads=config.grain_num_threads,
274+
grain_prefetch_buffer_size=config.grain_prefetch_buffer_size,
275+
grain_data_source_max_workers=config.grain_data_source_max_workers,
240276
)
241277
if config.use_dpo:
242278
train_dataloader = dpo_preprocessing_pipeline(
@@ -271,6 +307,9 @@ def make_grain_train_iterator(
271307
shuffle_seed=config.data_shuffle_seed,
272308
num_epoch=config.num_epoch,
273309
grain_worker_count=config.grain_worker_count,
310+
grain_num_threads=config.grain_num_threads,
311+
grain_prefetch_buffer_size=config.grain_prefetch_buffer_size,
312+
grain_data_source_max_workers=config.grain_data_source_max_workers,
274313
)
275314
if config.use_dpo:
276315
preprocessing_fn = functools.partial(
@@ -328,6 +367,9 @@ def make_grain_eval_iterator(
328367
dataloading_host_index=process_indices.index(jax.process_index()),
329368
dataloading_host_count=len(process_indices),
330369
grain_worker_count=config.grain_worker_count_eval,
370+
grain_num_threads=config.grain_num_threads_eval,
371+
grain_prefetch_buffer_size=config.grain_prefetch_buffer_size_eval,
372+
grain_data_source_max_workers=config.grain_data_source_max_workers,
331373
)
332374
if config.use_dpo:
333375
eval_dataloader = dpo_preprocessing_pipeline(
@@ -359,6 +401,9 @@ def make_grain_eval_iterator(
359401
shuffle_seed=config.data_shuffle_seed,
360402
num_epoch=1,
361403
grain_worker_count=config.grain_worker_count_eval,
404+
grain_num_threads=config.grain_num_threads_eval,
405+
grain_prefetch_buffer_size=config.grain_prefetch_buffer_size_eval,
406+
grain_data_source_max_workers=config.grain_data_source_max_workers,
362407
)
363408
if config.use_dpo:
364409
preprocessing_fn = functools.partial(

0 commit comments

Comments
 (0)