Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from queue import Empty, Full, Queue
from typing import Any, Literal

import backoff
import datasets
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -2561,26 +2562,7 @@ def one_training_step(
[policy_group.models[i].update_ref_policy.remote() for i in range(args.world_size)]
)

save_time = 0
if args.save_freq > 0 and training_step % args.save_freq == 0 and (args.eval_on_step_0 or training_step > 1):
with Timer("[Main Thread] 🗡️ Saving model") as timer:
checkpoint_dir = f"{args.output_dir}_checkpoints"
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
logger.info(f"Saving model at step {training_step} to {step_dir}")
ray_get_with_progress(
[
policy_group.models[i].save_model.remote(step_dir, chat_template_name, tokenizer)
for i in range(args.world_size)
],
desc=f"Saving model at step {training_step}",
)
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job():
leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}"
for i in range(args.world_size):
policy_group.models[i].launch_ai2_evals_on_weka_wrapper.remote(
step_dir, leaderboard_name, wandb_url, training_step
)
save_time += timer.duration
save_time = maybe_save_checkpoint(args, training_step, policy_group, chat_template_name, tokenizer, wandb_url)

if len(update_ref_policy_future) > 0:
with Timer("[Main Thread] 🔃 Updating reference policy"):
Expand Down Expand Up @@ -2634,6 +2616,34 @@ def one_training_step(
wandb.log(metrics, step=episode)


@backoff.on_exception(backoff.expo, Exception, max_tries=3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using backoff.on_exception with the broad Exception class can mask underlying bugs by retrying on non-transient errors like TypeError or AttributeError. It's better to specify only the exceptions you expect to be transient, such as I/O or network-related errors. This makes the retry logic more robust and prevents hiding actual code issues.

Consider using a more specific set of exceptions. For example, you could catch IOError, OSError, and Ray-specific task errors. You might need to add import ray.exceptions for this.

Example:

@backoff.on_exception(backoff.expo, (IOError, OSError, ray.exceptions.RayTaskError), max_tries=3)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably correct, but I don't know what errors will come up. Instead, I'll keep a list and fix these.

def maybe_save_checkpoint(
args: Args, training_step: int, policy_group, chat_template_name: str, tokenizer, wandb_url: str
) -> float:
save_time = 0
if args.save_freq > 0 and training_step % args.save_freq == 0 and (args.eval_on_step_0 or training_step > 1):
with Timer("[Main Thread] 🗡️ Saving model") as timer:
checkpoint_dir = f"{args.output_dir}_checkpoints"
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
logger.info(f"Saving model at step {training_step} to {step_dir}")
ray_get_with_progress(
[
policy_group.models[i].save_model.remote(step_dir, chat_template_name, tokenizer)
for i in range(args.world_size)
],
desc=f"Saving model at step {training_step}",
)
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job():
leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}"
for i in range(args.world_size):
policy_group.models[i].launch_ai2_evals_on_weka_wrapper.remote(
step_dir, leaderboard_name, wandb_url, training_step
)
save_time = timer.duration

return save_time


def maybe_evaluate(
args: Args,
training_step: int,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ requires-python = "==3.12.*"
dependencies = [
"accelerate>=1.10.1",
"antlr4-python3-runtime==4.11",
"backoff>=2.2.1",
"bitsandbytes>=0.44.1; platform_system != 'Darwin'",
"datasets>=4.0.0",
Comment on lines 7 to 12

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Regenerate requirements.txt after adding backoff

The new checkpoint retry logic imports backoff, and this commit adds the dependency in pyproject.toml. However, requirements.txt (the pinned export used for pip install -r requirements.txt) was not updated. Any environment that relies on the requirements file instead of uv will miss backoff, and importing open_instruct.grpo_fast will fail with ModuleNotFoundError. Please regenerate requirements.txt so that backoff is included.

Useful? React with 👍 / 👎.

"debugpy>=1.8.13",
Expand Down
11 changes: 11 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.