[docs] add Tunix GRPO example #236
Conversation
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a new Tunix GRPO example, including documentation and a Python launcher script for a smoke test. The review feedback identified an incorrect mesh configuration for 8-TPU setups and suggested promoting the tokenizer_path to a function parameter to align with Kinetic's API design guidelines for configurable resources.
| if num_tpus == 8: | ||
| mesh_shape = (1, 4) |
There was a problem hiding this comment.
The mesh shape for 8 TPUs is incorrect. jax.make_mesh requires the product of the mesh shape dimensions to equal the total number of devices in the mesh. For 8 devices, a shape of (1, 4) will result in a ValueError. This should be updated to a valid configuration such as (2, 4) (2-way FSDP, 4-way TP) to utilize all available devices.
| if num_tpus == 8: | |
| mesh_shape = (1, 4) | |
| if num_tpus == 8: | |
| mesh_shape = (2, 4) |
| ) | ||
| def run_tunix_grpo( | ||
| data_dir: str, | ||
| model_id: str = "google/gemma-3-270m-it", |
There was a problem hiding this comment.
The tokenizer_path should be promoted to a function parameter. According to the Kinetic API design guidelines, configurable resources should be resolvable through explicit parameters to the entry point function. This allows users to override the path via the CLI or environment variables without modifying the script. Remember to remove the local tokenizer_path definition on line 105 if you apply this change.
| model_id: str = "google/gemma-3-270m-it", | |
| model_id: str = "google/gemma-3-270m-it", | |
| tokenizer_path: str = "gs://gemma-data/tokenizers/tokenizer_gemma3.model", |
References
- Every configurable resource name (project, zone, cluster, namespace, etc.) must be resolvable through an explicit parameter to the entry point function. (link)
There was a problem hiding this comment.
Code Review
This pull request introduces a new example for GRPO post-training using Tunix on Kinetic TPUs, including documentation and a comprehensive Python script with a smoke dataset and reward functions. Feedback includes a correction to the TPU mesh configuration for 8-device setups to ensure full resource utilization and a recommendation to load the tokenizer from the local model directory rather than using a hardcoded GCS path to improve robustness.
| if num_tpus == 8: | ||
| mesh_shape = (1, 4) |
There was a problem hiding this comment.
The mesh_shape for 8 TPUs is set to (1, 4), which only accounts for 4 devices. jax.make_mesh requires the product of the shape dimensions to match the total number of available devices (8 in this case). This will result in a ValueError at runtime. It should be updated to use all 8 devices, for example (1, 8) or (2, 4).
| if num_tpus == 8: | |
| mesh_shape = (1, 4) | |
| if num_tpus == 8: | |
| mesh_shape = (1, 8) |
References
- Poke Holes in the Implementation: Actively search for and point out failing edge cases, race conditions, or unhandled exceptions in the implementation.
| output_dir = os.environ.get("KINETIC_OUTPUT_DIR", "/tmp/tunix-grpo") | ||
| checkpoint_dir = f"{output_dir.rstrip('/')}/checkpoints" | ||
|
|
||
| tokenizer_path = "gs://gemma-data/tokenizers/tokenizer_gemma3.model" |
There was a problem hiding this comment.
Hardcoding a GCS path for the tokenizer (gs://gemma-data/...) is brittle and may not be accessible to external users. Since the model is already being downloaded from Hugging Face using snapshot_download, it is recommended to load the tokenizer from the local model directory instead to ensure portability and robustness.
References
- Demand Robustness: Do not accept fragile code. If the proposed code is not robust enough... explicitly tell the author why the current approach is brittle.
Added a Tunix GRPO example