Skip to content

[docs] add Tunix GRPO example #236

Draft
ChiragSW wants to merge 3 commits into
keras-team:mainfrom
ChiragSW:issue#225
Draft

[docs] add Tunix GRPO example #236
ChiragSW wants to merge 3 commits into
keras-team:mainfrom
ChiragSW:issue#225

Conversation

@ChiragSW
Copy link
Copy Markdown

Added a Tunix GRPO example

@ChiragSW
Copy link
Copy Markdown
Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Tunix GRPO example, including documentation and a Python launcher script for a smoke test. The review feedback identified an incorrect mesh configuration for 8-TPU setups and suggested promoting the tokenizer_path to a function parameter to align with Kinetic's API design guidelines for configurable resources.

Comment thread examples/tunix_grpo.py Outdated
Comment on lines +157 to +158
if num_tpus == 8:
mesh_shape = (1, 4)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The mesh shape for 8 TPUs is incorrect. jax.make_mesh requires the product of the mesh shape dimensions to equal the total number of devices in the mesh. For 8 devices, a shape of (1, 4) will result in a ValueError. This should be updated to a valid configuration such as (2, 4) (2-way FSDP, 4-way TP) to utilize all available devices.

Suggested change
if num_tpus == 8:
mesh_shape = (1, 4)
if num_tpus == 8:
mesh_shape = (2, 4)

Comment thread examples/tunix_grpo.py
)
def run_tunix_grpo(
data_dir: str,
model_id: str = "google/gemma-3-270m-it",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tokenizer_path should be promoted to a function parameter. According to the Kinetic API design guidelines, configurable resources should be resolvable through explicit parameters to the entry point function. This allows users to override the path via the CLI or environment variables without modifying the script. Remember to remove the local tokenizer_path definition on line 105 if you apply this change.

Suggested change
model_id: str = "google/gemma-3-270m-it",
model_id: str = "google/gemma-3-270m-it",
tokenizer_path: str = "gs://gemma-data/tokenizers/tokenizer_gemma3.model",
References
  1. Every configurable resource name (project, zone, cluster, namespace, etc.) must be resolvable through an explicit parameter to the entry point function. (link)

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new example for GRPO post-training using Tunix on Kinetic TPUs, including documentation and a comprehensive Python script with a smoke dataset and reward functions. Feedback includes a correction to the TPU mesh configuration for 8-device setups to ensure full resource utilization and a recommendation to load the tokenizer from the local model directory rather than using a hardcoded GCS path to improve robustness.

Comment thread examples/tunix_grpo.py Outdated
Comment on lines +157 to +158
if num_tpus == 8:
mesh_shape = (1, 4)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The mesh_shape for 8 TPUs is set to (1, 4), which only accounts for 4 devices. jax.make_mesh requires the product of the shape dimensions to match the total number of available devices (8 in this case). This will result in a ValueError at runtime. It should be updated to use all 8 devices, for example (1, 8) or (2, 4).

Suggested change
if num_tpus == 8:
mesh_shape = (1, 4)
if num_tpus == 8:
mesh_shape = (1, 8)
References
  1. Poke Holes in the Implementation: Actively search for and point out failing edge cases, race conditions, or unhandled exceptions in the implementation.

Comment thread examples/tunix_grpo.py Outdated
output_dir = os.environ.get("KINETIC_OUTPUT_DIR", "/tmp/tunix-grpo")
checkpoint_dir = f"{output_dir.rstrip('/')}/checkpoints"

tokenizer_path = "gs://gemma-data/tokenizers/tokenizer_gemma3.model"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding a GCS path for the tokenizer (gs://gemma-data/...) is brittle and may not be accessible to external users. Since the model is already being downloaded from Hugging Face using snapshot_download, it is recommended to load the tokenizer from the local model directory instead to ensure portability and robustness.

References
  1. Demand Robustness: Do not accept fragile code. If the proposed code is not robust enough... explicitly tell the author why the current approach is brittle.

@ChiragSW ChiragSW marked this pull request as ready for review May 13, 2026 11:28
@ChiragSW ChiragSW marked this pull request as draft May 13, 2026 11:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant