Skip to content

Automatically Optimize grain_worker_count for Improved Data Loading Performance #2509

@bzantium

Description

@bzantium

What problem are you trying to solve?

The grain_worker_count parameter, which controls the number of parallel data loading workers, has a significant impact on training performance, especially when tokenizing raw text data on the fly. An incorrectly configured grain_worker_count can cause the data input and preprocessing pipeline to become a major bottleneck, leading to drastically reduced hardware utilization (TFLOP/s). Users must currently find the optimal value through manual trial and error, which is inefficient.

Why is this problem important?

Manually tuning grain_worker_count is a tedious process that requires multiple experimental runs to identify the best setting for a specific hardware and dataset configuration. This creates a poor user experience and can prevent users from achieving optimal training performance. An automatic solution would save significant time and effort, making it easier for users to maximize the efficiency of their training jobs right from the start.

Describe your requested feature or solution

We propose that MaxText add support for grain_worker_count = -1 as an opt-in feature to automatically determine the optimal number of data loading workers.

This implementation would leverage a function like grain.experimental.pick_performance_config to dynamically select the best value based on the system's capabilities.

While the default value will remain 1 (providing a stable, low-resource baseline), users can explicitly set grain_worker_count = -1 to enable this auto-tuning. This eliminates the guesswork for the user and prevents the data pipeline from bottlenecking the training process when high performance is desired.

Describe alternatives you’ve considered (if any)

The only alternative is the current method: manually setting and testing different integer values for grain_worker_count. This is the inefficient process we are seeking to improve.

Additional context or examples

The performance impact is clear from training experiments on a v6e-32 pod. Below is a summary of the TeraFLOPs per second per device (TFLOP/s/device) and the average time per step observed when training a Llama3-8B model with varying grain_worker_count values.

As shown, a low worker count (1-2) leads to slow, erratic step times and low throughput. Performance stabilizes dramatically with 4 and 8 workers. Critically, setting grain_worker_count to -1 (auto-tune) successfully achieves the same stable, high-throughput performance as a manually optimized value (like 8), reaching ~195 TFLOP/s and a consistent step time of ~4.3 seconds.

grain_worker_count Average TFLOP/s/device (Steps 3-9) Average Time/Step (s) (Steps 3-9) Stability
1 ~29 TFLOP/s ~30.6 s Unstable
2 ~60 TFLOP/s ~13.5 s Highly Unstable
4 ~195 TFLOP/s ~4.3 s Weakly Unstable
8 ~195 TFLOP/s ~4.3 s Stable
-1 (auto) ~195 TFLOP/s ~4.3 s Stable

This new feature allows users to opt-in to an automatic configuration that selects an optimal value, ensuring stable and efficient training.

Full Logs for Reference:

Click to expand logs

grain_worker_count = 1

# TFLOP/s/device values: 13.8, 78.8, 29.4, 27.7, 28.7, 25.1, 23.1, 28.1, 29.9, 32.0
# seconds per step: 61.1, 10.7, 28.7, 30.5, 29.4, 33.6, 36.5, 30.0, 28.2, 26.3

grain_worker_count = 2

# TFLOP/s/device values: 14.5, 3267.4, 73.2, 129458.3, 28.2, 129717.1, 31.9, 517.3, 33.5, 455.1
# seconds per step: 58.3, 0.2, 11.5, 0.007, 29.9, 0.007, 26.5, 1.6, 25.2, 1.9

grain_worker_count = 4

# TFLOP/s/device values: 15.2, 3385.9, 207.3, 195.9, 195.3, 195.0, 195.3, 195.4, 85.9, 137321.5
# seconds per step: 55.3, 0.2, 4.1, 4.3, 4.3, 4.3, 4.3, 4.3, 9.8, 0.006

grain_worker_count = 8

# TFLOP/s/device values: 17.7, 3253.9, 109.3, 196.1, 195.4, 195.5, 195.6, 195.4, 195.5, 195.6
# seconds per step: 47.6, 0.3, 7.7, 4.3, 4.3, 4.3, 4.3, 4.3, 4.3, 4.3

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions