|
18 | 18 | from pathlib import Path |
19 | 19 | import functools |
20 | 20 | import ml_collections |
| 21 | +from concurrent import futures |
21 | 22 |
|
22 | 23 | import jax |
23 | 24 |
|
@@ -53,26 +54,58 @@ def get_datasets( |
53 | 54 | dataloading_host_index, |
54 | 55 | dataloading_host_count, |
55 | 56 | grain_worker_count, |
| 57 | + grain_num_threads, |
| 58 | + grain_prefetch_buffer_size, |
| 59 | + grain_data_source_max_workers, |
56 | 60 | ): |
57 | 61 | """Load dataset from array_record files for using with grain""" |
58 | 62 | if data_file_type == "arrayrecord": |
| 63 | + # Helper function to find files, create data source, and wrap in MapDataset |
| 64 | + def create_dataset_from_pattern(pattern): |
| 65 | + files = find_data_files(pattern) |
| 66 | + source = grain.ArrayRecordDataSource(files) |
| 67 | + return grain.MapDataset.source(source) |
| 68 | + |
59 | 69 | if ";" in data_file_pattern: |
60 | 70 | data_file_patterns, weights = zip(*[pattern.split(",") for pattern in data_file_pattern.split(";")]) |
61 | 71 | assert len(data_file_patterns) == len(weights), "Number of data file patterns and weights must match" |
62 | 72 | weights = [float(weight) for weight in weights] |
63 | 73 | weights = [round(weight / sum(weights), 4) for weight in weights] |
64 | | - dataset_list = [ |
65 | | - grain.MapDataset.source(grain.ArrayRecordDataSource(find_data_files(pattern))) for pattern in data_file_patterns |
66 | | - ] |
67 | | - dataset = grain.MapDataset.mix(dataset_list, weights) |
| 74 | + |
| 75 | + # Parallelize file finding (globbing), data source creation, and dataset wrapping |
| 76 | + # File finding and source creation are I/O-bound operations that release the GIL |
| 77 | + executor = futures.ThreadPoolExecutor(max_workers=grain_data_source_max_workers) |
| 78 | + dataset_list = list(executor.map(create_dataset_from_pattern, data_file_patterns)) |
| 79 | + executor.shutdown(wait=True) |
| 80 | + |
| 81 | + # Apply shuffle, repeat, sharding, and conversion to IterDataset to each dataset before mixing |
| 82 | + for d, _ in enumerate(dataset_list): |
| 83 | + if shuffle: |
| 84 | + dataset_list[d] = dataset_list[d].shuffle(seed=shuffle_seed) |
| 85 | + dataset_list[d] = dataset_list[d].repeat(num_epoch) |
| 86 | + dataset_list[d] = dataset_list[d][dataloading_host_index::dataloading_host_count] # sharding |
| 87 | + dataset_list[d] = dataset_list[d].to_iter_dataset( |
| 88 | + read_options=grain.ReadOptions( |
| 89 | + num_threads=grain_num_threads, |
| 90 | + prefetch_buffer_size=grain_prefetch_buffer_size, |
| 91 | + ) |
| 92 | + ) |
| 93 | + # Use IterDataset.mix instead of MapDataset.mix in order to have per-mixture component checkpoints |
| 94 | + # for supporting changing the mixture after checkpointing |
| 95 | + dataset = grain.IterDataset.mix(dataset_list, weights) |
68 | 96 | else: |
69 | | - data_files = find_data_files(data_file_pattern) |
70 | | - dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files)) |
71 | | - if shuffle: |
72 | | - dataset = dataset.shuffle(seed=shuffle_seed) |
73 | | - dataset = dataset.repeat(num_epoch) |
74 | | - dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding |
75 | | - dataset = dataset.to_iter_dataset() |
| 97 | + # Single pattern case - no need for parallelization |
| 98 | + dataset = create_dataset_from_pattern(data_file_pattern) |
| 99 | + if shuffle: |
| 100 | + dataset = dataset.shuffle(seed=shuffle_seed) |
| 101 | + dataset = dataset.repeat(num_epoch) |
| 102 | + dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding |
| 103 | + dataset = dataset.to_iter_dataset( |
| 104 | + read_options=grain.ReadOptions( |
| 105 | + num_threads=grain_num_threads, |
| 106 | + prefetch_buffer_size=grain_prefetch_buffer_size, |
| 107 | + ) |
| 108 | + ) |
76 | 109 | elif data_file_type == "parquet": |
77 | 110 | data_files = find_data_files(data_file_pattern) |
78 | 111 | dataset = grain.MapDataset.source(data_files) |
@@ -237,6 +270,9 @@ def make_grain_train_iterator( |
237 | 270 | dataloading_host_index=process_indices.index(jax.process_index()), |
238 | 271 | dataloading_host_count=len(process_indices), |
239 | 272 | grain_worker_count=config.grain_worker_count, |
| 273 | + grain_num_threads=config.grain_num_threads, |
| 274 | + grain_prefetch_buffer_size=config.grain_prefetch_buffer_size, |
| 275 | + grain_data_source_max_workers=config.grain_data_source_max_workers, |
240 | 276 | ) |
241 | 277 | if config.use_dpo: |
242 | 278 | train_dataloader = dpo_preprocessing_pipeline( |
@@ -271,6 +307,9 @@ def make_grain_train_iterator( |
271 | 307 | shuffle_seed=config.data_shuffle_seed, |
272 | 308 | num_epoch=config.num_epoch, |
273 | 309 | grain_worker_count=config.grain_worker_count, |
| 310 | + grain_num_threads=config.grain_num_threads, |
| 311 | + grain_prefetch_buffer_size=config.grain_prefetch_buffer_size, |
| 312 | + grain_data_source_max_workers=config.grain_data_source_max_workers, |
274 | 313 | ) |
275 | 314 | if config.use_dpo: |
276 | 315 | preprocessing_fn = functools.partial( |
@@ -328,6 +367,9 @@ def make_grain_eval_iterator( |
328 | 367 | dataloading_host_index=process_indices.index(jax.process_index()), |
329 | 368 | dataloading_host_count=len(process_indices), |
330 | 369 | grain_worker_count=config.grain_worker_count_eval, |
| 370 | + grain_num_threads=config.grain_num_threads_eval, |
| 371 | + grain_prefetch_buffer_size=config.grain_prefetch_buffer_size_eval, |
| 372 | + grain_data_source_max_workers=config.grain_data_source_max_workers, |
331 | 373 | ) |
332 | 374 | if config.use_dpo: |
333 | 375 | eval_dataloader = dpo_preprocessing_pipeline( |
@@ -359,6 +401,9 @@ def make_grain_eval_iterator( |
359 | 401 | shuffle_seed=config.data_shuffle_seed, |
360 | 402 | num_epoch=1, |
361 | 403 | grain_worker_count=config.grain_worker_count_eval, |
| 404 | + grain_num_threads=config.grain_num_threads_eval, |
| 405 | + grain_prefetch_buffer_size=config.grain_prefetch_buffer_size_eval, |
| 406 | + grain_data_source_max_workers=config.grain_data_source_max_workers, |
362 | 407 | ) |
363 | 408 | if config.use_dpo: |
364 | 409 | preprocessing_fn = functools.partial( |
|
0 commit comments