diff --git a/CHANGELOG.md b/CHANGELOG.md index e245407e5e..941c8ca298 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs. +- **Orchestrator async-pipeline rewrite** (collection of removals/renames). The orchestrator was rewritten to overlap train/eval rollouts on a shared concurrency limiter; several config fields were removed or renamed. + - **`orchestrator.seed` removed**: was only consumed by the deleted buffer; no replacement. + - **`orchestrator.eval.eval_base_model` → `orchestrator.eval.skip_first_step`** (semantics inverted): `eval_base_model = true` becomes `skip_first_step = false` (the default — run the step-0 eval before any train rollouts). No alias; configs setting `eval_base_model` must rename. + - **`orchestrator.eval.skip_eval_on_resume` (and its alias `skip_eval_on_restart`) removed**: folded into `skip_first_step`. Resume no longer re-fires an already-completed eval (deduped via per-env last-eval-step tracking in the checkpointed progress). + - **`orchestrator.eval.cancel_inflight_rollouts_on_eval` removed**: the drain-switch overlap (stop scheduling new train, let in-flight train drain while eval queues) is now the only eval transition mode. + - **`ckpt.skip_buffer` removed**: there is no buffer to skip. + - **`[orchestrator.buffer]` removed** (difficulty pools): the whole block and every key (`seed`, `easy_threshold`, `hard_threshold`, `easy_fraction`, `hard_fraction`, `online_difficulty_filtering`, `hash_keys`) are gone. A before-validator drops the block and emits a `FutureWarning` (so old configs still parse). To preserve `online_difficulty_filtering = true`, enforce the zero-advantage pre-batch filter: `[[orchestrator.pre_batch_filters]]\ntype = "zero_advantage"\nenforce = true`. + - **`orchestrator.filters` → `orchestrator.post_batch_filters`** (backward compatible): a before-validator aliases `filters`, so existing TOML/CLI keep parsing. New configs should use `post_batch_filters` (and the new `pre_batch_filters`). Filters are **train-only** now — eval rollouts are no longer filtered. + - **`orchestrator.max_off_policy_steps` now also applies to eval** (behavior change, field unchanged): eval rollouts that fall more than `max_off_policy_steps` versions behind the policy are cancelled, same as train. (2026-05-29) - **`sampling.min_tokens`, `sampling.repetition_penalty`, `sampling.seed` removed**: Dropped from both `TrainSamplingConfig` and `EvalSamplingConfig` (group-level `[orchestrator.train.sampling]` / `[orchestrator.eval.sampling]` and per-env `[[orchestrator.train.env.sampling]]` / `[[orchestrator.eval.env.sampling]]`). `min_tokens` suppressed natural EOS, `repetition_penalty` distorts the on-policy sampling distribution, and `seed` wasn't pulling its weight — none belonged on the supported config surface. Existing configs setting any of these must delete the field. Hard-deprecation, no migration window. (2026-05-27) - **`wandb.shared` removed**: The deprecation shim that popped `wandb.shared` from input dicts with a `FutureWarning` (introduced in #2649) is gone. The `rl` entrypoint always uses shared W&B mode now, and existing configs that still set `wandb.shared = true` (or `false`) will fail validation. Drop the field from your config. (2026-05-27) - **`max_async_level` and `strict_async_level` removed**: The async-execution semantics between trainer and orchestrator are now design invariants, not config knobs. The trainer always runs exactly one step ahead of inference, and the orchestrator always adopts the freshest checkpoint that doesn't violate the one-step barrier. The shared top-level `max_async_level`, the per-sub-config `trainer.max_async_level` / `orchestrator.max_async_level`, and `orchestrator.strict_async_level` have all been removed. Existing configs setting any of these must drop the field; the previous defaults (`max_async_level = 1`, `strict_async_level = false`) match the new hardcoded behavior. Bench mode no longer bypasses the weight-ckpt wait (the `int(1e9)` workaround is gone) and `multimodal/rl_color_codeword_feat_renderer.toml`'s prior `max_async_level = 0` (fully synchronous on-policy) is no longer expressible. (2026-05-25) diff --git a/configs/ci/integration/alphabet_sort.toml b/configs/ci/integration/alphabet_sort.toml index a1666edd24..2253941ed9 100644 --- a/configs/ci/integration/alphabet_sort.toml +++ b/configs/ci/integration/alphabet_sort.toml @@ -1,4 +1,4 @@ -max_steps = 5 +max_steps = 10 seq_len = 2048 [ckpt] diff --git a/configs/hendrycks_math/sanity.toml b/configs/debug/hendrycks_sanity/rl.toml similarity index 89% rename from configs/hendrycks_math/sanity.toml rename to configs/debug/hendrycks_sanity/rl.toml index 11da4f0bbf..e01c97208b 100644 --- a/configs/hendrycks_math/sanity.toml +++ b/configs/debug/hendrycks_sanity/rl.toml @@ -1,5 +1,5 @@ max_steps = 3000 -seq_len = 8192 +seq_len = 4096 [wandb] project = "hendrycks-math-debug" @@ -9,7 +9,8 @@ name = "hendrycks-math-sanity" name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" [orchestrator] -batch_size = 512 +batch_size = 256 +max_inflight_rollouts = 512 group_size = 8 [[orchestrator.train.env]] @@ -30,7 +31,7 @@ group_size = 16 [trainer.model.compile] [inference.model] -max_model_len = 8192 +max_model_len = 4096 [log] level = "debug" diff --git a/configs/debug/multi_env/reverse_text.toml b/configs/debug/multi_env/reverse_text.toml new file mode 100644 index 0000000000..e57f65b1ff --- /dev/null +++ b/configs/debug/multi_env/reverse_text.toml @@ -0,0 +1,52 @@ +max_steps = 20 +seq_len = 2048 + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[wandb] +project = "reverse-text-debug" +name = "debug-multi-env" + +[orchestrator] +training_mode = "rl" +batch_size = 128 +group_size = 16 + +[orchestrator.renderer] +name = "qwen3" + +# -- multi train envs -- + +[[orchestrator.train.env]] +id = "reverse-text" +name = "reverse-text-train-1" + +[[orchestrator.train.env]] +id = "reverse-text" +name = "reverse-text-train-2" + +# -- multi eval envs -- + +[orchestrator.eval] +interval = 10 +num_examples = 16 +group_size = 4 + +[[orchestrator.eval.env]] +id = "reverse-text" +name = "reverse-text-eval-1" + +[[orchestrator.eval.env]] +id = "reverse-text" +name = "reverse-text-eval-2" +interval = 5 + +[trainer.optim] +lr = 3e-6 + +[inference] +gpu_memory_utilization = 0.5 + +[inference.model] +max_model_len = 128 \ No newline at end of file diff --git a/configs/debug/multimodal.toml b/configs/debug/multimodal.toml new file mode 100644 index 0000000000..48fc8aafd1 --- /dev/null +++ b/configs/debug/multimodal.toml @@ -0,0 +1,75 @@ +# 2-GPU debug RL run for the multimodal (renderer) path: Qwen3-VL-4B on +# color-codeword. Sized to actually learn (reward should trend up) while +# staying 2-GPU friendly. Exercises RendererClient + Qwen3VLRenderer end-to-end. + +max_steps = 15 +seq_len = 4096 + +[model] +name = "Qwen/Qwen3-VL-4B-Instruct" + +[model.vlm] +vision_encoder_attr = "model.visual" +language_model_attr = "model.language_model" + +[deployment] +num_train_gpus = 1 +num_infer_gpus = 1 +gpus_per_node = 2 + +[orchestrator] +batch_size = 256 +group_size = 16 +# Image processor is CPU-bound and dominates for VLMs; returns diminish past 4. +pool_size = 4 + +# Step 0 on Qwen3-VL-4B vs color-codeword can be uniform (all-correct or +# all-wrong), so don't enforce zero-advantage dropping or training would crash +# before any progress. +[[orchestrator.filters]] +type = "gibberish" + +[[orchestrator.filters]] +type = "repetition" + +[[orchestrator.filters]] +type = "zero_advantage" +enforce = false + +[orchestrator.train.sampling] +max_completion_tokens = 64 + +[[orchestrator.train.env]] +id = "color-codeword" +args = { images_per_turn = 1, max_turns = 3, num_examples = 1000, seed = 42 } + +# Default renderer (AutoRendererConfig) resolves Qwen3-VL-4B-Instruct from +# MODEL_RENDERER_MAP to Qwen3VLRenderer at runtime; no explicit name needed. + +[trainer] + +[trainer.model] +optimization_dtype = "bfloat16" +reduce_dtype = "bfloat16" + +[trainer.optim] +lr = 3e-6 + +[inference] + +[inference.model] +# Workaround for vLLM 0.20.1 Qwen3-VL deepstack buffer bug: when num_scheduled_tokens +# (188) gets padded up to the next cudagraph_capture_size (192), the model's +# _set_deepstack_input_embeds sizes the buffer to 188 but forward() runs with 192, +# triggering "Requested more deepstack tokens than available in buffer". Eager mode +# skips the padding so num_input_tokens == num_scheduled_tokens. +enforce_eager = true + +[inference.parallel] +dp = 1 +tp = 1 + +[wandb] +project = "debug" +name = "multimodal" +tags = ["qwen3vl-4b", "color-codeword", "renderer"] diff --git a/configs/hendrycks_math/rl.toml b/configs/hendrycks_math/rl.toml index b4fb071ba4..79ae58542a 100644 --- a/configs/hendrycks_math/rl.toml +++ b/configs/hendrycks_math/rl.toml @@ -20,10 +20,6 @@ id = "math-env" name = "hendrycks-math" args = { dataset_name = "PrimeIntellect/Hendrycks-Math", dataset_subset = "default", math_verify_max_workers = 128, math_verify_timeout = 60 } -[orchestrator.buffer] -easy_threshold = 1.0 -hard_threshold = 0.0 - [orchestrator.eval] interval = 10 diff --git a/configs/math_group/rl.toml b/configs/math_group/rl.toml index 51781dc9f2..45105c4c6e 100644 --- a/configs/math_group/rl.toml +++ b/configs/math_group/rl.toml @@ -23,10 +23,6 @@ name = "acereason-math" args = { dataset_name = "nvidia/AceReason-Math", dataset_subset = "default", question_key = "problem" } ratio = 0.5 -[orchestrator.buffer] -easy_threshold = 1.0 -hard_threshold = 0.0 - [orchestrator.eval] interval = 50 diff --git a/configs/multi_reverse_text/rl.toml b/configs/multi_reverse_text/rl.toml deleted file mode 100644 index d19602c543..0000000000 --- a/configs/multi_reverse_text/rl.toml +++ /dev/null @@ -1,59 +0,0 @@ -max_steps = 20 -seq_len = 2048 - -[model] -name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" - -[orchestrator] -batch_size = 128 -group_size = 16 - -# -- train envs -- - -[orchestrator.train.sampling] -max_completion_tokens = 128 - -[[orchestrator.train.env]] -id = "reverse-text" -name = "reverse-text-1" - -[orchestrator.train.env.sampling] -max_completion_tokens = 127 - -[[orchestrator.train.env]] -id = "reverse-text" -name = "reverse-text-2" - -# -- eval envs -- - -[orchestrator.eval] -interval = 5 - -[orchestrator.eval.sampling] -max_completion_tokens = 512 - -[[orchestrator.eval.env]] -id = "reverse-text" -name = "eval-default" -num_examples = 32 -group_size = 4 - -[[orchestrator.eval.env]] -id = "reverse-text" -name = "eval-custom" -num_examples = 16 -group_size = 2 -interval = 10 - -[orchestrator.eval.env.sampling] -max_completion_tokens = 256 -temperature = 0.5 - -[trainer.optim] -lr = 3e-6 - -[inference] # Default inference config - -# Model not in MODEL_RENDERER_MAP — opt into DefaultRenderer (apply_chat_template). -[orchestrator.renderer] -name = "default" diff --git a/configs/multimodal/rl_color_codeword.toml b/configs/multimodal/rl_color_codeword.toml deleted file mode 100644 index 35cfcab809..0000000000 --- a/configs/multimodal/rl_color_codeword.toml +++ /dev/null @@ -1,34 +0,0 @@ -max_steps = 15 -seq_len = 4096 - -[model] -name = "Qwen/Qwen3-VL-4B-Instruct" - -[model.vlm] -vision_encoder_attr = "model.visual" -language_model_attr = "model.language_model" - -[orchestrator] -batch_size = 256 -group_size = 16 - -[orchestrator.train.sampling] -max_completion_tokens = 64 - -[[orchestrator.train.env]] -id = "color-codeword" -args = { images_per_turn = 1, max_turns = 3, num_examples = 1000, seed = 42 } - -[trainer] - -[trainer.model] -optimization_dtype = "bfloat16" -reduce_dtype = "bfloat16" - -[trainer.optim] -lr = 3e-6 - -[inference] - -[inference.parallel] -dp = 1 diff --git a/configs/multimodal/rl_color_codeword_feat_renderer.toml b/configs/multimodal/rl_color_codeword_feat_renderer.toml deleted file mode 100644 index e4ed4b3969..0000000000 --- a/configs/multimodal/rl_color_codeword_feat_renderer.toml +++ /dev/null @@ -1,94 +0,0 @@ -# 20-step Qwen3-VL-4B RL run on color-codeword using the renderer multimodal path. -# -# Pair with rl_color_codeword_main_mito.toml for an A/B comparison: same env, -# same hyperparameters, same step count — only difference is the inference -# client. The feat-branch run uses the new RendererClient + Qwen3VLRenderer -# (renderers package with multimodal support); the main-branch baseline uses -# the existing TITO chat-completions path through the inference server. -# -# Compare in W&B project ``multimodal-renderer``: -# - ``kl/sampler_vs_trainer`` should be ~0 on this branch (the renderer -# produces byte-identical tokens to what the trainer re-tokenizes) and -# can spike on main when BPE drifts mid-rollout. -# - ``reward`` and ``loss`` should track within noise — same model, same -# env, same hyperparameters. -# - ``bridge_break_rate`` is renderer-only; surfaces multi-turn extension -# failures. - -max_steps = 20 -seq_len = 4096 -output_dir = "outputs/rl_color_codeword_feat_renderer" -clean_output_dir = true - -[model] -name = "Qwen/Qwen3-VL-4B-Instruct" - -[model.vlm] -vision_encoder_attr = "model.visual" -language_model_attr = "model.language_model" - -[deployment] -num_train_gpus = 1 -num_infer_gpus = 1 -gpus_per_node = 2 - -[orchestrator] -batch_size = 16 -group_size = 8 -# 64 concurrent rollouts (batch_size=16 × group_size=4) want -# more than one tokenizer slot to avoid serialization queueing. The -# image processor (CPU-bound) dominates for VLMs so returns diminish -# past 4; bump to 4 as the default for multimodal runs. -pool_size = 4 - -# Track zero-advantage groups but don't drop them — we're validating the -# multimodal renderer path on 20 steps, not optimizing training efficiency. -# Step 0 on Qwen3-VL-4B vs color-codeword is likely uniform (all-correct or -# all-wrong) so enforce=True would crash before any training happens. -[[orchestrator.filters]] -type = "gibberish" - -[[orchestrator.filters]] -type = "repetition" - -[[orchestrator.filters]] -type = "zero_advantage" -enforce = false - -[orchestrator.train.sampling] -max_completion_tokens = 64 - -[[orchestrator.train.env]] -id = "color-codeword" -args = { images_per_turn = 2, max_turns = 2, num_examples = 100, seed = 42 } - -# Default renderer (AutoRendererConfig) resolves Qwen3-VL-4B-Instruct from -# MODEL_RENDERER_MAP to Qwen3VLRenderer at runtime; no explicit name needed. - -[trainer] - -[trainer.model] -optimization_dtype = "bfloat16" -reduce_dtype = "bfloat16" - -[trainer.optim] -lr = 3e-6 - -[inference] - -[inference.model] -# Workaround for vLLM 0.20.1 Qwen3-VL deepstack buffer bug: when num_scheduled_tokens -# (188) gets padded up to the next cudagraph_capture_size (192), the model's -# _set_deepstack_input_embeds sizes the buffer to 188 but forward() runs with 192, -# triggering "Requested more deepstack tokens than available in buffer". Eager mode -# skips the padding so num_input_tokens == num_scheduled_tokens. -enforce_eager = true - -[inference.parallel] -dp = 1 -tp = 1 - -[wandb] -project = "multimodal-renderer" -name = "feat-renderer-20step-r8-i2-t2-onpolicy" -tags = ["qwen3vl-4b", "color-codeword", "renderer", "feat-branch", "mm-kwargs-generic", "on-policy"] diff --git a/configs/multimodal/rl_color_codeword_test.toml b/configs/multimodal/rl_color_codeword_test.toml deleted file mode 100644 index 23a25a94f9..0000000000 --- a/configs/multimodal/rl_color_codeword_test.toml +++ /dev/null @@ -1,35 +0,0 @@ -max_steps = 3 -seq_len = 2048 -output_dir = "outputs/rl_color_codeword_test" - -[model] -name = "Qwen/Qwen3-VL-4B-Instruct" - -[model.vlm] -vision_encoder_attr = "model.visual" -language_model_attr = "model.language_model" - -[orchestrator] -batch_size = 16 -group_size = 2 - -[orchestrator.train.sampling] -max_completion_tokens = 32 - -[[orchestrator.train.env]] -id = "color-codeword" -args = { images_per_turn = 1, max_turns = 2, num_examples = 100, seed = 42 } - -[trainer] - -[trainer.model] -optimization_dtype = "bfloat16" -reduce_dtype = "bfloat16" - -[trainer.optim] -lr = 3e-6 - -[inference] - -[inference.parallel] -dp = 1 diff --git a/configs/nemotron_4node/rl.toml b/configs/nemotron_4node/rl.toml index c75e753998..db39767b9e 100644 --- a/configs/nemotron_4node/rl.toml +++ b/configs/nemotron_4node/rl.toml @@ -55,10 +55,6 @@ id = "math-env" name = "hendrycks-math" args = { dataset_name = "PrimeIntellect/Hendrycks-Math", dataset_subset = "default", math_verify_max_workers = 128, math_verify_timeout = 60 } -[orchestrator.buffer] -easy_threshold = 1.0 -hard_threshold = 0.0 - [orchestrator.eval] interval = 10 diff --git a/configs/nemotron_debug/rl.toml b/configs/nemotron_debug/rl.toml index bf92b2c1f2..4c2d020912 100644 --- a/configs/nemotron_debug/rl.toml +++ b/configs/nemotron_debug/rl.toml @@ -54,10 +54,6 @@ id = "math-env" name = "hendrycks-math" args = { dataset_name = "PrimeIntellect/Hendrycks-Math", dataset_subset = "default", math_verify_max_workers = 128, math_verify_timeout = 60 } -[orchestrator.buffer] -easy_threshold = 1.0 -hard_threshold = 0.0 - [orchestrator.eval] interval = 10 diff --git a/configs/rlm_swe/qwen35_4b.toml b/configs/rlm_swe/qwen35_4b.toml new file mode 100644 index 0000000000..626b62b8ef --- /dev/null +++ b/configs/rlm_swe/qwen35_4b.toml @@ -0,0 +1,99 @@ +output_dir = "/beegfs/mika/rlm-swe-qwen35-4b" +max_steps = 400 +seq_len = 65536 + +[slurm] +job_name = "rlm-swe-qwen35-4b" +project_dir = "." +pre_run_command = "prime sandbox delete --label rlm-swe-qwen35-4b -y --plain || true" + +[deployment] +type = "multi_node" +num_train_nodes = 1 +num_infer_nodes = 1 +num_infer_replicas = 2 + +[wandb] +project = "rlm-swe-debug" +name = "qwen35-4b" + +[weight_broadcast] +type = "nccl" + +[ckpt] +interval = 50 +keep_last = 1 +resume_step = -1 + +[model] +name = "Qwen/Qwen3.5-4B" + +# --- Trainer --- + +[trainer] + +[trainer.model] +cp = 4 +cp_style = "ulysses" + +[trainer.model.ac] +freq = 1 + +[trainer.model.compile] + +# --- Orchestrator --- + +[orchestrator] +batch_size = 256 +group_size = 8 +max_inflight_rollouts = 512 +max_off_policy_steps = 16 + +# Thinking enabled for the Qwen3.5 renderer. +[orchestrator.renderer] +name = "qwen3.5" +enable_thinking = true + +[orchestrator.train.sampling] +temperature = 1.0 + +[[orchestrator.train.env]] +id = "rlm_swe" +name = "rlm-swe-r2e" +num_workers = 4 + +[orchestrator.train.env.args] +labels = ["rlm-swe-qwen35-4b"] + +[orchestrator.prime_monitor] + +[orchestrator.eval] +interval = 20 + +[[orchestrator.eval.env]] +id = "rlm_swe" +name = "rlm-swe-swebench-verified-quick" +num_workers = 4 +timeout = 3600 + +[orchestrator.eval.env.args] +task_type = "swebench" +dataset_name = "PrimeIntellect/SWE-Bench-Verified-Quick" +labels = ["rlm-swe-qwen35-4b"] + +# --- Inference --- + +[inference] +gpu_memory_utilization = 0.85 +enable_prefix_caching = true + +[inference.model] +max_model_len = 65536 + +[inference.parallel] +dp = 8 + +# Qwen3.5-4B is a VL model; skip the vision tower for text-only SWE. +# `language_model_only` is a vLLM MultiModalConfig arg (no prime-rl field) → pass via vllm_extra. +[inference.vllm_extra] +language_model_only = true diff --git a/configs/wiki_search/rl.toml b/configs/wiki_search/rl.toml index ebf0037b03..b66d41f3ae 100644 --- a/configs/wiki_search/rl.toml +++ b/configs/wiki_search/rl.toml @@ -33,9 +33,6 @@ oversampling_factor = 2.0 [orchestrator.train.sampling] max_completion_tokens = 512 -[orchestrator.buffer] -online_difficulty_filtering = true - [[orchestrator.train.env]] id = "wiki-search" diff --git a/examples/Intellect-3.1/rl.toml b/examples/Intellect-3.1/rl.toml index 82c8694755..d7d28d4531 100644 --- a/examples/Intellect-3.1/rl.toml +++ b/examples/Intellect-3.1/rl.toml @@ -77,10 +77,9 @@ name = "code" ratio = 0.2 args = { pool_size = 512 } -[orchestrator.buffer] -easy_threshold = 1.0 -online_difficulty_filtering = true -seed = 42 +[[orchestrator.pre_batch_filters]] +type = "zero_advantage" +enforce = true [orchestrator.eval] interval = 25 diff --git a/examples/glm5_pd_disag/rl.toml b/examples/glm5_pd_disag/rl.toml index 1ec1a5b435..671b8ff7e8 100644 --- a/examples/glm5_pd_disag/rl.toml +++ b/examples/glm5_pd_disag/rl.toml @@ -81,11 +81,6 @@ args = { dataset_name="PrimeIntellect/SWE-Bench-Verified-Quick", max_turns = 200 type = "gibberish" enforce = true -[orchestrator.buffer] -easy_threshold = 1.0 -hard_threshold = 0.0 -seed = 42 - [inference] enable_expert_parallel = true # we need <0.85 bc glm5 layers are too large for 0.85 diff --git a/examples/minimax_m2.5_swe/rl.toml b/examples/minimax_m2.5_swe/rl.toml index 7ab62275c2..efa2406be2 100644 --- a/examples/minimax_m2.5_swe/rl.toml +++ b/examples/minimax_m2.5_swe/rl.toml @@ -62,11 +62,6 @@ id = "mini-swe-agent-plus" name = "swe-bench-verified-quick" args = { dataset_name="PrimeIntellect/SWE-Bench-Verified-Quick", max_turns = 200, cpu_cores = 2, memory_gb = 4, disk_size_gb = 4, labels = ["mini-swe-agent-plus"], total_timeout_minutes = 720, sandbox_client_max_workers = 256, max_command_timeouts = 3, sandbox_command_timeout = 30} -[orchestrator.buffer] -easy_threshold = 1.0 -hard_threshold = 0.0 -seed = 42 - [inference.parallel] tp = 8 diff --git a/examples/multinode/rl.toml b/examples/multinode/rl.toml index f5ee93d16b..be27bdf355 100644 --- a/examples/multinode/rl.toml +++ b/examples/multinode/rl.toml @@ -46,10 +46,6 @@ id = "math-env" name = "hendrycks-math" args = { dataset_name = "PrimeIntellect/Hendrycks-Math", dataset_subset = "default", math_verify_max_workers = 128, math_verify_timeout = 60 } -[orchestrator.buffer] -easy_threshold = 1.0 -hard_threshold = 0.0 - [inference.parallel] tp = 4 dp = 2 diff --git a/examples/qwen30b_math/rl.toml b/examples/qwen30b_math/rl.toml index 82edf7e050..9bb6d9a36e 100644 --- a/examples/qwen30b_math/rl.toml +++ b/examples/qwen30b_math/rl.toml @@ -58,10 +58,5 @@ interval = 25 [[orchestrator.eval.env]] id = "aime2025" -[orchestrator.buffer] -easy_threshold = 1.0 -hard_threshold = 0.0 -seed = 42 - [inference.parallel] tp = 8 \ No newline at end of file diff --git a/examples/qwen30b_swe/rl.toml b/examples/qwen30b_swe/rl.toml index ba5902b71b..929fc8b481 100644 --- a/examples/qwen30b_swe/rl.toml +++ b/examples/qwen30b_swe/rl.toml @@ -59,10 +59,5 @@ id = "mini-swe-agent-plus" name = "swe-bench-verified-quick" args = { dataset_name="PrimeIntellect/SWE-Bench-Verified-Quick", max_turns = 200, cpu_cores = 2, memory_gb = 4, disk_size_gb = 4, labels = ["mini-swe-agent-plus"], total_timeout_minutes = 720, sandbox_client_max_workers = 256, max_command_timeouts = 3, sandbox_command_timeout = 30} -[orchestrator.buffer] -easy_threshold = 1.0 -hard_threshold = 0.0 -seed = 42 - [inference.parallel] tp = 8 \ No newline at end of file diff --git a/examples/wiki_search/rl.toml b/examples/wiki_search/rl.toml index 599d70658f..715ed18794 100644 --- a/examples/wiki_search/rl.toml +++ b/examples/wiki_search/rl.toml @@ -40,8 +40,9 @@ name = "qwen3-4b-wiki-search" [orchestrator.train.sampling] max_completion_tokens = 512 -[orchestrator.buffer] -online_difficulty_filtering = true +[[orchestrator.pre_batch_filters]] +type = "zero_advantage" +enforce = true [[orchestrator.train.env]] id = "primeintellect/wiki-search" diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 62b0c0b37c..4c6efc9431 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -210,6 +210,10 @@ class TrainEnvConfig(EnvConfig): sampling: TrainSamplingConfig = TrainSamplingConfig() """Per-env sampling overrides. Unset fields inherit from the group-level train sampling config.""" + group_size: int = Field(1, ge=1, validation_alias=AliasChoices("group_size", "rollouts_per_example")) + """Rollouts generated per example for GRPO group-relative advantages. + Inherits from ``orchestrator.group_size`` when unset.""" + class EvalEnvConfig(EnvConfig): sampling: EvalSamplingConfig = EvalSamplingConfig() @@ -296,6 +300,10 @@ class EvalConfig(BaseConfig): interval: int = Field(100, ge=1) """Step interval at which to evaluate the model.""" + skip_first_step: bool = False + """If True, skip the startup eval that otherwise runs before any + train rollouts.""" + @model_validator(mode="after") def resolve_env_defaults(self): """Resolve per-env overrides: inherit group-level sampling, num_workers, max_retries, num_examples, group_size, and interval. Then resolve auto num_workers.""" @@ -325,6 +333,16 @@ def resolve_env_defaults(self): env.num_workers = max(1, math.ceil(max_concurrent / 256)) return self + @model_validator(mode="after") + def validate_non_empty_envs(self): + if not self.env: + raise ValueError( + "EvalConfig must define at least one env. Either drop the " + "[orchestrator.eval] block entirely (to disable eval) or " + "add a [[orchestrator.eval.env]] block." + ) + return self + @model_validator(mode="after") def validate_unique_env_names(self): env_names = [env.resolved_name for env in self.env] @@ -335,17 +353,6 @@ def validate_unique_env_names(self): ) return self - eval_base_model: bool = True - """Evaluate the base model we are training on.""" - - skip_eval_on_resume: bool = Field( - True, validation_alias=AliasChoices("skip_eval_on_resume", "skip_eval_on_restart") - ) - """When resuming the orchestrator from a checkpoint, skip the (potentially redundant) online eval that would otherwise run immediately at the resumed step.""" - - cancel_inflight_rollouts_on_eval: bool = False - """Cancel in-flight training rollouts before starting online evals. Avoids congestion (no training + eval rollouts at the same time) at the cost of slower training steps as the pipeline has to refill after each eval.""" - class CheckpointConfig(BaseConfig): interval: int | None = Field(None, ge=1) @@ -366,38 +373,6 @@ class CheckpointConfig(BaseConfig): skip_progress: bool = False """Skip loading the progress from checkpoint.""" - skip_buffer: bool = False - """Skip loading the buffer from checkpoint.""" - - -class BufferConfig(BaseConfig): - seed: int | None = None - """Random seed for the buffer. When set, sampling from the buffer is deterministic.""" - - easy_threshold: float | None = None - """Average-reward threshold above which a problem is classified ``easy``.""" - - hard_threshold: float | None = None - """Average-reward threshold below which a problem is classified ``hard``.""" - - easy_fraction: float = Field(0.0, ge=0, le=1) - """Fraction of easy problems to convert to ``normal`` when resuming or starting training. Only problems with difficulty ``normal`` are sampled.""" - - hard_fraction: float = Field(0.0, ge=0, le=1) - """Fraction of hard problems to convert to ``normal`` when resuming or starting training. Only problems with difficulty ``normal`` are sampled.""" - - online_difficulty_filtering: bool = False - """Filter rollouts based on difficulty. When True, rollouts with average reward 0.0 or 1.0 are not added to the buffer.""" - - hash_keys: list[str] = Field(["env_name", "prompt"], min_length=1) - """Keys used to compute example hashes. Used to match examples from buffer checkpoints and determine buffer resume behavior.""" - - @model_validator(mode="after") - def validate_thresholds(self): - if self.easy_threshold is not None and self.hard_threshold is not None: - assert self.easy_threshold > self.hard_threshold, "easy_threshold must be greater than hard_threshold." - return self - class TokensLengthPenaltyConfig(BaseConfig): type: Literal["tokens"] = "tokens" @@ -569,12 +544,27 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic eval: EvalConfig | None = None """Evaluation configuration.""" - buffer: BufferConfig = BufferConfig() - advantage: AdvantageConfig | None = DefaultAdvantageConfig() - filters: list[FilterConfig] = [GibberishFilterConfig(), RepetitionFilterConfig(), ZeroAdvantageFilterConfig()] - """Rollout filters. Each filter can ``monitor`` (default) or ``enforce`` (skip rollouts).""" + pre_batch_filters: list[FilterConfig] = [ + GibberishFilterConfig(enforce=False), + RepetitionFilterConfig(enforce=False), + ZeroAdvantageFilterConfig(enforce=False), + ] + """Filters applied *before* a rollout enters the training batch buffer. + All three filter types are registered in monitor mode by default; flip ``enforce=true`` per type + to drop matching rollouts before they consume a slot in the batch (e.g. a zero-advantage group + never makes it into a training batch).""" + + post_batch_filters: list[FilterConfig] = [ + GibberishFilterConfig(), + RepetitionFilterConfig(), + ZeroAdvantageFilterConfig(), + ] + """Filters applied *after* a batch has been assembled. Each filter annotates each rollout; + rollouts flagged by an enforcing filter are still recorded but not shipped to the trainer. + The TOML/CLI key ``filters`` is accepted as an alias for ``post_batch_filters`` (see + ``_alias_filters_to_post_batch_filters``).""" log: LogConfig = LogConfig() @@ -634,9 +624,6 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic bench: bool = False """Benchmark mode. Sets ``max_steps`` to 5 and disables W&B.""" - seed: int | None = 42 - """Random seed for the orchestrator.""" - heartbeat: HeartbeatConfig | None = None """BetterStack heartbeat configuration for monitoring training progress.""" @@ -645,6 +632,48 @@ def _preserve_mito_renderer(self, handler: SerializerFunctionWrapHandler) -> dic experimental: OrchestratorExperimentalConfig = OrchestratorExperimentalConfig() + @model_validator(mode="before") + @classmethod + def _alias_filters_to_post_batch_filters(cls, data: Any) -> Any: + """``filters`` is accepted as an alias for ``post_batch_filters``.""" + if isinstance(data, dict) and "filters" in data and "post_batch_filters" not in data: + data = dict(data) + data["post_batch_filters"] = data.pop("filters") + return data + + @model_validator(mode="before") + @classmethod + def _warn_removed_buffer(cls, data: Any) -> Any: + """The ``[orchestrator.buffer]`` block (difficulty pools + online + difficulty filtering) has been removed. Drop it so old configs still + parse, but warn — loudly for ``online_difficulty_filtering``, since it + has a direct replacement.""" + if not isinstance(data, dict) or "buffer" not in data: + return data + data = dict(data) + buffer = data.pop("buffer") + if isinstance(buffer, dict) and buffer.get("online_difficulty_filtering"): + warnings.warn( + "'[orchestrator.buffer]' has been removed and 'online_difficulty_filtering' " + "is now a no-op. To preserve the behavior (drop zero-advantage groups before " + "they enter the training batch), enforce the zero_advantage pre-batch filter:\n" + " [[orchestrator.pre_batch_filters]]\n" + ' type = "zero_advantage"\n' + " enforce = true\n" + "Difficulty pools (easy/hard thresholds and fractions) are removed with no " + "replacement.", + FutureWarning, + stacklevel=2, + ) + else: + warnings.warn( + "'[orchestrator.buffer]' has been removed (difficulty pools are no longer " + "supported) and is being ignored. Remove it from your config.", + FutureWarning, + stacklevel=2, + ) + return data + @model_validator(mode="before") @classmethod def fold_student_shortcuts(cls, data: Any) -> Any: @@ -759,9 +788,12 @@ def auto_setup_prime_monitor_run_name(self): @model_validator(mode="after") def validate_unique_filter_types(self): - types = [f.type for f in self.filters] - if len(types) != len(set(types)): - raise ValueError(f"Duplicate filter types: {types}. Each filter type may only appear once.") + for slot_name in ("pre_batch_filters", "post_batch_filters"): + types = [f.type for f in getattr(self, slot_name)] + if len(types) != len(set(types)): + raise ValueError( + f"Duplicate filter types in {slot_name}: {types}. Each filter type may only appear once per slot." + ) return self @model_validator(mode="after") @@ -883,6 +915,11 @@ def resolve_batching(self): if self.max_inflight_rollouts is not None and self.max_inflight_rollouts < self.group_size: raise ValueError("max_inflight_rollouts must be at least the number of rollouts per example") + # Propagate the top-level ``group_size`` into each train env that didn't set its own. + for env_cfg in self.train.env: + if "group_size" not in env_cfg.model_fields_set: + env_cfg.group_size = self.group_size + # Resolve train env num_workers from max_inflight_rollouts for env_cfg in self.train.env: if env_cfg.num_workers == "auto": diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index 515634cf4c..ff311f145d 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -158,6 +158,9 @@ class LogConfig(BaseConfig): log_data: bool = False """Log the first data sample at startup.""" + interval: float = Field(10.0, gt=0) + """Interval (seconds) for periodic logs across components.""" + class TrainerLogConfig(LogConfig): ranks_filter: list[int] = [0] diff --git a/pyproject.toml b/pyproject.toml index 3756b233af..1b1a18d016 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ envs = [ "opencode-science", "opencode-swe", "reverse-text", + "rlm-swe", "science-env", "simpleqa-verified", "tau2-bench", @@ -158,6 +159,7 @@ members = [ "deps/research-environments/environments/opencode_math", "deps/research-environments/environments/opencode_science", "deps/research-environments/environments/opencode_swe", + "deps/research-environments/environments/rlm_swe", "deps/research-environments/environments/science_env", "deps/research-environments/environments/simpleqa_verified", "deps/research-environments/environments/tau2_bench", @@ -254,6 +256,7 @@ opencode-math = { workspace = true } opencode-science = { workspace = true } opencode-swe = { workspace = true } reverse-text = { workspace = true } +rlm-swe = { workspace = true } science-env = { workspace = true } simpleqa-verified = { workspace = true } tau2-bench = { workspace = true } diff --git a/src/prime_rl/entrypoints/orchestrator.py b/src/prime_rl/entrypoints/orchestrator.py index ef1b02d2a8..48a8a354f8 100644 --- a/src/prime_rl/entrypoints/orchestrator.py +++ b/src/prime_rl/entrypoints/orchestrator.py @@ -19,9 +19,9 @@ def main(): set_proc_title("Orchestrator") config = cli(OrchestratorConfig) - from prime_rl.orchestrator.orchestrator import orchestrate + from prime_rl.orchestrator.orchestrator import run_orchestrator - asyncio.run(orchestrate(config)) + asyncio.run(run_orchestrator(config)) if __name__ == "__main__": diff --git a/src/prime_rl/entrypoints/rl.py b/src/prime_rl/entrypoints/rl.py index 93902e76ae..21afe03752 100644 --- a/src/prime_rl/entrypoints/rl.py +++ b/src/prime_rl/entrypoints/rl.py @@ -194,12 +194,7 @@ def sigterm_handler(signum, frame): "orchestrator starts, otherwise rollouts will hang." ) - # Start orchestrator process - orchestrator_cmd = [ - "orchestrator", - "@", - (config_dir / ORCHESTRATOR_TOML).as_posix(), - ] + orchestrator_cmd = ["orchestrator", "@", (config_dir / ORCHESTRATOR_TOML).as_posix()] logger.info("Starting orchestrator process") logger.debug(f"Orchestrator start command: {' '.join(orchestrator_cmd)}") with open(log_dir / "orchestrator.log", "w") as log_file: @@ -313,7 +308,7 @@ def sigterm_handler(signum, frame): cleanup_processes(processes) sys.exit(1) - logger.success("RL training finished!") + logger.success("Training finished!") # Cleanup threads and processes cleanup_threads(monitor_threads) diff --git a/src/prime_rl/orchestrator/advantage.py b/src/prime_rl/orchestrator/advantage.py index c60a99aa13..3866bb0f56 100644 --- a/src/prime_rl/orchestrator/advantage.py +++ b/src/prime_rl/orchestrator/advantage.py @@ -1,12 +1,16 @@ -from collections import defaultdict +from __future__ import annotations + from dataclasses import dataclass -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch import verifiers as vf from jaxtyping import Float from torch import Tensor +if TYPE_CHECKING: + from prime_rl.orchestrator.types import TrainRollout + from prime_rl.configs.orchestrator import ( AdvantageConfig, CustomAdvantageConfig, @@ -14,7 +18,7 @@ TokensLengthPenaltyConfig, TurnsLengthPenaltyConfig, ) -from prime_rl.orchestrator.vf_utils import get_model_completion_len, get_num_turns, get_tool_response_len +from prime_rl.orchestrator.utils import get_model_completion_len, get_tool_response_len from prime_rl.utils.utils import import_object @@ -40,7 +44,7 @@ def my_advantage(inputs: AdvantageInputs, **kwargs) -> AdvantageOutputs: ... The function receives a single group and returns a list of advantages with one -entry per rollout. `compute_advantages` calls it once per group. +entry per rollout. `assign_advantages` calls it on one already-grouped cohort. """ @@ -64,7 +68,7 @@ def default_advantage_fn( ) return AdvantageOutputs(advantages=_efficiency_shaping(rewards, costs).tolist()) if isinstance(length_penalty, TurnsLengthPenaltyConfig): - costs = torch.tensor([get_num_turns(r) for r in inputs.rollouts], dtype=rewards.dtype) + costs = torch.tensor([len(r["trajectory"]) for r in inputs.rollouts], dtype=rewards.dtype) return AdvantageOutputs(advantages=_efficiency_shaping(rewards, costs).tolist()) return AdvantageOutputs(advantages=(rewards - rewards.mean()).tolist()) @@ -124,28 +128,22 @@ def advantage_fn(inputs: AdvantageInputs) -> AdvantageOutputs: return advantage_fn -def compute_advantages( - rollouts: list[vf.RolloutOutput], - advantage_config: AdvantageConfig | None, +def assign_advantages( + rollouts: list["TrainRollout"], # noqa: F821 (forward ref) + advantage_fn: AdvantageFn | None, ) -> None: - """Computes advantages from rollouts, grouped by (env_name, example_id), and - stores them in-place on the rollouts. + """Compute and assign advantages for one finished group of rollouts. - `advantage_fn` is called once per group, so groups may have varying sizes - (partial-group training drops failed rollouts rather than rescheduling them). + Caller (``TrainSink.process_group``) hands in a single group's + post-error-filter survivors, so no grouping logic is needed here. + ``advantage_fn=None`` is the trivial case (advantage = reward). + Custom advantage functions still receive the raw ``vf.RolloutOutput``\\ s + via ``AdvantageInputs.rollouts`` — public API unchanged. """ - if not advantage_config: + if advantage_fn is None: for rollout in rollouts: - rollout["advantage"] = rollout["reward"] + rollout.advantage = rollout.reward return - - advantage_fn = setup_advantage_fn(advantage_config) - - groups_by_example: dict[tuple[str, int], list[vf.RolloutOutput]] = defaultdict(list) - for rollout in rollouts: - groups_by_example[(rollout["env_name"], rollout["example_id"])].append(rollout) - - for group in groups_by_example.values(): - result = advantage_fn(AdvantageInputs(rollouts=group)) - for rollout, advantage in zip(group, result.advantages): - rollout["advantage"] = advantage + result = advantage_fn(AdvantageInputs(rollouts=[r.raw for r in rollouts])) + for rollout, advantage in zip(rollouts, result.advantages): + rollout.advantage = advantage diff --git a/src/prime_rl/orchestrator/buffer.py b/src/prime_rl/orchestrator/buffer.py deleted file mode 100644 index 0217faea76..0000000000 --- a/src/prime_rl/orchestrator/buffer.py +++ /dev/null @@ -1,323 +0,0 @@ -from __future__ import annotations - -import hashlib -import json -import random -from collections import defaultdict -from functools import partial -from pathlib import Path -from typing import TYPE_CHECKING, cast - -import verifiers as vf -from verifiers.utils.save_utils import make_serializable - -from prime_rl.configs.orchestrator import BufferConfig -from prime_rl.utils.logger import get_logger -from prime_rl.utils.utils import format_num, mean, mean_normalize - -if TYPE_CHECKING: - from prime_rl.orchestrator.envs import TrainEnv, TrainEnvs - - -POOLS = ["easy", "normal", "hard"] - - -class _EnvBuffer: - """Manages examples and difficulty pools for a single env.""" - - def __init__(self, env: TrainEnv, config: BufferConfig): - self.env_name = env.name - self.config = config - - dataset = env.get_dataset(seed=config.seed) - if "example_id" not in dataset.column_names: - dataset = dataset.map(lambda ex, idx: {**ex, "example_id": idx}, with_indices=True) - - assert len(dataset) > 0, f"Dataset for {env.name} must contain at least one example." - assert "example_id" in dataset.column_names, f"Dataset for {env.name} must contain an `example_id` column." - assert "prompt" in dataset.column_names, f"Dataset for {env.name} must contain a `prompt` column." - - self.examples: dict[int, dict] = {} - for example in map(partial(cast, dict), dataset): - example["env_name"] = env.name - self.examples[example["example_id"]] = example - - self.easy_examples: list[dict] = [] - self.hard_examples: list[dict] = [] - - self.reset_step_metrics() - - @property - def num_normal(self) -> int: - return len(self.examples) - - @property - def num_total(self) -> int: - return self.num_normal + len(self.easy_examples) + len(self.hard_examples) - - def sample_example(self) -> dict: - key = random.choice(tuple(self.examples)) - return self.examples[key] - - def get_example_hash(self, example: dict) -> str: - hash_keys = [key for key in self.config.hash_keys if key in example] - assert hash_keys, "No hashable keys found in example." - return hashlib.sha256(json.dumps([example[key] for key in hash_keys]).encode()).hexdigest() - - def update_pools(self, example_id: int, avg_reward: float) -> str: - """Assign example to pool based on reward. Returns pool name.""" - if self.config.easy_threshold is not None and avg_reward >= self.config.easy_threshold: - pool = "easy" - elif self.config.hard_threshold is not None and avg_reward <= self.config.hard_threshold: - pool = "hard" - else: - pool = "normal" - - if pool != "normal" and example_id in self.examples: - example = self.examples.pop(example_id) - target = self.easy_examples if pool == "easy" else self.hard_examples - target.append(example) - - self.num_examples_per_step[pool] += 1 - return pool - - def reset_step_metrics(self) -> None: - zero = lambda: {p: 0 for p in POOLS} - self.num_examples_per_step = zero() - self.num_rollouts_per_step = zero() - - def get_metrics(self) -> dict[str, float]: - metrics = {} - num_examples = sum(self.num_examples_per_step.values()) - num_rollouts = sum(self.num_rollouts_per_step.values()) - - for pool in ["easy", "hard"]: - if num_examples: - metrics[f"evicted_examples/{self.env_name}/{pool}"] = self.num_examples_per_step[pool] / num_examples - if num_rollouts: - metrics[f"filtered_rollouts/{self.env_name}/{pool}"] = self.num_rollouts_per_step[pool] / num_rollouts - - pool_counts = [len(self.easy_examples), self.num_normal, len(self.hard_examples)] - pool_ratios = mean_normalize(pool_counts) - for pool, ratio in zip(POOLS, pool_ratios): - metrics[f"pool/{self.env_name}/{pool}"] = ratio - - self.reset_step_metrics() - return metrics - - -class Buffer: - """Manages multiple Buffers with env-ratio-aware sampling.""" - - def __init__(self, envs: TrainEnvs, config: BufferConfig): - self.config = config - self.logger = get_logger() - - if config.seed is not None: - random.seed(config.seed) - - self.env_buffers: dict[str, _EnvBuffer] = {} - for env in envs: - self.env_buffers[env.name] = _EnvBuffer(env, config) - self.env_names = envs.names - - total = sum(eb.num_total for eb in self.env_buffers.values()) - self.logger.debug( - f"Initialized buffer with {format_num(total, precision=0)} example(s) " - f"in {len(self.env_names)} environment(s)" - ) - - env_ratios = [env.config.ratio for env in envs] - if any(r is not None for r in env_ratios): - env_ratio = mean_normalize(env_ratios) - self.env_probs = dict(zip(self.env_names, env_ratio)) - self.logger.debug( - f"Sampling buffer according to provided environment ratios " - f"({', '.join(f'{k}={v:.2f}' for k, v in self.env_probs.items())})" - ) - else: - env_counts = [self.env_buffers[name].num_normal for name in self.env_names] - env_ratio = mean_normalize(env_counts) - self.env_probs = dict(zip(self.env_names, env_ratio)) - self.logger.debug( - f"Sampling buffer according to natural environment distribution " - f"({', '.join(f'{k}={v:.2f}' for k, v in self.env_probs.items())})" - ) - - self.rollout_buffer: list[vf.RolloutOutput] = [] - - def sample_examples(self, n: int) -> list[dict]: - """Samples n examples across envs, respecting env ratios.""" - non_empty = [name for name, eb in self.env_buffers.items() if eb.examples] - if not non_empty: - raise ValueError("No environments left with examples.") - - weights = [self.env_probs[name] for name in non_empty] - return [self.env_buffers[name].sample_example() for name in random.choices(non_empty, weights=weights, k=n)] - - def update(self, rollouts: list[vf.RolloutOutput]): - """Updates buffer state with completed rollouts.""" - rollouts_by_example = defaultdict(list) - for rollout in rollouts: - rollouts_by_example[(rollout["env_name"], rollout["example_id"])].append(rollout) - - for (env_name, example_id), example_rollouts in rollouts_by_example.items(): - eb = self.env_buffers[env_name] - avg_reward = mean([r["reward"] for r in example_rollouts]) - eb.update_pools(example_id, avg_reward) - - if self.config.online_difficulty_filtering: - if avg_reward == 0.0: - eb.num_rollouts_per_step["hard"] += len(example_rollouts) - continue - elif avg_reward == 1.0: - eb.num_rollouts_per_step["easy"] += len(example_rollouts) - continue - - eb.num_rollouts_per_step["normal"] += len(example_rollouts) - self.rollout_buffer.extend(example_rollouts) - - def sample_rollouts(self, n: int) -> list[vf.RolloutOutput]: - """Samples the latest n rollouts from the buffer.""" - n = min(n, len(self.rollout_buffer)) - sampled = self.rollout_buffer[-n:] - self.rollout_buffer = self.rollout_buffer[:-n] - return sampled - - def save(self, path: Path) -> None: - """Saves pool assignments and rollout buffer.""" - path.mkdir(parents=True, exist_ok=True) - - def write_jsonl(lst: list, filepath: Path) -> None: - with open(filepath, "w") as f: - for item in lst: - f.write(json.dumps(item, default=make_serializable) + "\n") - - all_easy = [ex for eb in self.env_buffers.values() for ex in eb.easy_examples] - all_hard = [ex for eb in self.env_buffers.values() for ex in eb.hard_examples] - write_jsonl(all_easy, path / "easy_examples.jsonl") - write_jsonl(all_hard, path / "hard_examples.jsonl") - write_jsonl(self.rollout_buffer, path / "rollout_buffer.jsonl") - - def load(self, path: Path) -> None: - """Loads pool assignments and rollouts from checkpoint.""" - - def read_jsonl(filepath: Path) -> list[dict]: - with open(filepath, "r") as f: - return [json.loads(line) for line in f] - - saved_easy = read_jsonl(path / "easy_examples.jsonl") - saved_hard = read_jsonl(path / "hard_examples.jsonl") - saved_rollouts = cast(list[vf.RolloutOutput], read_jsonl(path / "rollout_buffer.jsonl")) - - if not any(saved_easy) and not any(saved_hard) and not any(saved_rollouts): - self.logger.debug("No easy/ hard examples or rollouts found in checkpoint") - return - - # Build hash lookup across all env buffers: env -> (hash -> example_id) - hash_lookup: dict[str, dict[str, int]] = defaultdict(dict) - all_hashes: set[str] = set() - for env_name, eb in self.env_buffers.items(): - for example_id, example in eb.examples.items(): - h = eb.get_example_hash(example) - if h in all_hashes: - self.logger.warning( - f"Duplicate example hash found based on hash_keys={self.config.hash_keys}. " - "Overwriting with latest example. This may cause unexpected behavior when resuming the buffer." - ) - hash_lookup[env_name][h] = example_id - all_hashes.add(h) - - def move_saved_pool(saved_examples: list[dict], pool_name: str) -> int: - num_moved = 0 - for example in saved_examples: - # Use any env buffer to compute hash (hash_keys are config-level) - first_eb = next(iter(self.env_buffers.values())) - h = first_eb.get_example_hash(example) - for env_name, env_hashes in hash_lookup.items(): - if h in env_hashes: - example_id = env_hashes[h] - eb = self.env_buffers[env_name] - matched = eb.examples.pop(example_id, None) - if matched is not None: - target = eb.easy_examples if pool_name == "easy" else eb.hard_examples - target.append(matched) - num_moved += 1 - break - return num_moved - - if any(saved_easy): - num_moved = move_saved_pool(saved_easy, "easy") - self.logger.debug(f"Loaded {num_moved}/{len(saved_easy)} example(s) to easy pool from checkpoint.") - if num_moved != len(saved_easy): - self.logger.warning( - f"Could not move {len(saved_easy) - num_moved} example(s) from checkpoint to easy pool. " - "This usually means you resumed with an env mix that does not contain all previous examples." - ) - - if any(saved_hard): - num_moved = move_saved_pool(saved_hard, "hard") - self.logger.debug(f"Moved {num_moved}/{len(saved_hard)} example(s) to hard pool from checkpoint.") - if num_moved != len(saved_hard): - self.logger.warning( - f"Could not move {len(saved_hard) - num_moved} example(s) from checkpoint to hard pool. " - "This usually means you resumed with an env mix that does not contain all previous examples." - ) - - if any(saved_rollouts): - valid = [r for r in saved_rollouts if r.get("env_name") in self.env_names] - self.rollout_buffer.extend(valid) - self.logger.debug(f"Loaded {len(valid)} rollout(s) from checkpoint.") - - def convert_to_normal(eb: _EnvBuffer, pool: list[dict], fraction: float) -> int: - if fraction <= 0.0 or not pool: - return 0 - num_to_move = round(len(pool) * fraction) - if num_to_move <= 0: - return 0 - for _ in range(num_to_move): - example = random.choice(pool) - pool.remove(example) - eb.examples[example["example_id"]] = example - return num_to_move - - for eb in self.env_buffers.values(): - n_easy = len(eb.easy_examples) - moved = convert_to_normal(eb, eb.easy_examples, self.config.easy_fraction) - self.logger.debug(f"Converted {moved}/{n_easy} example(s) back to normal from easy pool ({eb.env_name}).") - n_hard = len(eb.hard_examples) - moved = convert_to_normal(eb, eb.hard_examples, self.config.hard_fraction) - self.logger.debug(f"Converted {moved}/{n_hard} example(s) back to normal from hard pool ({eb.env_name}).") - - def get_metrics(self) -> dict[str, float]: - metrics = {} - - # Aggregate cross-env totals - total_examples_per_pool = {p: 0 for p in POOLS} - total_rollouts_per_pool = {p: 0 for p in POOLS} - for eb in self.env_buffers.values(): - for p in POOLS: - total_examples_per_pool[p] += eb.num_examples_per_step[p] - total_rollouts_per_pool[p] += eb.num_rollouts_per_step[p] - - total_examples = sum(total_examples_per_pool.values()) - total_rollouts = sum(total_rollouts_per_pool.values()) - - for pool in ["easy", "hard"]: - if total_examples: - metrics[f"evicted_examples/{pool}"] = total_examples_per_pool[pool] / total_examples - if total_rollouts: - metrics[f"filtered_rollouts/{pool}"] = total_rollouts_per_pool[pool] / total_rollouts - - total_normal = sum(eb.num_normal for eb in self.env_buffers.values()) - total_easy = sum(len(eb.easy_examples) for eb in self.env_buffers.values()) - total_hard = sum(len(eb.hard_examples) for eb in self.env_buffers.values()) - pool_ratios = mean_normalize([total_easy, total_normal, total_hard]) - for pool, ratio in zip(POOLS, pool_ratios): - metrics[f"pool/{pool}"] = ratio - - # Per-env metrics - for eb in self.env_buffers.values(): - metrics.update(eb.get_metrics()) - - return metrics diff --git a/src/prime_rl/orchestrator/ckpt.py b/src/prime_rl/orchestrator/ckpt.py index 19277e3063..8cf5842722 100644 --- a/src/prime_rl/orchestrator/ckpt.py +++ b/src/prime_rl/orchestrator/ckpt.py @@ -1,93 +1,55 @@ +"""Checkpoint manager for ``Progress``. Layout: +``/checkpoints/step_N/orchestrator/progress.pt``.""" + +from __future__ import annotations + import time -from dataclasses import asdict, dataclass +from dataclasses import asdict from pathlib import Path import torch from prime_rl.configs.orchestrator import CheckpointConfig -from prime_rl.orchestrator.buffer import Buffer -from prime_rl.utils.logger import get_logger -from prime_rl.utils.utils import get_ckpt_dir, get_step_path - - -@dataclass -class Progress: - step: int = 0 - total_tokens: int = 0 - total_samples: int = 0 - total_problems: int = 0 +from prime_rl.orchestrator.types import Progress +from prime_rl.utils.logger import format_time, get_logger +from prime_rl.utils.pathing import get_ckpt_dir, get_step_path class CheckpointManager: - """Utility class to save and load orchestrator checkpoints to resume orchestrator.""" - - def __init__(self, output_dir: Path, config: CheckpointConfig): + def __init__(self, output_dir: Path, config: CheckpointConfig) -> None: self.config = config self.ckpt_dir = get_ckpt_dir(output_dir) - self.logger = get_logger() def get_ckpt_path(self, step: int) -> Path: return get_step_path(self.ckpt_dir, step) / "orchestrator" - def save_to_path( - self, - ckpt_path: Path, - progress: Progress, - buffer: Buffer, - ): - self.logger.debug(f"Saving orchestrator checkpoint to {ckpt_path}") - start_time = time.perf_counter() - - # Save progress + def save(self, progress: Progress, step: int) -> None: + ckpt_path = self.get_ckpt_path(step) + ckpt_path.mkdir(parents=True, exist_ok=True) + start = time.perf_counter() with open(ckpt_path / "progress.pt", "wb") as f: torch.save({"progress": progress}, f) + get_logger().debug( + f"Orchestrator checkpoint saved to {ckpt_path} in {format_time(time.perf_counter() - start)}" + ) - # Save buffer - buffer.save(ckpt_path / "buffer") - - self.logger.debug(f"Orchestrator checkpoint saved in {time.perf_counter() - start_time:.2f} seconds") - - def load_from_path(self, ckpt_path: Path, progress: Progress, buffer: Buffer) -> None: - """Loads a checkpoint from a given path in-place.""" - self.logger.debug(f"Loading checkpoint from {ckpt_path}") - start_time = time.perf_counter() - - # Load progress + def load(self, progress: Progress, step: int) -> None: + ckpt_path = self.get_ckpt_path(step) + state_file = ckpt_path / "progress.pt" + if not state_file.exists(): + raise FileNotFoundError(f"Orchestrator checkpoint not found at {state_file}") + get_logger().debug(f"Loading checkpoint from {state_file}") + start = time.perf_counter() if self.config.skip_progress: - self.logger.info("Skipping progress loading from checkpoint") + get_logger().info("Skipping progress loading from checkpoint") else: - with open(ckpt_path / "progress.pt", "rb") as f: + with open(state_file, "rb") as f: state = torch.load(f, weights_only=False) - - # Set progress in-place - for key, value in asdict(state["progress"]).items(): - setattr(progress, key, value) - - # Load buffer - if self.config.skip_buffer: - self.logger.info("Skipping buffer loading from checkpoint") - else: - buffer.load(ckpt_path / "buffer") - - self.logger.debug(f"Orchestrator checkpoint loaded in {time.perf_counter() - start_time:.2f} seconds") - - def load(self, progress: Progress, buffer: Buffer, step: int) -> None: - """Loads a checkpoint from a given path.""" - ckpt_path = self.get_ckpt_path(step) - if not ckpt_path.exists(): - raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}") - self.load_from_path(ckpt_path, progress, buffer) - - def save( - self, - progress: Progress, - buffer: Buffer, - step: int, - ) -> None: - """Saves the full checkpoint state for a specified step.""" - ckpt_path = self.get_ckpt_path(step) - ckpt_path.mkdir(parents=True, exist_ok=True) - self.save_to_path(ckpt_path, progress, buffer) + saved: Progress = state["progress"] + for key, value in asdict(saved).items(): + if hasattr(progress, key): + setattr(progress, key, value) + get_logger().debug(f"Orchestrator checkpoint loaded in {format_time(time.perf_counter() - start)}") def setup_ckpt_manager(output_dir: Path, config: CheckpointConfig | None) -> CheckpointManager | None: diff --git a/src/prime_rl/orchestrator/dispatcher.py b/src/prime_rl/orchestrator/dispatcher.py new file mode 100644 index 0000000000..57eaf8903c --- /dev/null +++ b/src/prime_rl/orchestrator/dispatcher.py @@ -0,0 +1,683 @@ +"""RolloutDispatcher: schedules rollouts under a shared permit counter. + +- Capacity (``max_inflight_rollouts``) is shared across train + eval. + A group-scoring task that runs N rollouts in one call reserves N permits. +- Optional rate limiting via ``AsyncLimiter(tasks_per_minute, 60)``. +- Emit-everything invariant: every dispatched rollout eventually reaches + ``out_q`` exactly once as a ``TrainRollout`` / ``EvalRollout``. Failures + (env error, empty trajectory, task exception, off-policy cancel) carry + ``raw["error"]`` set; sinks decide drop / partial-train policy. +- ``DispatcherMode.PREFER_TRAIN`` / ``PREFER_EVAL`` controls which kind to + schedule next. Transitions are level-triggered (driven by the eval + source's emptiness), so in-flight rollouts of the opposite kind drain + naturally on either side of an eval boundary. +- ``on_new_version`` (called by the watcher) bumps ``off_policy_steps`` on + every in-flight rollout and drops groups past ``max_off_policy_steps``. + Cancellations surface as synthetic ``Cancelled`` markers so the sink's + count-to-``group_size`` finalization still fires. +""" + +from __future__ import annotations + +import asyncio +import uuid +from collections import Counter, defaultdict +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import Literal + +import verifiers as vf +from aiolimiter import AsyncLimiter + +from prime_rl.orchestrator.envs import EvalEnvs, TrainEnvs +from prime_rl.orchestrator.eval_source import EvalSource +from prime_rl.orchestrator.train_source import TrainSource +from prime_rl.orchestrator.types import ( + EvalRollout, + FinishedRollout, + GroupState, + InflightRollout, + Kind, + Policy, + TrainRollout, +) +from prime_rl.utils.async_utils import safe_cancel, safe_cancel_all +from prime_rl.utils.client import InferencePool, client_identity +from prime_rl.utils.logger import get_logger + + +class DispatcherMode(Enum): + """Which kind of work the dispatcher schedules next.""" + + PREFER_TRAIN = auto() + PREFER_EVAL = auto() + + +@dataclass +class DispatcherMetrics: + """Per-tick drain counters for the orchestrator's periodic log. + ``drained()`` returns the current values and clears them; point-in-time + gauges live on ``RolloutDispatcher.gauges`` instead.""" + + cancelled_by_kind_env: dict[tuple[Literal["train", "eval"], str], int] = field( + default_factory=lambda: defaultdict(int) + ) + errored_by_kind_env: dict[tuple[Literal["train", "eval"], str], int] = field( + default_factory=lambda: defaultdict(int) + ) + + def record_cancellation(self, *, kind: Literal["train", "eval"], env_name: str, n: int = 1) -> None: + self.cancelled_by_kind_env[(kind, env_name)] += n + + def record_error(self, *, kind: Literal["train", "eval"], env_name: str) -> None: + self.errored_by_kind_env[(kind, env_name)] += 1 + + def drained(self, *, train_envs: set[str], eval_envs: set[str]) -> dict[str, float]: + """Return per-tick counters and clear them. Emits the full pre- + registered key set every tick (zero when no activity) so the wandb + time axis stays dense and ``define_metric`` lines up.""" + out: dict[str, float] = {} + for kind in ("train", "eval"): + envs = train_envs if kind == "train" else eval_envs + cancelled_total = sum(self.cancelled_by_kind_env.get((kind, e), 0) for e in envs) + errored_total = sum(self.errored_by_kind_env.get((kind, e), 0) for e in envs) + out[f"dispatcher/cancelled/{kind}"] = float(cancelled_total) + out[f"dispatcher/errored/{kind}"] = float(errored_total) + for env in train_envs | eval_envs: + out[f"dispatcher/cancelled/{env}"] = float( + self.cancelled_by_kind_env.get(("train", env), 0) + self.cancelled_by_kind_env.get(("eval", env), 0) + ) + out[f"dispatcher/errored/{env}"] = float( + self.errored_by_kind_env.get(("train", env), 0) + self.errored_by_kind_env.get(("eval", env), 0) + ) + self.cancelled_by_kind_env.clear() + self.errored_by_kind_env.clear() + return out + + @staticmethod + def drain_keys(*, train_envs: set[str], eval_envs: set[str]) -> list[str]: + """Full set of keys ``drained`` may emit; used by the periodic + logger for ``wandb.define_metric``.""" + keys = [ + "dispatcher/cancelled/train", + "dispatcher/cancelled/eval", + "dispatcher/errored/train", + "dispatcher/errored/eval", + ] + for env in train_envs | eval_envs: + keys.append(f"dispatcher/cancelled/{env}") + keys.append(f"dispatcher/errored/{env}") + return keys + + +class RolloutDispatcher: + """``await dispatcher.start()`` runs the dispatch loop until ``stop()``. + Pulls examples from ``TrainSource`` / ``EvalSource``, schedules + rollouts under shared capacity, and emits ``FinishedRollout``\\ s to + ``out_q``. The watcher drives ``on_new_version`` for off-policy + cancellation; the orchestrator triggers eval epochs.""" + + def __init__( + self, + *, + train_envs: TrainEnvs, + eval_envs: EvalEnvs | None, + train_source: TrainSource, + eval_source: EvalSource | None, + inference: InferencePool, + eval_inference: InferencePool, + policy: Policy, + max_inflight_rollouts: int, + tasks_per_minute: float | None, + max_off_policy_steps: int, + training_mode: Literal["rl", "opd", "sft"], + ) -> None: + self.policy = policy + self.train_envs = train_envs + self.eval_envs = eval_envs + # Train rollouts go to ``inference`` (the teacher in SFT mode); + # eval always evaluates the student, so it uses ``eval_inference``. + self.inference = inference + self.eval_inference = eval_inference + self.train_source = train_source + self.eval_source = eval_source + self.training_mode = training_mode + self.max_off_policy_steps = max_off_policy_steps + + self.max_inflight = max_inflight_rollouts + self.inflight_permits = 0 + self.rate_limiter: AsyncLimiter | None = ( + AsyncLimiter(tasks_per_minute, time_period=60) if tasks_per_minute else None + ) + + self.inflight: dict[asyncio.Task, InflightRollout] = {} + self.groups: dict[uuid.UUID, GroupState] = {} + + # Bounded so the dispatcher backpressures on a slow sink + self.out_q: asyncio.Queue[FinishedRollout] = asyncio.Queue(maxsize=max(8, self.max_inflight)) + + self.mode: DispatcherMode = DispatcherMode.PREFER_TRAIN + # Set by the orchestrator after the final train step; pipeline then + # winds down without scheduling new train rollouts + self.train_scheduling_disabled: bool = False + self.metrics = DispatcherMetrics() + + # Orchestrator-owned gate. When clear, ``fill_inflight`` returns + # without scheduling new groups. The dispatcher itself doesn't know + # *why* — the orchestrator toggles this based on step / policy lead. + self.dispatch_allowed = asyncio.Event() + self.dispatch_allowed.set() + + self.stopped = asyncio.Event() + self.task: asyncio.Task | None = None + + @property + def train_model_name(self) -> str: + """Model name for *train* rollouts. In SFT mode train data comes from + the teacher pool, so use its model name; otherwise the live student + policy. (Eval always uses ``policy.model_name`` — the student.)""" + if self.training_mode == "sft": + return self.inference.model_name + return self.policy.model_name + + @property + def inflight_train_count(self) -> int: + return sum(m.rollout_count for m in self.inflight.values() if m.kind == "train") + + @property + def inflight_eval_count(self) -> int: + return sum(m.rollout_count for m in self.inflight.values() if m.kind == "eval") + + @property + def available_permits(self) -> int: + return self.max_inflight - self.inflight_permits + + @property + def inflight_by_env(self) -> dict[tuple[Kind, str], int]: + counts: dict[tuple[Kind, str], int] = defaultdict(int) + for meta in self.inflight.values(): + counts[(meta.kind, meta.env_name)] += meta.rollout_count + return dict(counts) + + @property + def queued_eval_examples(self) -> int: + return len(self.eval_source) if self.eval_source is not None else 0 + + @property + def is_idle(self) -> bool: + """True once nothing is in flight, no eval queued, and ``out_q`` is + empty — the pipeline has fully drained.""" + eval_drained = self.eval_source is None or not self.eval_source + return not self.inflight and eval_drained and self.out_q.empty() + + def disable_train_scheduling(self) -> None: + """Stop scheduling new train rollouts; in-flight train + any + triggered eval drain naturally.""" + self.train_scheduling_disabled = True + + @property + def max_off_policy_level(self) -> int: + steps = [m.off_policy_steps for m in self.inflight.values() if m.kind == "train"] + return max(steps) if steps else 0 + + @property + def mean_off_policy_level(self) -> float: + steps = [m.off_policy_steps for m in self.inflight.values() if m.kind == "train"] + return sum(steps) / len(steps) if steps else 0.0 + + # ── lifecycle ────────────────────────────────────────────────────────── + + async def start(self) -> None: + """Single dispatch loop: schedule, wait, collect, repeat.""" + self.task = asyncio.current_task() + try: + while not self.stopped.is_set(): + await self.fill_inflight() + if not self.inflight: + # No work — sleep briefly. Eval triggers from the + # orchestrator wake the next iteration via a mode flip + try: + await asyncio.wait_for(self.stopped.wait(), timeout=0.1) + except asyncio.TimeoutError: + pass + continue + + done, _pending = await asyncio.wait( + list(self.inflight.keys()), + return_when=asyncio.FIRST_COMPLETED, + timeout=0.5, # wake periodically to re-check fill (mode flips) + ) + for task in done: + await self.handle_completed_rollout(task) + except asyncio.CancelledError: + return + + async def stop(self) -> None: + self.stopped.set() + await self.cancel_inflight_rollouts() + if self.task is not None: + await safe_cancel(self.task) + self.task = None + + async def on_new_version(self, step: int) -> None: + """Bump off-policy counters and drop groups past + ``max_off_policy_steps`` (drop_group emits ``Cancelled`` markers so + the sink still finalizes the partial group).""" + stale_groups: dict[uuid.UUID, Kind] = {} + cancelled_by_kind: dict[Kind, int] = {"train": 0, "eval": 0} + for meta in self.inflight.values(): + meta.off_policy_steps += 1 + if meta.off_policy_steps > self.max_off_policy_steps: + stale_groups[meta.group_id] = meta.kind + + for gid, kind in stale_groups.items(): + removed = await self.drop_group(gid) + cancelled_by_kind[kind] += removed + + for kind in ("train", "eval"): + n = cancelled_by_kind[kind] + if n: + get_logger().warning( + f"Cancelled {n} {kind} rollouts past max_off_policy_steps={self.max_off_policy_steps}. " + "Consider increasing it to avoid this." + ) + + async def fill_inflight(self) -> None: + """Schedule new rollouts up to ``max_inflight``, honoring + ``self.mode``. When ``PREFER_EVAL``'s source exhausts we flip back + to ``PREFER_TRAIN`` so the eval tail drains alongside fresh train.""" + if not self.dispatch_allowed.is_set(): + return + while True: + if self.available_permits <= 0: + return + + if self.mode == DispatcherMode.PREFER_EVAL: + # PREFER_EVAL is only entered when the orchestrator triggers + # eval, which requires ``eval_source`` to be configured + assert self.eval_source is not None + eval_has_work = bool(self.eval_source) or any( + g.kind == "eval" and g.rollouts_to_schedule > 0 for g in self.groups.values() + ) + if not eval_has_work: + # Eval source + all eval groups fully dispatched. Flip + # to PREFER_TRAIN so any remaining permits go to train + # while the in-flight eval tail completes naturally + self.switch_mode(DispatcherMode.PREFER_TRAIN, reason="the eval queue drained") + continue + scheduled = await self.try_schedule("eval") + if not scheduled: + return + else: # PREFER_TRAIN + scheduled = await self.try_schedule("train") + if not scheduled: + return + + def switch_mode(self, new_mode: DispatcherMode, *, reason: str) -> None: + if new_mode == self.mode: + return + prefer = "eval" if new_mode == DispatcherMode.PREFER_EVAL else "train" + get_logger().info(f"Switching dispatcher mode to prefer {prefer} rollouts because {reason}") + self.mode = new_mode + + async def try_schedule(self, kind: Kind) -> bool: + """Schedule one rollout of ``kind``: prefer continuing an existing + group (keeps prefix-cache hits); otherwise open a fresh group from + the corresponding source. Returns False if nothing could be + scheduled.""" + if kind == "train" and self.train_scheduling_disabled: + return False + envs = self.train_envs if kind == "train" else self.eval_envs + if envs is None: + return False + + for gid, group in list(self.groups.items()): + if group.kind != kind or group.rollouts_to_schedule <= 0: + continue + env = envs.get(group.env_name) + cost = group.rollouts_to_schedule if env.requires_group_scoring else 1 + if cost <= self.available_permits: + return await self.schedule_group_rollout(gid, group) + + fresh = self.next_fresh_group(kind, envs) + if fresh is None: + return False + gid = uuid.uuid4() + self.groups[gid] = fresh + return await self.schedule_group_rollout(gid, fresh) + + def next_fresh_group(self, kind: Kind, envs) -> GroupState | None: + """Pop the next example from the corresponding source and wrap it in + a ``GroupState``. Returns ``None`` if the source is empty or the + picked env's permit cost doesn't fit.""" + if kind == "train": + source = self.train_source + else: + assert self.eval_source is not None + source = self.eval_source + example = source.next_example(self.available_permits) + if example is None: + return None + + env_name = example["env_name"] + group_size = envs.get(env_name).config.group_size + eval_step: int | None = example.get("eval_step") if kind == "eval" else None + + return GroupState( + kind=kind, + env_name=env_name, + example=example, + rollouts_to_schedule=group_size, + target_rollouts=group_size, + eval_step=eval_step, + policy_version_at_start=self.policy.version, + ) + + async def schedule_group_rollout(self, group_id: uuid.UUID, group: GroupState) -> bool: + """Dispatch one ``run_rollout`` / ``run_group`` task for this group. + + Returns False only if we couldn't even schedule one rollout (no clients + ready, no permits). Returns True after issuing one task — the caller + loops to keep scheduling. + """ + # Train rollouts use the rollout pool (teacher in SFT) via the + # renderer/token train client. Eval always evaluates the student and + # goes through the eval client (chat-completions) — the same path the + # legacy orchestrator used, so eval scores stay comparable. + if group.kind == "eval": + pool, model_name = self.eval_inference, self.policy.model_name + else: + pool, model_name = self.inference, self.train_model_name + + # Pin a single client per group to keep prefix-cache hits + if group.pinned_client is None: + if group.kind == "eval": + client = await pool.get_eval_client() + else: + load = Counter( + client_identity(m.client_config) for m in self.inflight.values() if m.client_config is not None + ) + client = await pool.select_train_client(load) + if group_id not in self.groups: + return False + group.pinned_client = client + else: + client = group.pinned_client + + env_collection = self.train_envs if group.kind == "train" else self.eval_envs + if env_collection is None: + return False + env = env_collection.get(group.env_name) + cache_salt = str(group.policy_version_at_start) + + if env.requires_group_scoring: + permits = group.rollouts_to_schedule + group.rollouts_to_schedule = 0 + await self.acquire(permits) + task: asyncio.Task = asyncio.create_task( + env.run_group( + client=client, + example=group.example, + model_name=model_name, + group_size=permits, + cache_salt=cache_salt, + ) + ) + else: + permits = 1 + group.rollouts_to_schedule -= 1 + await self.acquire(permits) + task = asyncio.create_task( + env.run_rollout( + client=client, + example=group.example, + model_name=model_name, + cache_salt=cache_salt, + ) + ) + + self.inflight[task] = InflightRollout( + kind=group.kind, + env_name=group.env_name, + group_id=group_id, + policy_version=group.policy_version_at_start, + rollout_count=permits, + client_config=client, + eval_step=group.eval_step, + ) + return True + + async def acquire(self, n: int) -> None: + """Reserve ``n`` permits + rate-limit each one. Caller must precheck + ``available_permits >= n``; this is not a blocking acquire.""" + for _ in range(n): + if self.rate_limiter is not None: + await self.rate_limiter.acquire() + self.inflight_permits += 1 + + def release(self, n: int) -> None: + self.inflight_permits -= n + + async def handle_completed_rollout(self, task: asyncio.Task) -> None: + """Emit every dispatched rollout exactly once to ``out_q``. Task + exceptions synthesize ``meta.rollout_count`` error markers so the + sink's count-to-``group_size`` finalization still triggers. + Cancelled tasks (popped by ``drop_group``) raise ``CancelledError`` + and are discarded — ``drop_group`` already emitted their markers. + """ + meta = self.inflight.pop(task, None) + if meta is None: + return # already handled by drop_group / cancel_inflight_rollouts + self.release(meta.rollout_count) + group = self.groups.get(meta.group_id) + + is_synth_exception = False + try: + result = task.result() + rollouts: list[vf.RolloutOutput] = result if isinstance(result, list) else [result] + except asyncio.CancelledError: + return + except Exception as exc: + get_logger().warning(f"Rollout task failed in group {meta.group_id} ({meta.env_name}): {exc!r}") + rollouts = [ + self.error_rollout_output(error_type=type(exc).__name__, error_repr=repr(exc)) + for _ in range(meta.rollout_count) + ] + is_synth_exception = True + + for r in rollouts: + if r.get("error") is None and len(r.get("trajectory") or []) == 0: + # Empty trajectory: promote to an explicit error so the sink + # treats it like any other failure + r["error"] = { + "error": "EmptyTrajectory", + "error_chain_repr": "Rollout returned with no trajectory steps", + "error_chain_str": "", + } + get_logger().warning(f"Empty trajectory in group {meta.group_id} ({meta.env_name})") + if r.get("error") is not None: + err_type = r["error"].get("error", "Unknown") + self.metrics.record_error(kind=meta.kind, env_name=meta.env_name) + if not is_synth_exception: + get_logger().warning( + f"Rollout failed in group {meta.group_id} ({meta.env_name}) — " + f"{r['error'].get('error_chain_repr', err_type)}" + ) + await self.emit_rollout(meta, group, r) + + async def emit_rollout(self, meta: InflightRollout, group: GroupState | None, raw: vf.RolloutOutput) -> None: + """Build a ``TrainRollout`` / ``EvalRollout`` and put it on ``out_q``. + Pops the group from ``self.groups`` once every member has been emitted.""" + eval_step = meta.eval_step + policy_version = meta.policy_version + example_id = raw.get("example_id") + if group is not None: + eval_step = group.eval_step + policy_version = group.policy_version_at_start + example_id = group.example["example_id"] + group.emitted += 1 + if group.emitted >= group.target_rollouts: + self.groups.pop(meta.group_id, None) + + common = dict( + raw=raw, + env_name=meta.env_name, + example_id=example_id if example_id is not None else -1, + group_id=meta.group_id, + policy_version=policy_version, + off_policy_steps=meta.off_policy_steps, + ) + rollout: FinishedRollout + if meta.kind == "train": + rollout = TrainRollout(**common) + else: + assert eval_step is not None, "eval rollout missing eval_step" + rollout = EvalRollout(**common, eval_step=eval_step) + await self.out_q.put(rollout) + + @staticmethod + def error_rollout_output(*, error_type: str, error_repr: str) -> vf.RolloutOutput: + """Minimal ``vf.RolloutOutput`` for rollouts that never produced + real output (task exception, off-policy cancel).""" + out: vf.RolloutOutput = vf.RolloutOutput() + out["error"] = { + "error": error_type, + "error_chain_repr": error_repr, + "error_chain_str": error_repr, + } + out["trajectory"] = [] + out["completion"] = None + out["reward"] = 0.0 + out["is_truncated"] = False + out["metrics"] = {} + out["stop_condition"] = None + out["token_usage"] = { + "input_tokens": 0.0, + "output_tokens": 0.0, + "final_input_tokens": 0.0, + "final_output_tokens": 0.0, + } + return out + + async def drop_group(self, group_id: uuid.UUID) -> int: + """Cancel remaining in-flight tasks for this group and emit a + ``Cancelled`` marker for every rollout it still owes the sink + (both in-flight and not-yet-scheduled). Returns the count for + off-policy metrics.""" + group = self.groups.pop(group_id, None) + + # Sync claim phase: pop matching tasks from ``self.inflight`` and + # release their permits in one non-yielding sweep. After this loop + # the dropped tasks are no longer reachable from ``self.inflight``, + # so ``handle_completed_rollout``'s existing None-guard makes the + # subsequent async emit phase race-free. + claimed: list[tuple[asyncio.Task, InflightRollout]] = [] + for task, meta in list(self.inflight.items()): + if meta.group_id != group_id: + continue + del self.inflight[task] + self.release(meta.rollout_count) + claimed.append((task, meta)) + + tasks_to_cancel = [task for task, _ in claimed] + inflight_cancelled = sum(meta.rollout_count for _, meta in claimed) + last_meta: InflightRollout | None = claimed[-1][1] if claimed else None + for _, meta in claimed: + for _ in range(meta.rollout_count): + raw = self.error_rollout_output(error_type="Cancelled", error_repr="Off-policy cancel") + await self.emit_rollout(meta, group, raw) + + # For non-group-scoring envs, the group may have rollouts that + # were never dispatched (``rollouts_to_schedule > 0``). Emit + # markers for those too so the sink hits ``target_rollouts`` + # + # ``last_meta`` can be ``None`` if the only inflight task for this + # group completed naturally between ``on_new_version``'s snapshot + # and us reaching it — synthesize a stand-in from the group state + unscheduled_cancelled = 0 + if group is not None and group.rollouts_to_schedule > 0: + fallback_meta = last_meta or InflightRollout( + kind=group.kind, + env_name=group.env_name, + group_id=group_id, + policy_version=group.policy_version_at_start, + rollout_count=1, + eval_step=group.eval_step, + ) + unscheduled_cancelled = group.rollouts_to_schedule + for _ in range(unscheduled_cancelled): + raw = self.error_rollout_output(error_type="Cancelled", error_repr="Off-policy cancel") + await self.emit_rollout(fallback_meta, group, raw) + + cancelled = inflight_cancelled + unscheduled_cancelled + if cancelled > 0: + meta_for_log = last_meta or ( + InflightRollout( + kind=group.kind, + env_name=group.env_name, + group_id=group_id, + policy_version=group.policy_version_at_start if group else 0, + rollout_count=1, + eval_step=group.eval_step, + ) + if group is not None + else None + ) + if meta_for_log is not None: + self.metrics.record_cancellation(kind=meta_for_log.kind, env_name=meta_for_log.env_name, n=cancelled) + get_logger().debug( + f"drain {meta_for_log.kind} | group={str(group_id)[:8]} env={meta_for_log.env_name} | " + f"cancelled={cancelled} (inflight={inflight_cancelled} unscheduled={unscheduled_cancelled})" + ) + + if tasks_to_cancel: + await safe_cancel_all(tasks_to_cancel) + return cancelled + + async def cancel_inflight_rollouts(self) -> None: + """Cancel all in-flight rollouts. Used on shutdown — doesn't emit + markers since the sinks are being torn down anyway.""" + for meta in self.inflight.values(): + self.metrics.record_cancellation(kind=meta.kind, env_name=meta.env_name, n=meta.rollout_count) + self.release(meta.rollout_count) + tasks = list(self.inflight.keys()) + self.inflight.clear() + self.groups.clear() + if tasks: + await safe_cancel_all(tasks) + + async def cancel_inflight_train_rollouts(self) -> int: + """Cancel in-flight train rollouts, leaving eval alone. Used by the + orchestrator at ``max_steps`` so triggered eval can still complete + through the pipeline while wasted train inference is short-circuited.""" + train_tasks: list[asyncio.Task] = [] + train_group_ids: set[uuid.UUID] = set() + cancelled = 0 + for task, meta in list(self.inflight.items()): + if meta.kind != "train": + continue + self.inflight.pop(task, None) + self.release(meta.rollout_count) + self.metrics.record_cancellation(kind="train", env_name=meta.env_name, n=meta.rollout_count) + cancelled += meta.rollout_count + train_tasks.append(task) + train_group_ids.add(meta.group_id) + for gid in train_group_ids: + self.groups.pop(gid, None) + if train_tasks: + await safe_cancel_all(train_tasks) + return cancelled + + # ── metrics ──────────────────────────────────────────────────────────── + + def gauges(self) -> dict[str, float]: + """Instantaneous, read-only gauges sampled by the periodic logger.""" + return { + "dispatcher/inflight_train": float(self.inflight_train_count), + "dispatcher/inflight_eval": float(self.inflight_eval_count), + "dispatcher/queued/eval": float(self.queued_eval_examples), + "dispatcher/mode": float(self.mode == DispatcherMode.PREFER_EVAL), + "dispatcher/groups_in_flight": float(len(self.groups)), + "dispatcher/off_policy_level_max": float(self.max_off_policy_level), + "dispatcher/off_policy_level_mean": self.mean_off_policy_level, + } diff --git a/src/prime_rl/orchestrator/envs.py b/src/prime_rl/orchestrator/envs.py index 03b1f59e90..e9d426b403 100644 --- a/src/prime_rl/orchestrator/envs.py +++ b/src/prime_rl/orchestrator/envs.py @@ -16,7 +16,6 @@ from prime_rl.configs.orchestrator import EnvConfig, EvalEnvConfig, TrainEnvConfig from prime_rl.orchestrator.eval_utils import compute_pass_at_k -from prime_rl.orchestrator.vf_utils import get_completion_len from prime_rl.utils.logger import ProgressTracker, get_logger from prime_rl.utils.monitor import get_monitor from prime_rl.utils.utils import capitalize @@ -31,7 +30,7 @@ def __init__(self, config: EnvConfig): self.config = config self.sampling_args: dict = {} - get_logger().info(f"Initializing {config.resolved_name} ({config})") + get_logger().debug(f"Initializing {config.resolved_name} ({config})") self._env: vf.Environment = vf.load_environment(config.stripped_id, **config.args) self._env_client: ZMQEnvClient | None = None self._env_server_process: BaseProcess | None = None @@ -268,7 +267,7 @@ async def run_with_progress(example: dict) -> list[vf.RolloutOutput] | None: { "example_id": o["example_id"], "reward": o["reward"], - "completion_len": get_completion_len(o), + "completion_len": o["token_usage"]["final_output_tokens"], "is_truncated": o["is_truncated"], "has_error": o.get("error") is not None, "no_response": not o.get("completion"), @@ -370,7 +369,7 @@ def shutdown(self) -> None: if not processes: return logger = get_logger() - logger.info(f"Shutting down {len(processes)} env server(s), waiting for sandbox cleanup...") + logger.debug(f"Shutting down {len(processes)} env server(s)") for p in processes: p.terminate() for p in processes: diff --git a/src/prime_rl/orchestrator/eval_sink.py b/src/prime_rl/orchestrator/eval_sink.py new file mode 100644 index 0000000000..2c21841678 --- /dev/null +++ b/src/prime_rl/orchestrator/eval_sink.py @@ -0,0 +1,160 @@ +"""EvalSink: three-level rollout sink for eval epochs. + +Same shape as ``TrainSink``, but no tokenization / advantages / filters: + +1. ``process_rollout`` — no-op. +2. ``process_group`` — at ``group_size`` arrivals, move the rollouts + (errored ones included) into the ``(env, eval_step)`` bucket. +3. ``process_batch`` — at ``num_examples × group_size`` arrivals, build + the ``EvalBatchMetrics`` and return an ``EvalBatch``. + +``add()`` returns ``EvalBatch | None``. +""" + +from __future__ import annotations + +import uuid +from collections import defaultdict + +from prime_rl.orchestrator.envs import EvalEnvs +from prime_rl.orchestrator.eval_utils import compute_pass_at_k +from prime_rl.orchestrator.types import EvalBatch, EvalBatchMetrics, EvalRollout +from prime_rl.utils.logger import get_logger + + +class EvalSink: + """Constructed only when eval is configured.""" + + def __init__(self, *, eval_envs: EvalEnvs) -> None: + self.eval_envs = eval_envs + self.pending_groups: dict[uuid.UUID, list[EvalRollout]] = defaultdict(list) + # Bucket size IS the arrival count — ``process_group`` flushes + # everything in without filtering + self.pending_batches: dict[tuple[str, int], list[EvalRollout]] = defaultdict(list) + + def add(self, rollout: EvalRollout) -> EvalBatch | None: + """Process one arrival; finalize the group on the ``group_size``-th + arrival and the per-env epoch on the ``num_examples × group_size``-th.""" + env_name = rollout.env_name + self.process_rollout(rollout) + bkey = (env_name, rollout.eval_step) + self.pending_groups[rollout.group_id].append(rollout) + if len(self.pending_groups[rollout.group_id]) >= self.group_size_for(env_name): + self.process_group(rollout.group_id) + if len(self.pending_batches[bkey]) >= self.batch_size_for(env_name): + return self.process_batch(bkey) + return None + + def group_size_for(self, env_name: str) -> int: + return self.eval_envs.get(env_name).config.group_size + + def batch_size_for(self, env_name: str) -> int: + """``num_examples × group_size`` — total rollouts expected for one + epoch of ``env_name``.""" + env = self.eval_envs.get(env_name) + return len(env.examples) * env.config.group_size + + def batch_progress(self) -> list[tuple[str, int, int, int, int]]: + """One entry per accumulating ``(env, eval_step)`` batch: + ``(env_name, eval_step, batch_count, expected, buffered)``. + ``batch_count`` is finalized-group survivors in ``pending_batches``; + ``buffered`` is partial-group arrivals from non-group-scoring envs.""" + batch_counts: dict[tuple[str, int], int] = {bkey: len(bucket) for bkey, bucket in self.pending_batches.items()} + buffered: dict[tuple[str, int], int] = {} + for rollouts in self.pending_groups.values(): + if not rollouts: + continue + env_name = rollouts[0].env_name + if self.eval_envs.get(env_name).requires_group_scoring: + continue + bkey = (env_name, rollouts[0].eval_step) + buffered[bkey] = buffered.get(bkey, 0) + len(rollouts) + return [ + ( + env_name, + eval_step, + batch_counts.get((env_name, eval_step), 0), + self.batch_size_for(env_name), + buffered.get((env_name, eval_step), 0), + ) + for (env_name, eval_step) in set(batch_counts) | set(buffered) + ] + + # ── level 1: per-rollout (no-op for eval) ───────────────────────────── + + def process_rollout(self, rollout: EvalRollout) -> None: + """No-op. Eval rollouts don't need trainer-bound tokenization; the + method exists to keep the three-level structure uniform with + ``TrainSink``. + """ + return None + + # ── level 2: per-group (move into batch bucket) ─────────────────────── + + def process_group(self, group_id: uuid.UUID) -> None: + group = self.pending_groups.pop(group_id, []) + if not group: + return + env_name = group[0].env_name + example_id = group[0].example_id + eval_step = group[0].eval_step + bucket = self.pending_batches[(env_name, eval_step)] + bucket.extend(group) + + survivors = [r for r in group if r.error is None] + num_errored = len(group) - len(survivors) + rewards = [r.reward for r in survivors] + avg_reward = sum(rewards) / len(rewards) if rewards else 0.0 + get_logger().debug( + f"Finished group | env={env_name} example_id={example_id} eval_step={eval_step} | " + f"rollouts={len(group)} (errored={num_errored}) | reward={avg_reward:.4f}" + ) + + def process_batch(self, key: tuple[str, int]) -> EvalBatch: + """Build ``EvalBatchMetrics`` and return the finalized ``EvalBatch``. + Errored rollouts (env failures, cancellations, task exceptions) are + excluded from reward / pass@k / seq_len aggregation (including them + at reward=0 would bias the score down) and surfaced separately as + ``n_cancelled`` / ``n_errored``.""" + env_name, step = key + rollouts = self.pending_batches.pop(key, []) + + n_total = len(rollouts) + n_cancelled = sum(1 for r in rollouts if (r.error or {}).get("error") == "Cancelled") + n_errored = sum(1 for r in rollouts if r.error is not None) - n_cancelled + valid = [r for r in rollouts if r.error is None] + metrics = EvalBatchMetrics( + n_rollouts=n_total, + n_cancelled=n_cancelled, + n_errored=n_errored, + ) + + if valid: + rewards = [r.reward for r in valid] + lens = [r.raw["token_usage"]["final_output_tokens"] for r in valid] + metrics.group_size = self.group_size_for(env_name) + metrics.reward_mean = float(sum(rewards) / len(rewards)) + metrics.completion_len_mean = float(sum(lens) / len(lens)) + metrics.completion_len_max = float(max(lens)) + metrics.completion_len_min = float(min(lens)) + metrics.truncation_rate = float(sum(1 for r in valid if r.is_truncated) / len(valid)) + metrics.no_response_rate = float(sum(1 for r in valid if not r.raw.get("completion")) / len(valid)) + num_turns = [len(r.raw.get("trajectory") or []) for r in valid] + metrics.num_turns_mean = float(sum(num_turns) / len(num_turns)) + metrics.num_turns_min = float(min(num_turns)) + metrics.num_turns_max = float(max(num_turns)) + + # pass@k: errored attempts don't count toward k tries + by_example: dict[int | str, list[float]] = {} + for r in valid: + by_example.setdefault(r.example_id, []).append(r.reward) + metrics.n_examples = len(by_example) + unique_rewards = {float(r) for r in rewards} + if unique_rewards.issubset({0.0, 1.0}) and by_example: + pass_at_k_per_example = [compute_pass_at_k(rs) for rs in by_example.values()] + keys = set().union(*(d.keys() for d in pass_at_k_per_example)) + for k in keys: + values = [d[k] for d in pass_at_k_per_example if k in d] + metrics.pass_at_k[k] = float(sum(values) / len(values)) + + return EvalBatch(env_name=env_name, step=step, rollouts=rollouts, metrics=metrics) diff --git a/src/prime_rl/orchestrator/eval_source.py b/src/prime_rl/orchestrator/eval_source.py new file mode 100644 index 0000000000..65744fc793 --- /dev/null +++ b/src/prime_rl/orchestrator/eval_source.py @@ -0,0 +1,87 @@ +"""EvalSource: trigger-driven, finite-per-epoch pull of eval examples. + +The orchestrator pokes ``trigger(step)`` after each ship + once at +startup; the dispatcher pulls via ``next_example(available_permits)`` +until ``bool(source) == False``. Constructed only when eval is +configured.""" + +from __future__ import annotations + +from collections import deque +from itertools import zip_longest + +from prime_rl.configs.orchestrator import EvalConfig +from prime_rl.orchestrator.envs import EvalEnvs + + +class EvalSource: + """Finite-per-epoch source of eval examples.""" + + def __init__( + self, + eval_envs: EvalEnvs, + eval_config: EvalConfig, + *, + is_resumed: bool = False, + ) -> None: + self.eval_envs = eval_envs + self.eval_config = eval_config + + self.examples_by_env: dict[str, list[dict]] = {} + self.intervals: dict[str, int] = {} + for env in eval_envs: + rows: list[dict] = [] + for ex in env.examples: + row = dict(ex) + row["env_name"] = env.name + rows.append(row) + self.examples_by_env[env.name] = rows + self.intervals[env.name] = env.config.interval + + self.queue: deque[dict] = deque() + + # On resume we skip the startup eval; on fresh start the first + # trigger fires every env (subject to ``skip_first_step``) + self.first_trigger = not is_resumed + + def trigger(self, step: int) -> list[str]: + """Fire eligible envs for ``step`` and return their names. On resume + ``first_trigger`` is False, so the startup/base eval doesn't re-run.""" + is_first, self.first_trigger = self.first_trigger, False + if is_first and self.eval_config.skip_first_step: + return [] + fired: list[str] = [] + for name, interval in self.intervals.items(): + if is_first or step % interval == 0: + fired.append(name) + # Round-robin across fired envs (A₁, B₁, A₂, B₂, …) so the + # dispatcher rotates at example granularity. ``try_schedule``'s + # continue-group branch still keeps each example's group_size + # rollouts back-to-back, so per-example prefix-cache locality holds + iters = [iter(self.examples_by_env[name]) for name in fired] + for round_examples in zip_longest(*iters): + for example in round_examples: + if example is None: + continue + row = dict(example) + row["eval_step"] = step + self.queue.append(row) + return fired + + def next_example(self, available_permits: int) -> dict | None: + """Pop the next eval example if the head's permit cost fits in + ``available_permits``; otherwise leave it for a later call.""" + if not self.queue: + return None + head = self.queue[0] + env = self.eval_envs.get(head["env_name"]) + cost = env.config.group_size if env.requires_group_scoring else 1 + if cost > available_permits: + return None + return self.queue.popleft() + + def __bool__(self) -> bool: + return bool(self.queue) + + def __len__(self) -> int: + return len(self.queue) diff --git a/src/prime_rl/orchestrator/event_loop_lag.py b/src/prime_rl/orchestrator/event_loop_lag.py deleted file mode 100644 index d053601e0f..0000000000 --- a/src/prime_rl/orchestrator/event_loop_lag.py +++ /dev/null @@ -1,83 +0,0 @@ -import asyncio -from time import perf_counter - -import numpy as np - -from prime_rl.utils.logger import get_logger - - -class EventLoopLagMonitor: - """A class to monitor how busy the main event loop is.""" - - def __init__( - self, - interval: float = 0.1, - max_window_size: int = 10000, - warn_p90_lag_threshold: float = 1.0, - warn_p99_lag_threshold: float = 5.0, - warn_max_lag_threshold: float = 30.0, - ): - assert ( - interval > 0 - and max_window_size > 0 - and warn_p90_lag_threshold > 0 - and warn_p99_lag_threshold > 0 - and warn_max_lag_threshold > 0 - ) - self.interval = interval - self.max_window_size = max_window_size - self.warn_p90_lag_threshold = warn_p90_lag_threshold - self.warn_p99_lag_threshold = warn_p99_lag_threshold - self.warn_max_lag_threshold = warn_max_lag_threshold - self.logger = get_logger() - self.lags = [] - - async def measure_lag(self): - """Measures event loop lag by asynchronously sleeping for interval seconds""" - next_time = perf_counter() + self.interval - await asyncio.sleep(self.interval) - now = perf_counter() - lag = now - next_time - return lag - - async def run(self): - """Infinite loop to periodically measure event loop lag. Should be started as background task.""" - while True: - lag = await self.measure_lag() - self.lags.append(lag) - if len(self.lags) > self.max_window_size: - self.lags.pop(0) - - def reset(self): - """Reset the list of measured lags.""" - self.lags = [] - - def get_metrics(self) -> dict[str, float]: - """Compute metrics for the event loop lag over the last window_size measurements.""" - window_size = int(min(self.max_window_size, len(self.lags))) - if window_size <= 0: - return {} - last_lags = np.array(self.lags[-window_size:]) - mean_lag = float(np.mean(last_lags)) - med_lag = float(np.median(last_lags)) - p90_lag = float(np.percentile(last_lags, 90)) - p99_lag = float(np.percentile(last_lags, 99)) - min_lag = float(np.min(last_lags)) - max_lag = float(np.max(last_lags)) - if ( - p90_lag > self.warn_p90_lag_threshold - or p99_lag > self.warn_p99_lag_threshold - or max_lag > self.warn_max_lag_threshold - ): - self.logger.warning( - f"Detected busy event loop. Measured {mean_lag:.1f}s (min={min_lag:.1f}s, med={med_lag:.1f}s, p90={p90_lag:.1f}s, p99={p99_lag:.1f}s, max={max_lag:.1f}s) event loop lag over the last {len(last_lags)} measurement(s)" - ) - - return { - "event_loop_lag/min": min_lag, - "event_loop_lag/mean": mean_lag, - "event_loop_lag/med": med_lag, - "event_loop_lag/p90": p90_lag, - "event_loop_lag/p99": p99_lag, - "event_loop_lag/max": max_lag, - } diff --git a/src/prime_rl/orchestrator/filters.py b/src/prime_rl/orchestrator/filters.py index b2921d22b1..f8deda1230 100644 --- a/src/prime_rl/orchestrator/filters.py +++ b/src/prime_rl/orchestrator/filters.py @@ -6,15 +6,18 @@ are not sent to the trainer. Reward is kept as-is for baseline calculation. """ +from __future__ import annotations + import math from dataclasses import dataclass -from typing import Protocol - -import verifiers as vf +from typing import TYPE_CHECKING, Protocol from prime_rl.configs.orchestrator import FilterConfig from prime_rl.utils.logger import get_logger +if TYPE_CHECKING: + from prime_rl.orchestrator.types import TrainRollout + @dataclass class FilterResult: @@ -26,7 +29,7 @@ class RolloutFilter(Protocol): name: str enforce: bool - def check(self, rollout: vf.RolloutOutput) -> FilterResult: ... + def check(self, rollout: "TrainRollout") -> FilterResult: ... @dataclass @@ -46,9 +49,9 @@ class GibberishFilter: logprob_threshold: float enforce: bool = False - def check(self, rollout: vf.RolloutOutput) -> FilterResult: + def check(self, rollout: "TrainRollout") -> FilterResult: global_idx = 0 - for step in rollout["trajectory"]: + for step in rollout.raw["trajectory"]: tokens = step["tokens"] if tokens is None: continue @@ -76,10 +79,10 @@ class RepetitionFilter: logprob_threshold: float enforce: bool = False - def check(self, rollout: vf.RolloutOutput) -> FilterResult: + def check(self, rollout: "TrainRollout") -> FilterResult: consecutive = 0 global_idx = 0 - for step in rollout["trajectory"]: + for step in rollout.raw["trajectory"]: tokens = step["tokens"] if tokens is None: continue @@ -96,18 +99,14 @@ def check(self, rollout: vf.RolloutOutput) -> FilterResult: @dataclass class ZeroAdvantageFilter: - """Flags rollouts with zero advantage. - - This filter is applied after advantages are computed and checks if the - rollout's advantage field is zero. - """ + """Flags rollouts whose computed advantage is zero (e.g. all rollouts in a + GRPO group earned the same reward, so the centered advantage collapses).""" name: str enforce: bool = True - def check(self, rollout: vf.RolloutOutput) -> FilterResult: - advantage = rollout.get("advantage") - if advantage is not None and advantage == 0.0: + def check(self, rollout: "TrainRollout") -> FilterResult: + if rollout.advantage is not None and rollout.advantage == 0.0: return FilterResult(detected=True) return FilterResult(detected=False) @@ -136,11 +135,11 @@ def setup_filter(config: FilterConfig, vocab_size: int) -> RolloutFilter: raise ValueError(f"Unknown filter type: {config.type}") -def setup_filters(configs: list[FilterConfig], vocab_size: int) -> list[RolloutFilter]: +def setup_filters(configs: list[FilterConfig], vocab_size: int, *, kind: str) -> list[RolloutFilter]: """Create RolloutFilters from a list of filter configs.""" filters = [setup_filter(config, vocab_size) for config in configs] if filters: - get_logger().info(f"Configured {len(filters)} rollout filter(s):") + get_logger().info(f"Configured {len(filters)} {kind} rollout filter(s):") for config, filt in zip(configs, filters): mode = "Enforcing" if filt.enforce else "Monitoring" params = ", ".join(f"{k}={v}" for k, v in config.model_dump().items()) @@ -148,41 +147,27 @@ def setup_filters(configs: list[FilterConfig], vocab_size: int) -> list[RolloutF return filters -def apply_filters(filters: list[RolloutFilter], rollouts: list[vf.RolloutOutput]) -> None: - """Flag rollouts in-place with per-filter detection and drop decision. +def apply_filters(filters: list[RolloutFilter], rollouts: list["TrainRollout"]) -> None: # noqa: F821 (forward ref) + """Flag ``TrainRollout``\\ s in place with per-filter detection + drop decision. - Each rollout gets a `filters` dict with per-filter detection booleans and - an `is_filtered` bool that is True iff an enforcing filter detected it. - First matching filter wins per rollout (no double-counting). Reward and - trajectory tokens are left untouched so the rollout can still contribute - to baseline calculations and metric aggregation. + Each rollout's ``filter_results`` dict records per-filter detection bools; + ``is_filtered`` is True iff an enforcing filter detected it. First matching + filter wins per rollout (no double-counting). Reward and trajectory tokens + are left untouched so the rollout can still contribute to baseline + calculations and metric aggregation. """ for rollout in rollouts: - rollout["filters"] = {f.name: False for f in filters} - rollout["is_filtered"] = False + rollout.filter_results = {f.name: False for f in filters} + rollout.is_filtered = False if not filters: return - counts: dict[str, int] = {f.name: 0 for f in filters} - total_detected = 0 - total_enforced = 0 - for rollout in rollouts: for filt in filters: result = filt.check(rollout) if result.detected: - counts[filt.name] += 1 - total_detected += 1 - rollout["filters"][filt.name] = True + rollout.filter_results[filt.name] = True if filt.enforce: - rollout["is_filtered"] = True - total_enforced += 1 + rollout.is_filtered = True break - - if total_detected > 0: - enforced_msg = f", enforced {total_enforced}" if total_enforced > 0 else "" - get_logger().info( - f"Detected {total_detected}/{len(rollouts)} rollouts " - f"({', '.join(f'{name}={c}' for name, c in counts.items() if c > 0)})" + enforced_msg - ) diff --git a/src/prime_rl/orchestrator/inference_metrics.py b/src/prime_rl/orchestrator/inference_metrics.py index 1d5916a385..4468edda5d 100644 --- a/src/prime_rl/orchestrator/inference_metrics.py +++ b/src/prime_rl/orchestrator/inference_metrics.py @@ -392,7 +392,6 @@ class InferenceMetricsCollector: def __init__(self, admin_clients: list[AsyncClient], roles: list[str | None] | None = None): self.endpoints = build_metrics_endpoints(admin_clients, roles=roles) - self.logger = get_logger() self.metric_history: dict[str, deque[float]] = {} self.previous: dict[str, TimedRollup] = {} self.task: asyncio.Task | None = None @@ -406,7 +405,7 @@ async def poll_loop(): try: await self.collect_and_log() except Exception as e: - self.logger.debug(f"Inference metrics poll failed: {e!r}") + get_logger().debug(f"Inference metrics poll failed: {e!r}") await asyncio.sleep(POLL_INTERVAL) self.task = asyncio.create_task(poll_loop()) @@ -420,7 +419,7 @@ async def fetch(endpoint: MetricsEndpoint) -> str | None: response.raise_for_status() return response.text except Exception as e: - self.logger.debug(f"Failed to fetch metrics from {endpoint.client.base_url}: {e!r}") + get_logger().debug(f"Failed to fetch metrics from {endpoint.client.base_url}: {e!r}") return None results = await asyncio.gather(*[fetch(endpoint) for endpoint in self.endpoints]) diff --git a/src/prime_rl/orchestrator/metrics.py b/src/prime_rl/orchestrator/metrics.py new file mode 100644 index 0000000000..87ec99c424 --- /dev/null +++ b/src/prime_rl/orchestrator/metrics.py @@ -0,0 +1,199 @@ +"""MetricsBuilder: assembles the per-step W&B dict. No I/O, no side effects.""" + +from __future__ import annotations + +from typing import Any + +import pandas as pd + +from prime_rl.configs.orchestrator import OrchestratorConfig +from prime_rl.orchestrator.types import Progress, TrainBatchMetrics, TrainRollout + + +class MetricsBuilder: + def __init__(self, config: OrchestratorConfig) -> None: + self.config = config + + def build( + self, + *, + step: int, + rollouts: list[TrainRollout], + metrics: TrainBatchMetrics, + progress: Progress, + step_time: float, + save_ckpt_time: float, + teacher_logprobs_time: float, + pre_filter_seen: int, + pre_filter_dropped: int, + pre_filter_dropped_by_name: dict[str, int], + ) -> dict[str, Any]: + """Builds the per-step W&B dict. Stable metric names so + existing dashboards / alerts keep working.""" + num_rollouts = len(rollouts) + num_unique_examples = len({r.group_id for r in rollouts}) + num_tokens = sum( + r.raw["token_usage"]["final_input_tokens"] + r.raw["token_usage"]["final_output_tokens"] for r in rollouts + ) + + results_df = pd.DataFrame( + { + "group_id": [r.group_id for r in rollouts], + "example_id": [r.example_id for r in rollouts], + "env_name": [r.env_name for r in rollouts], + "reward": [r.reward for r in rollouts], + "is_truncated": [r.is_truncated for r in rollouts], + "is_filtered": [r.is_filtered for r in rollouts], + "stop_condition": [r.raw.get("stop_condition") for r in rollouts], + "seq_len": [ + r.raw["token_usage"]["final_input_tokens"] + r.raw["token_usage"]["final_output_tokens"] + for r in rollouts + ], + "prefill_len": metrics.rollout_prefill_lens, + "decode_len": metrics.rollout_decode_lens, + "samples_per_rollout": metrics.samples_per_rollout, + "num_turns": [len(r.raw["trajectory"]) for r in rollouts], + } + ) + metrics_df = pd.DataFrame([(r.raw.get("metrics") or {}) for r in rollouts]) + filter_df = pd.DataFrame([r.filter_results for r in rollouts]) + timing_df = self.timing_df(rollouts) + + # Each group's full-solve threshold is its own env's group_size (envs + # can override the top-level group_size). + env_group_size = {env.resolved_name: env.group_size for env in self.config.train.env} + + def compute_solve_rates(df): + grouped = df.groupby("group_id") + reward_per_problem = grouped.reward.sum() + solve_none = (reward_per_problem == 0).mean() + expected = grouped.env_name.first().map(env_group_size) + solve_all = (reward_per_problem == expected).mean() + return solve_none, solve_all, 1 - solve_none - solve_all + + by_example = results_df.groupby("group_id") + solve_none, solve_all, effective_batch_size = compute_solve_rates(results_df) + + to_log: dict[str, Any] = { + "progress/tokens": num_tokens, + "progress/prefill_tokens": metrics.num_prefill_tokens, + "progress/decode_tokens": metrics.num_decode_tokens, + "progress/samples": num_rollouts, + "progress/problems": num_unique_examples, + "progress/total_tokens": progress.total_tokens, + "progress/total_samples": progress.total_samples, + "progress/total_problems": progress.total_problems, + "seq_len/all/mean": by_example.seq_len.mean().mean(), + "seq_len/all/max": by_example.seq_len.mean().max(), + "seq_len/all/min": by_example.seq_len.mean().min(), + "prefill_len/all/mean": by_example.prefill_len.mean().mean(), + "prefill_len/all/max": by_example.prefill_len.mean().max(), + "prefill_len/all/min": by_example.prefill_len.mean().min(), + "decode_len/all/mean": by_example.decode_len.mean().mean(), + "decode_len/all/max": by_example.decode_len.mean().max(), + "decode_len/all/min": by_example.decode_len.mean().min(), + "is_truncated/all/mean": by_example.is_truncated.mean().mean(), + "is_truncated/all/max": by_example.is_truncated.mean().max(), + "stop_condition/all/generation_truncated": ( + results_df.is_truncated & (results_df.stop_condition != "prompt_too_long") + ).mean(), + **{ + f"stop_condition/all/{sc}": rate + for sc, rate in results_df.stop_condition.dropna().value_counts(normalize=True).items() + }, + "samples_per_rollout/all/mean": by_example.samples_per_rollout.mean().mean(), + "samples_per_rollout/all/max": by_example.samples_per_rollout.mean().max(), + "samples_per_rollout/all/min": by_example.samples_per_rollout.mean().min(), + "num_turns/all/mean": by_example.num_turns.mean().mean(), + "num_turns/all/max": by_example.num_turns.mean().max(), + "num_turns/all/min": by_example.num_turns.mean().min(), + **{ + f"timing/all/{key}/{stat}": getattr( + timing_df[key].groupby(results_df.group_id).mean(), + stat, + )() + for key in timing_df.columns + for stat in ("mean", "max", "min") + }, + "reward/all/mean": by_example.reward.mean().mean(), + "reward/all/max": by_example.reward.mean().max(), + "reward/all/min": by_example.reward.mean().min(), + "solve_none/all": solve_none, + "solve_all/all": solve_all, + "effective_batch_size/all": effective_batch_size, + **{f"batch/{env}": r for env, r in results_df.env_name.value_counts(normalize=True).items()}, + "time/step": step_time, + "time/teacher_logprobs": teacher_logprobs_time, + "time/save_ckpt": save_ckpt_time, + "filters/all/is_filtered": results_df.is_filtered.astype(float).mean(), + **{f"filters/all/{name}": filter_df[name].astype(float).mean() for name in filter_df.columns}, + "step": step, + } + + # Per-env metrics + per_env_columns = [ + "seq_len", + "prefill_len", + "decode_len", + "is_truncated", + "samples_per_rollout", + "num_turns", + ] + for env, env_df in results_df.groupby("env_name"): + env_by_example = env_df.groupby("group_id") + for col in per_env_columns: + to_log[f"{col}/{env}/mean"] = env_by_example[col].mean().mean() + to_log[f"{col}/{env}/max"] = env_by_example[col].mean().max() + if col != "is_truncated": + to_log[f"{col}/{env}/min"] = env_by_example[col].mean().min() + env_timing_df = timing_df.loc[env_df.index] + for key in timing_df.columns: + per_example = env_timing_df.groupby(env_df["group_id"])[key].mean() + to_log[f"timing/{env}/{key}/mean"] = per_example.mean() + to_log[f"timing/{env}/{key}/max"] = per_example.max() + to_log[f"timing/{env}/{key}/min"] = per_example.min() + to_log[f"reward/{env}/mean"] = env_by_example.reward.mean().mean() + to_log[f"reward/{env}/max"] = env_by_example.reward.mean().max() + to_log[f"reward/{env}/min"] = env_by_example.reward.mean().min() + sn, sa, eb = compute_solve_rates(env_df) + to_log[f"solve_none/{env}"] = sn + to_log[f"solve_all/{env}"] = sa + to_log[f"effective_batch_size/{env}"] = eb + to_log[f"stop_condition/{env}/generation_truncated"] = ( + env_df.is_truncated & (env_df.stop_condition != "prompt_too_long") + ).mean() + for sc, rate in env_df.stop_condition.dropna().value_counts(normalize=True).items(): + to_log[f"stop_condition/{env}/{sc}"] = rate + env_metrics_df = metrics_df.loc[env_df.index] if not metrics_df.empty else metrics_df + for metric in metrics_df.columns: + to_log[f"metrics/{env}/{metric}"] = env_metrics_df.groupby(env_df["group_id"])[metric].mean().mean() + to_log[f"filters/{env}/is_filtered"] = env_df.is_filtered.astype(float).mean() + env_filter_df = filter_df.loc[env_df.index] if not filter_df.empty else filter_df + for name in filter_df.columns: + to_log[f"filters/{env}/{name}"] = env_filter_df[name].astype(float).mean() + + # Dispatcher / watcher gauges live on the ``_timestamp`` axis via + # the periodic logger — keep this dict step-axis only + if pre_filter_seen > 0: + to_log["pre_filters/all/dropped_rate"] = pre_filter_dropped / pre_filter_seen + for name, count in pre_filter_dropped_by_name.items(): + to_log[f"pre_filters/all/{name}/rate"] = count / pre_filter_seen + + return to_log + + @staticmethod + def timing_df(rollouts: list[TrainRollout]) -> pd.DataFrame: + return pd.DataFrame( + [ + { + "total": r.raw["timing"]["total"], + "setup": r.raw["timing"]["setup"]["duration"], + "generation": r.raw["timing"]["generation"]["duration"], + "model": r.raw["timing"]["model"]["duration"], + "env": r.raw["timing"]["env"]["duration"], + "scoring": r.raw["timing"]["scoring"]["duration"], + "overhead": r.raw["timing"]["overhead"], + } + for r in rollouts + ] + ) diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 65f12ccf77..902c8b963b 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -1,901 +1,903 @@ +"""Async-pipelined RL orchestrator. + +``Orchestrator`` owns the shared state (policy, progress, ckpt, monitor) +and drives the pipeline. Components are single-purpose: + +- ``RolloutDispatcher`` schedules rollouts; emits ``TrainRollout`` / + ``EvalRollout`` on its queue. +- ``TrainSink`` ingests train rollouts (tokenize → advantages → filters) + and returns a ``TrainBatch`` when the threshold is met. +- ``EvalSink`` ingests eval rollouts and returns an ``EvalBatch`` (with + per-env metrics) on epoch completion. +- ``MetricsBuilder`` builds the per-step train W&B dict. +- ``WeightWatcher`` advances ``Policy`` and notifies observers. +- ``PeriodicLogger`` polls the components on a shared interval for the + ``_timestamp``-axis pipeline log. + +Components don't reference the orchestrator. The orchestrator wires them +in ``setup()`` and drives them from ``main_loop()``. +""" + +from __future__ import annotations + import asyncio import ctypes -import gc +import logging import os import time +from typing import TYPE_CHECKING import tomli_w -import prime_rl._compat # noqa: F401 — patch ring_flash_attn compat before transitive import -from prime_rl.orchestrator.advantage import compute_advantages -from prime_rl.orchestrator.event_loop_lag import EventLoopLagMonitor -from prime_rl.orchestrator.inference_metrics import InferenceMetricsCollector -from prime_rl.orchestrator.patches import monkey_patch_chat_completion_logprobs, monkey_patch_oai_iterable_types -from prime_rl.orchestrator.trajectories import ( - backfill_rollout_tokens, - interleave_rollout, - offload_images_to_disk, -) -from prime_rl.transport import TrainingBatch, TrainingSample, setup_training_batch_sender -from prime_rl.utils.pathing import get_log_dir, get_rollout_dir, get_step_path -from prime_rl.utils.usage_reporter import UsageReporter - -# This monkey patch is necessary to avoid Pydantic validating fields using typing.Iterable (e.g. in multimodal or tool call messages) lazily which leads to tokenization errors, for more info see https://github.com/PrimeIntellect-ai/prime-rl/pull/1249 -monkey_patch_oai_iterable_types() - - -# This monkey patch is necessary to avoid heavy CPU overhead from constructing the OAI ChatCompletion Pydantic model with logprobs, for more info see https://github.com/PrimeIntellect-ai/prime-rl/pull/1189 -monkey_patch_chat_completion_logprobs() - -# Import environment before any other imports +if TYPE_CHECKING: + from renderers.base import Renderer + from transformers.tokenization_utils import PreTrainedTokenizer -import pandas as pd -import verifiers as vf -from renderers.base import create_renderer + from prime_rl.orchestrator.ckpt import CheckpointManager + from prime_rl.transport.base import TrainingBatchSender + from prime_rl.utils.client import InferencePool + from prime_rl.utils.monitor.base import Monitor +from verifiers.utils.async_utils import EventLoopLagMonitor, EventLoopLagStats +import prime_rl._compat # noqa: F401 — patch ring_flash_attn compat before transitive imports from prime_rl.configs.orchestrator import OrchestratorConfig -from prime_rl.orchestrator.buffer import Buffer -from prime_rl.orchestrator.ckpt import Progress, setup_ckpt_manager -from prime_rl.orchestrator.envs import EvalEnv, EvalEnvs, TrainEnvs -from prime_rl.orchestrator.filters import apply_filters, setup_filters -from prime_rl.orchestrator.scheduler import Scheduler +from prime_rl.orchestrator.ckpt import setup_ckpt_manager +from prime_rl.orchestrator.dispatcher import DispatcherMetrics, DispatcherMode, RolloutDispatcher +from prime_rl.orchestrator.envs import EvalEnvs, TrainEnvs +from prime_rl.orchestrator.eval_sink import EvalSink +from prime_rl.orchestrator.eval_source import EvalSource +from prime_rl.orchestrator.filters import setup_filters +from prime_rl.orchestrator.inference_metrics import InferenceMetricsCollector +from prime_rl.orchestrator.metrics import MetricsBuilder +from prime_rl.orchestrator.patches import ( + monkey_patch_chat_completion_logprobs, + monkey_patch_oai_iterable_types, +) +from prime_rl.orchestrator.periodic_logger import PeriodicLogger +from prime_rl.orchestrator.train_sink import TrainSink +from prime_rl.orchestrator.train_source import TrainSource +from prime_rl.orchestrator.types import ( + EvalBatch, + EvalRollout, + FinishedRollout, + Policy, + Progress, + TrainBatch, + TrainRollout, +) from prime_rl.orchestrator.utils import ( compute_teacher_logprobs, get_weight_dir, - print_benchmark, - set_default_executor, -) -from prime_rl.orchestrator.vf_utils import ( - get_seq_len, intercept_vf_logging, save_rollouts, + set_default_executor, + setup_student_inference_pool, ) +from prime_rl.orchestrator.watcher import WeightWatcher from prime_rl.trainer.model import setup_tokenizer -from prime_rl.utils.client import ( - init_nccl_broadcast, - setup_inference_pool, -) -from prime_rl.utils.config import cli +from prime_rl.transport import TrainingBatch, setup_training_batch_sender +from prime_rl.utils.async_utils import safe_cancel +from prime_rl.utils.client import init_nccl_broadcast, setup_inference_pool from prime_rl.utils.heartbeat import Heartbeat -from prime_rl.utils.logger import setup_logger +from prime_rl.utils.logger import format_time, get_logger, setup_logger from prime_rl.utils.monitor import setup_monitor -from prime_rl.utils.process import set_proc_title +from prime_rl.utils.pathing import get_log_dir, get_rollout_dir, get_step_path +from prime_rl.utils.usage_reporter import UsageReporter from prime_rl.utils.utils import ( clean_exit, get_env_ids_to_install, install_env, resolve_latest_ckpt_step, - to_col_format, ) -# Hard wall-clock budget for the orchestrator's post-training cleanup. If the -# graceful shutdown sequence (scheduler / inference pool / env teardown) is -# still running after this many seconds, we force-exit the process so the run -# pod terminates instead of sitting wedged forever. The training checkpoint -# and artifacts are persisted *before* this point, so a forced exit is safe. -SHUTDOWN_TIMEOUT_S = 300 +monkey_patch_oai_iterable_types() +monkey_patch_chat_completion_logprobs() -# Maximum number of times to attempt generating a training batch when all -# rollouts are filtered out. After this many attempts, the orchestrator crashes -# rather than silently skipping training steps. -MAX_EMPTY_BATCH_ATTEMPTS = 3 +# Wall-clock budget for post-training cleanup; force-exit if graceful +# shutdown wedges (env-server ZMQ recv, vLLM admin aclose, etc) +SHUTDOWN_TIMEOUT_S = 300 -@clean_exit -async def orchestrate(config: OrchestratorConfig): - # Initialize the logger - logger = setup_logger( - config.log.level, - json_logging=config.log.json_logging, - ) - intercept_vf_logging(logger="verifiers.serve", level="WARN") # show logs from env clients - - logger.info(f"Starting orchestrator ({config.training_mode})") - - set_default_executor() - event_loop_lag_monitor = EventLoopLagMonitor() - event_loop_lag_monitor_task = asyncio.create_task(event_loop_lag_monitor.run()) - - # Print warning if running in benchmark mode - if config.bench: - logger.warning(f"Running in benchmark mode (max_steps={config.max_steps})") - - # Save configs to output directory - config_dir = config.output_dir / "control" - config_dir.mkdir(parents=True, exist_ok=True) - with open(config_dir / "orch.toml", "wb") as f: - tomli_w.dump(config.model_dump(exclude_none=True, mode="json"), f) - - # Install environments - env_ids_to_install = set() - env_ids_to_install.update(get_env_ids_to_install(config.train.env)) - if config.eval is not None: - env_ids_to_install.update(get_env_ids_to_install(config.eval.env)) - - for env_id in env_ids_to_install: - install_env(env_id, prerelease=config.env_install_prerelease) - - logger.info(f"Initializing tokenizer ({config.tokenizer})") - tokenizer = setup_tokenizer(config.tokenizer) - - # Set up student inference pool (required for all training modes). - logger.info( - f"Initializing student inference pool (base_url={', '.join(config.student.client.base_url)}, " - f"model={config.student.model.name})" - ) - renderer, student_inference = await setup_student_inference_pool( - config=config, - tokenizer=tokenizer, - logger=logger, - ) - - # Token-id → modality marker (1 = image patch, 2 = video patch) used - # to build ``mm_token_type_ids`` per sample. The renderer is the - # single source of truth — it already knows its own special-token - # IDs (``<|image_pad|>`` etc.) from the tokenizer it owns, so the - # orchestrator never needs to load a separate ``AutoProcessor``. - # Text-only renderers expose an empty map (or no attribute). - mm_token_type_ids_mapping: dict[int, int] | None = ( - getattr(renderer, "mm_token_type_id_map", None) if renderer is not None else None - ) - if mm_token_type_ids_mapping == {}: - mm_token_type_ids_mapping = None - - # Set up teacher inference pool (configured for opd or sft). Always MITO for - # simplicity - this also keeps external OAI-compatible teachers (PI inference, - # OpenAI) working as drop-in endpoints. - teacher_inference = None - if config.teacher is not None: - logger.info( - f"Initializing teacher inference pool (base_url={', '.join(config.teacher.client.base_url)}, " - f"model={config.teacher.model.name})" +# Abort after this many consecutive train batches drop all rollouts to +# post-batch filters — usually a misconfigured filter or homogeneous-reward +# dataset; fail loudly instead of spinning +MAX_CONSECUTIVE_EMPTY_BATCHES = 10 + +# Maximum batches the orchestrator may run ahead of the trainer. The +# dispatcher is paused via ``update_dispatch_gate`` once this is exceeded; +# resumed when the watcher advances ``policy.version``. +TARGET_LAG = 1 + + +class Orchestrator: + # Set in ``__init__`` + config: OrchestratorConfig + progress: Progress + policy: Policy + stopped: asyncio.Event + draining: bool + last_batch_at: float | None + consecutive_empty_batches: int + eval_triggered_at: dict[tuple[str, int], float] + ckpt_manager: CheckpointManager | None + component_tasks: list[asyncio.Task] + + # Always set by ``setup()`` + tokenizer: PreTrainedTokenizer + student_inference: InferencePool + monitor: Monitor + sender: TrainingBatchSender + train_envs: TrainEnvs + train_source: TrainSource + train_sink: TrainSink + dispatcher: RolloutDispatcher + watcher: WeightWatcher + metrics: MetricsBuilder + lag_monitor: EventLoopLagMonitor + periodic_logger: PeriodicLogger + + # Set by ``setup()`` only when relevant config is present + renderer: Renderer | None + mm_token_type_ids_mapping: dict[int, int] | None + teacher_inference: InferencePool | None + heart: Heartbeat | None + usage_reporter: UsageReporter | None + inference_metrics: InferenceMetricsCollector | None + eval_envs: EvalEnvs | None + eval_sink: EvalSink | None + eval_source: EvalSource | None + lora_name: str | None + resume_step: int | None + lag_task: asyncio.Task | None + + def __init__(self, config: OrchestratorConfig) -> None: + self.config = config + setup_logger(config.log.level, json_logging=config.log.json_logging) + # Silence in-process ``verifiers.*`` library noise but keep + # ``verifiers.serve`` (env-server lifecycle) through our handler + logging.getLogger("verifiers").setLevel(logging.CRITICAL + 1) + intercept_vf_logging(logger="verifiers.serve", level="WARN") + get_logger().info(f"Starting orchestrator ({config.training_mode})") + + if config.bench: + get_logger().warning(f"Running in benchmark mode (max_steps={config.max_steps})") + + self.progress = Progress() + self.ckpt_manager = setup_ckpt_manager(config.output_dir, config.ckpt) + self.policy = Policy(version=0, model_name="") + self.stopped = asyncio.Event() + # True after the final train step ships — pipeline winds down without + # scheduling new train rollouts + self.draining = False + # Previous ``TrainBatch`` arrival timestamp; reset every ship so + # ``step_time`` in the success log is real sink-to-sink cycle time + self.last_batch_at = None + # Trigger timestamps so eval success logs can report epoch duration + self.eval_triggered_at = {} + self.consecutive_empty_batches = 0 + self.component_tasks = [] + + # Optional attributes — ``setup()`` populates them when the relevant + # config is present + self.renderer = None + self.mm_token_type_ids_mapping = None + self.teacher_inference = None + self.heart = None + self.usage_reporter = None + self.inference_metrics = None + self.eval_envs = None + self.eval_sink = None + self.eval_source = None + self.lora_name = None + self.resume_step = None + self.lag_task = None + + # ── lifecycle ────────────────────────────────────────────────────────── + + async def setup(self) -> None: + """Install envs, load models/pools, resume from checkpoint, and + construct the pipeline components.""" + config = self.config + set_default_executor() + + # Persist the resolved config alongside the run + config_dir = config.output_dir / "control" + config_dir.mkdir(parents=True, exist_ok=True) + with open(config_dir / "orch.toml", "wb") as f: + tomli_w.dump(config.model_dump(exclude_none=True, mode="json"), f) + + env_ids_to_install = set(get_env_ids_to_install(config.train.env)) + if config.eval is not None: + env_ids_to_install.update(get_env_ids_to_install(config.eval.env)) + for env_id in env_ids_to_install: + install_env(env_id, prerelease=config.env_install_prerelease) + + get_logger().info(f"Initializing tokenizer ({config.tokenizer})") + self.tokenizer = setup_tokenizer(config.tokenizer) + + # Student inference pool + get_logger().info( + f"Initializing student inference pool (base_url={', '.join(config.student.client.base_url)}, " + f"model={config.student.model.name})" ) - teacher_inference = await setup_inference_pool( - config.teacher.client, - model_name=config.teacher.model.name, - train_client_type="openai_chat_completions", + self.renderer, self.student_inference = await setup_student_inference_pool( + config=config, tokenizer=self.tokenizer ) - - # Setup monitor (may register the run and set RUN_ID in the environment) - logger.info(f"Initializing monitor (wandb={config.wandb}, prime_monitor={config.prime_monitor})") - monitor = setup_monitor( - wandb_config=config.wandb, - prime_config=config.prime_monitor, - output_dir=config.output_dir, - tokenizer=tokenizer, - run_config=config, - keep_full_history=config.bench, - ) - - # Read run_id AFTER setup_monitor so that newly registered runs are captured - run_id = os.getenv("RUN_ID", "") - - # Usage reporter requires BOTH the base URL and the API key. Activating - # with only one set used to crash every POST inside httpx (None header - # value), so we now gate construction on both being present and log a - # clear warning when half-configured. - usage_base_url = os.environ.get("PI_USAGE_BASE_URL") - usage_api_key = os.environ.get("PI_USAGE_API_KEY") - if usage_base_url and usage_api_key: - usage_reporter = UsageReporter() - else: - if usage_base_url and not usage_api_key: - logger.warning("PI_USAGE_BASE_URL is set but PI_USAGE_API_KEY is missing; usage reporting disabled.") - usage_reporter = None - - # Setup heartbeat (only on rank 0, orchestrator is single process) - heart = None - if config.heartbeat is not None: - logger.info("Initializing heartbeat") - heart = Heartbeat(config.heartbeat.url) - - # Build rollout filters - rollout_filters = setup_filters(config.filters, vocab_size=tokenizer.vocab_size) - - # Load environments - logger.info("Loading training environments") - train_envs = TrainEnvs(config.train.env) - if config.training_mode == "sft": - # Teacher rollouts don't need inference-side logprobs (the trainer - # reconstructs teacher tokens), and some external reasoning-model - # endpoints (e.g. openai/gpt-5*) reject the parameter. - for env in train_envs: - env.sampling_args.pop("logprobs", None) - logger.info(f"Loaded {len(train_envs)} training environment(s) ({', '.join(train_envs.names)})") - - await train_envs.start( - log_dir=get_log_dir(config.output_dir.parent) / "envs" / "train", - log_level=config.log.vf_level, - json_logging=config.log.json_logging, - ) - logger.success("Train environment(s) ready") - - eval_envs: EvalEnvs | None = None - if config.eval: - logger.info("Loading eval environment(s)") - eval_envs = EvalEnvs(config.eval.env) - logger.info(f"Loaded {len(eval_envs)} eval environment(s) ({', '.join(eval_envs.names)})") - - await eval_envs.start( - log_dir=get_log_dir(config.output_dir.parent) / "envs" / "eval", - log_level=config.log.vf_level, - json_logging=config.log.json_logging, + self.mm_token_type_ids_mapping = ( + getattr(self.renderer, "mm_token_type_id_map", None) if self.renderer is not None else None ) - logger.success("Eval environment(s) ready") + if self.mm_token_type_ids_mapping == {}: + self.mm_token_type_ids_mapping = None - # Setup buffer - logger.info(f"Setting up buffer ({config.buffer})") - buffer = Buffer(train_envs, config.buffer) + if config.teacher is not None: + get_logger().info( + f"Initializing teacher inference pool (base_url={', '.join(config.teacher.client.base_url)}, " + f"model={config.teacher.model.name})" + ) + self.teacher_inference = await setup_inference_pool( + config.teacher.client, + model_name=config.teacher.model.name, + train_client_type="openai_chat_completions", + ) - # Get checkpoint manager - logger.info(f"Initializing checkpoint manager ({config.ckpt})") - ckpt_manager = setup_ckpt_manager(config.output_dir, config.ckpt) + get_logger().info(f"Initializing monitor (wandb={config.wandb}, prime_monitor={config.prime_monitor})") + self.monitor = setup_monitor( + wandb_config=config.wandb, + prime_config=config.prime_monitor, + output_dir=config.output_dir, + tokenizer=self.tokenizer, + run_config=config, + keep_full_history=config.bench, + ) - checkpoint_step = None - if config.ckpt and config.ckpt.resume_step is not None and ckpt_manager is not None: - if config.ckpt.resume_step == -1: - checkpoint_step = resolve_latest_ckpt_step(ckpt_manager.ckpt_dir) - else: - checkpoint_step = config.ckpt.resume_step - - scheduler = Scheduler( - train_envs=train_envs, - buffer=buffer, - student_inference=student_inference, - teacher_inference=teacher_inference, - max_inflight_rollouts=config.max_inflight_rollouts, - max_off_policy_steps=config.max_off_policy_steps, - tasks_per_minute=config.tasks_per_minute, - lora_name=config.student.model.lora.name if config.student.model.lora else None, - config=config, - ) - - # Wait for pools to be ready - logger.info("Waiting for student inference pool to be ready") - await student_inference.wait_for_ready(config.student.model.name) - logger.success("Student inference pool ready") - if teacher_inference is not None: - assert config.teacher is not None - logger.info("Waiting for teacher inference pool to be ready") - await teacher_inference.wait_for_ready(config.teacher.model.name) - logger.success("Teacher inference pool ready") - - # Start inference metrics collector (requires W&B) - inference_metrics_collector = None - if config.wandb is not None and config.collect_inference_metrics: - inference_metrics_collector = InferenceMetricsCollector( - student_inference.admin_clients, - roles=config.inference_metrics_roles, + if config.heartbeat is not None: + self.heart = Heartbeat(config.heartbeat.url) + + usage_base_url = os.environ.get("PI_USAGE_BASE_URL") + usage_api_key = os.environ.get("PI_USAGE_API_KEY") + if usage_base_url and usage_api_key: + self.usage_reporter = UsageReporter() + + # Filters apply to train rollouts only + pre_filters = setup_filters(config.pre_batch_filters, vocab_size=self.tokenizer.vocab_size, kind="pre-batch") + post_filters = setup_filters(config.post_batch_filters, vocab_size=self.tokenizer.vocab_size, kind="post-batch") + + get_logger().info("Loading training environments") + self.train_envs = TrainEnvs(config.train.env) + if config.training_mode == "sft": + for env in self.train_envs: + env.sampling_args.pop("logprobs", None) + get_logger().debug( + f"Loaded {len(self.train_envs)} training environment(s) ({', '.join(self.train_envs.names)})" ) - await inference_metrics_collector.start() - - # Set up weight broadcast backend (targets student inference) - logger.info(f"Initializing weight broadcast ({config.weight_broadcast})") - if config.weight_broadcast.type == "nccl": - await init_nccl_broadcast( - student_inference.admin_clients, - config.weight_broadcast.host, - config.weight_broadcast.port, - config.weight_broadcast.timeout, - inference_world_size=config.weight_broadcast.inference_world_size, - quantize_in_weight_transfer=config.weight_broadcast.quantize_in_weight_transfer, + await self.train_envs.start( + log_dir=get_log_dir(config.output_dir.parent) / "envs" / "train", + log_level=config.log.vf_level, + json_logging=config.log.json_logging, ) + get_logger().success("Train environment(s) ready") + + if config.eval is not None: + get_logger().info("Loading eval environment(s)") + self.eval_envs = EvalEnvs(config.eval.env) + get_logger().debug(f"Loaded {len(self.eval_envs)} eval environment(s) ({', '.join(self.eval_envs.names)})") + await self.eval_envs.start( + log_dir=get_log_dir(config.output_dir.parent) / "envs" / "eval", + log_level=config.log.vf_level, + json_logging=config.log.json_logging, + ) + get_logger().success("Eval environment(s) ready") - # Setup training batch sender for sending training examples to trainer - logger.info(f"Initializing training batch sender ({config.rollout_transport})") - training_batch_sender = setup_training_batch_sender(config.output_dir, config.rollout_transport) - - # Reset weights to base model if starting from scratch - progress = Progress() + if config.ckpt is not None and config.ckpt.resume_step is not None and self.ckpt_manager is not None: + if config.ckpt.resume_step == -1: + self.resume_step = resolve_latest_ckpt_step(self.ckpt_manager.ckpt_dir) + else: + self.resume_step = config.ckpt.resume_step - if checkpoint_step is not None and ckpt_manager is not None: - ckpt_manager.load(progress, buffer, step=checkpoint_step) - logger.info(f"Resuming training from checkpoint step {checkpoint_step}") - scheduler.ckpt_step = progress.step # Always resume from the latest checkpoint + # Resume below may bump ``policy.version`` and the LoRA model name + self.policy.model_name = self.student_inference.model_name - # In NCCL mode, skip existence check - weights are broadcasted, not stored on disk - check_exists = config.weight_broadcast.type != "nccl" - wait_timeout = config.ckpt.wait_for_weights_timeout if config.ckpt else None - weights_path = get_weight_dir( - config.output_dir, scheduler.ckpt_step, check_exists=check_exists, wait_timeout=wait_timeout - ) - lora_name = config.student.model.lora.name if config.student.model.lora else None - await student_inference.update_weights(weights_path, lora_name=lora_name, step=scheduler.ckpt_step) - if lora_name is not None: - student_inference.update_model_name(lora_name) - if scheduler.rollout_inference is student_inference: - scheduler.model_name = lora_name - else: - logger.info("Training from scratch") - - # Iterate over dataset in batches - logger.info(f"Starting orchestrator loop (max_steps={config.max_steps or 'infinite'})") - is_first_step = True - - while True: - # Check if this run has been evicted by the trainer - evicted_path = config.output_dir / "control" / "evicted.txt" - if evicted_path.exists(): - reason = evicted_path.read_text().strip() - raise RuntimeError(f"Run evicted by trainer: {reason}") - - # Capture ckpt_step once for consistency (it's updated inside the scheduler) - ckpt_step = scheduler.ckpt_step - scheduler.ckpt_step = ckpt_step - - # Save checkpoint (if we are at an interval step and not at the first or last step) - is_last_step = config.max_steps is not None and progress.step == config.max_steps - 1 - save_ckpt_time = 0 - if ( - ckpt_manager is not None - and (config.ckpt and config.ckpt.interval) - and not (is_first_step or is_last_step) - and progress.step % config.ckpt.interval == 0 - ): - logger.info(f"Saving checkpoint at step {progress.step}") - save_ckpt_start_time = time.perf_counter() - ckpt_manager.save(progress, buffer, step=progress.step) - save_ckpt_time = time.perf_counter() - save_ckpt_start_time - - # Break if we have reached the maximum number of steps - if config.max_steps and progress.step >= config.max_steps: - break - - logger.info(f"Starting orchestrator step {progress.step}") - step_start_time = time.perf_counter() - - # Run evals BEFORE training (blocking). Weight updates are paused via - # scheduler.checkpoint_ready during eval to ensure consistent weights. - # Each eval env has its own interval, so we check each independently. - envs_to_eval: list[EvalEnv] = [] - if config.eval: - assert eval_envs is not None - if is_first_step and checkpoint_step is not None and config.eval.skip_eval_on_resume: - logger.info(f"Skipping online eval on resume (step={progress.step})") - else: - for eval_env in eval_envs: - if progress.step % eval_env.config.interval == 0 and ( - progress.step > 0 or config.eval.eval_base_model - ): - envs_to_eval.append(eval_env) - - if envs_to_eval: - env_names = ", ".join(e.name for e in envs_to_eval) - logger.info(f"Running evals at step={progress.step} for {env_names}") - - # Pause weight updates and re-scheduling of training rollouts during eval - # to avoid evaluating across different checkpoints and avoid congestion - scheduler.checkpoint_ready.clear() - - # For heavy eval workloads, it might be necessary additionally cancel in-flight training rollouts - if config.eval.cancel_inflight_rollouts_on_eval: - logger.info("Cancelling in-flight training rollouts before starting evals to avoid congestion.") - await scheduler.cancel_inflight_rollouts() - - eval_results = await asyncio.gather( - *[ - eval_env.evaluate( - model_name=student_inference.model_name, - get_client=student_inference.get_eval_client, - step=progress.step, - cache_salt=str(ckpt_step), - ) - for eval_env in envs_to_eval - ] + get_logger().info("Waiting for student inference pool to be ready") + await self.student_inference.wait_for_ready(config.student.model.name) + get_logger().success("Student inference pool ready") + if self.teacher_inference is not None: + assert config.teacher is not None + get_logger().info("Waiting for teacher inference pool to be ready") + await self.teacher_inference.wait_for_ready(config.teacher.model.name) + get_logger().success("Teacher inference pool ready") + + if config.wandb is not None and config.collect_inference_metrics: + self.inference_metrics = InferenceMetricsCollector( + self.student_inference.admin_clients, + roles=config.inference_metrics_roles, + ) + await self.inference_metrics.start() + + get_logger().info(f"Initializing weight broadcast ({config.weight_broadcast})") + if config.weight_broadcast.type == "nccl": + await init_nccl_broadcast( + self.student_inference.admin_clients, + config.weight_broadcast.host, + config.weight_broadcast.port, + config.weight_broadcast.timeout, + inference_world_size=config.weight_broadcast.inference_world_size, + quantize_in_weight_transfer=config.weight_broadcast.quantize_in_weight_transfer, ) - # Save eval rollouts to disk (fire-and-forget background thread) - eval_rollouts = [o for outputs in eval_results for o in outputs] - if eval_rollouts: - step_path = get_step_path(get_rollout_dir(config.output_dir), progress.step) - await asyncio.to_thread( - save_rollouts, eval_rollouts, step_path / "eval_rollouts.jsonl", exclude_keys={"trajectory"} - ) + get_logger().info(f"Initializing training batch sender ({config.rollout_transport})") + self.sender = setup_training_batch_sender(config.output_dir, config.rollout_transport) - # Resume weight updates - scheduler.checkpoint_ready.set() - - # Schedule generating the training batch. Retry on empty-after-filter - # batches so the trainer never receives an empty batch. - generate_completions_time = 0.0 - train_rollouts: list[vf.RolloutOutput] = [] - num_rollouts = 0 - num_unique_examples = 0 - n_trainable = 0 - for attempt in range(MAX_EMPTY_BATCH_ATTEMPTS): - train_rollouts = await scheduler.generate_batch(step=progress.step) - generate_completions_time += scheduler.last_batch_generation_time - - # Compute advantages (in-place) - num_rollouts = len(train_rollouts) - num_unique_examples = len({(r["env_name"], r["example_id"]) for r in train_rollouts}) - await asyncio.to_thread(compute_advantages, train_rollouts, config.advantage) - - # Apply rollout filters — sets rollout["filters"] and rollout["is_filtered"] - await asyncio.to_thread(apply_filters, rollout_filters, train_rollouts) - - n_trainable = sum(1 for r in train_rollouts if not r["is_filtered"]) - if n_trainable > 0: - break + self.lora_name = config.student.model.lora.name if config.student.model.lora else None - if attempt == MAX_EMPTY_BATCH_ATTEMPTS - 1: - logger.error( - f"Attempt {attempt + 1}/{MAX_EMPTY_BATCH_ATTEMPTS} at step {progress.step} " - f"filtered out all {num_rollouts} rollouts - crashing orchestrator" - ) - reason = ( - f"All {num_rollouts} rollouts were filtered out on " - f"{MAX_EMPTY_BATCH_ATTEMPTS} consecutive attempts at step {progress.step}" - ) - evicted_path = config.output_dir / "control" / "evicted.txt" - evicted_path.parent.mkdir(parents=True, exist_ok=True) - evicted_path.write_text(reason) - raise RuntimeError(reason) - - logger.warning( - f"Attempt {attempt + 1}/{MAX_EMPTY_BATCH_ATTEMPTS} at step {progress.step} " - f"filtered out all {num_rollouts} rollouts - retrying batch generation" + if self.resume_step is not None and self.ckpt_manager is not None: + self.ckpt_manager.load(self.progress, step=self.resume_step) + get_logger().info(f"Resuming orchestrator from checkpoint step {self.resume_step}") + check_exists = config.weight_broadcast.type != "nccl" + wait_timeout = config.ckpt.wait_for_weights_timeout if config.ckpt else None + weights_path = get_weight_dir( + config.output_dir, self.progress.step, check_exists=check_exists, wait_timeout=wait_timeout ) + await self.student_inference.update_weights(weights_path, lora_name=self.lora_name, step=self.progress.step) + if self.lora_name is not None: + self.student_inference.update_model_name(self.lora_name) + self.policy.model_name = self.lora_name + self.policy.version = self.progress.step + else: + get_logger().info("Training from scratch") - trainable_ratio = n_trainable / num_rollouts - if trainable_ratio <= 0.1: - logger.warning( - f"Only {n_trainable}/{num_rollouts} rollouts in the batch are trainable " - f"({trainable_ratio:.1%}) - this can mean the tasks are too easy or too hard for the " - "model, consider reviewing the task difficulty of your environment(s)" + # SFT generates rollouts via the teacher (the student is trained on + # the teacher's outputs); RL / OPD generate via the student + if config.training_mode == "sft": + assert self.teacher_inference is not None, "sft mode requires teacher inference" + rollout_inference = self.teacher_inference + else: + rollout_inference = self.student_inference + + self.train_source = TrainSource(self.train_envs, seed=42) + self.eval_source: EvalSource | None = ( + EvalSource( + self.eval_envs, + config.eval, + is_resumed=self.resume_step is not None, ) + if config.eval is not None and self.eval_envs is not None + else None + ) - # Save train rollouts to disk (fire-and-forget background thread) - step_path = get_step_path(get_rollout_dir(config.output_dir), progress.step) - await asyncio.to_thread( - save_rollouts, train_rollouts, step_path / "train_rollouts.jsonl", exclude_keys={"trajectory"} + assert config.max_inflight_rollouts is not None, "max_inflight_rollouts must be resolved before dispatcher init" + log_interval = config.log.interval + wandb_enabled = config.wandb is not None + self.dispatcher = RolloutDispatcher( + train_envs=self.train_envs, + eval_envs=self.eval_envs, + train_source=self.train_source, + eval_source=self.eval_source, + inference=rollout_inference, + eval_inference=self.student_inference, + policy=self.policy, + max_inflight_rollouts=config.max_inflight_rollouts, + tasks_per_minute=config.tasks_per_minute, + max_off_policy_steps=config.max_off_policy_steps, + training_mode=config.training_mode, + ) + self.metrics = MetricsBuilder(config) + self.train_sink = TrainSink( + config, + tokenizer=self.tokenizer, + renderer=self.renderer, + train_envs=self.train_envs, + mm_token_type_ids_mapping=self.mm_token_type_ids_mapping, + batch_size=config.batch_size, + token_batch_size=config.token_batch_size, + advantage_config=config.advantage, + pre_filters=pre_filters, + post_filters=post_filters, ) + self.eval_sink = EvalSink(eval_envs=self.eval_envs) if self.eval_envs is not None else None + self.watcher = WeightWatcher( + config, + policy=self.policy, + inference=self.student_inference, + observers=[self.dispatcher, self], + lora_name=self.lora_name, + ckpt_step=self.progress.step, + ) + # Single periodic logger for the whole pipeline. It's the only + # consumer of ``dispatcher.metrics.drained()`` (which clears on read) + self.lag_monitor = EventLoopLagMonitor() + self.periodic_logger = PeriodicLogger( + name="Pipeline", + collect=self.collect_pipeline_view, + metric_keys=[ + *list(self.dispatcher.gauges().keys()), + *DispatcherMetrics.drain_keys( + train_envs={e.name for e in self.train_envs}, + eval_envs={e.name for e in self.eval_envs} if self.eval_envs is not None else set(), + ), + *list(self.watcher.gauges().keys()), + "event_loop_lag/min", + "event_loop_lag/mean", + "event_loop_lag/median", + "event_loop_lag/p90", + "event_loop_lag/p99", + "event_loop_lag/max", + "event_loop_lag/n", + ], + interval=log_interval, + wandb_enabled=wandb_enabled, + ) + + async def start(self) -> None: + """Run the orchestrator until shutdown. Drives setup, spawns the + background tasks, runs the main loop in this task, then cleans up.""" + await self.setup() + config = self.config + get_logger().info(f"Starting orchestrator loop (max_steps={config.max_steps or 'infinite'})") + start_time = time.perf_counter() + + # Spawn background loops (dispatcher schedules, watcher polls). The + # pipeline ``main_loop`` runs inline in this task; the single + # ``PeriodicLogger`` polls dispatcher / watcher / sinks / lag + # monitor each ``log.interval`` seconds for the pipeline-view log + self.lag_task = asyncio.create_task(self.lag_monitor.run(), name="event_loop_lag") + await self.periodic_logger.start() + self.component_tasks = [ + asyncio.create_task(self.dispatcher.start(), name="dispatcher"), + asyncio.create_task(self.watcher.start(), name="watcher"), + ] + + # Default step-0 base-model eval — fires before any train rollouts + # unless ``eval.skip_first_step=True`` (or this is a resume) + self.maybe_trigger_eval(self.progress.step) + + # Anchor step-time clock so step 0 measures startup → first batch + self.last_batch_at = time.perf_counter() + + # ``clean_exit`` stays False if ``main_loop`` raises (signal-driven + # CancelledError, KeyboardInterrupt, or a real error), so the teardown + # logs a forced-cleanup warning instead of a clean-exit success. + clean_exit = False + try: + await self.main_loop() + clean_exit = True + finally: + elapsed = format_time(time.perf_counter() - start_time) + if clean_exit: + get_logger().success(f"Orchestrator step loop done in {elapsed}") + else: + get_logger().warning(f"Orchestrator interrupted after {elapsed} — forcing cleanup (not a clean exit)") + self.monitor.save_final_summary() + if self.ckpt_manager is not None: + get_logger().info("Writing final checkpoint") + self.ckpt_manager.save(self.progress, step=self.progress.step) + await self.stop() + if clean_exit: + get_logger().success("Orchestrator finished.") + else: + get_logger().warning("Orchestrator cleanup complete (forced).") + try: + ctypes.CDLL("libc.so.6").malloc_trim(0) + except Exception as e: + get_logger().debug(f"malloc_trim(0) failed: {e}") + + async def main_loop(self) -> None: + """Consume ``FinishedRollout``\\ s from the dispatcher and route them + to the train / eval sink. Both sinks return a finalized batch (or + ``None``) from ``add()``; we just dispatch on the result.""" + while not self.stopped.is_set(): + if self.draining and self.dispatcher.is_idle: + get_logger().info("Pipeline drained, exiting main loop") + self.stopped.set() + break - # Offload base64 images to disk to free memory. No-op for text-only - # rollouts (no ``data:image`` URLs to find); cheap to call always. - offload_start = time.perf_counter() - num_offloaded = offload_images_to_disk(train_rollouts, config.output_dir) - if num_offloaded: - logger.info( - f"Offloaded {num_offloaded} unique images to disk in {time.perf_counter() - offload_start:.2f}s" + try: + rollout: FinishedRollout = await asyncio.wait_for(self.dispatcher.out_q.get(), timeout=0.5) + except asyncio.TimeoutError: + continue + + if isinstance(rollout, EvalRollout): + assert self.eval_sink is not None # eval rollouts only emitted when eval is configured + eval_batch = self.eval_sink.add(rollout) + if eval_batch is not None: + self.finalize_eval_batch(eval_batch) + continue + + assert isinstance(rollout, TrainRollout) + train_batch = await self.train_sink.add(rollout) + # In drain mode any late-arriving train batch is dropped — we + # don't want to ship past ``max_steps`` + if train_batch is not None and not self.draining and not self.stopped.is_set(): + await self.finalize_train_batch(train_batch) + + async def finalize_train_batch(self, batch: TrainBatch) -> None: + """Ship one ``TrainBatch`` out to the trainer and handle the I/O + side-effects (ckpt, save_rollouts, teacher logprobs, sender.send, + metrics, heartbeat, progress, eval trigger). The sink has already + done all data-transformation work.""" + config = self.config + step = self.progress.step + + # Sink-to-sink cycle time — the actual time between batches, not + # including the orchestrator's ship I/O (overlapped with the + # dispatcher producing the next batch) + now = time.perf_counter() + step_time = (now - self.last_batch_at) if self.last_batch_at is not None else 0.0 + self.last_batch_at = now + + save_ckpt_time = await self.maybe_save_ckpt(step) + + if config.max_steps is not None and step >= config.max_steps: + self.draining = True + self.dispatcher.disable_train_scheduling() + n_cancelled = await self.dispatcher.cancel_inflight_train_rollouts() + get_logger().info( + f"Draining pipeline (cancelled {n_cancelled} in-flight train rollout(s); " + f"any in-flight evals will complete)" ) + return - # Convert rollouts to training samples - parallel_preprocess_start = time.perf_counter() - - # We only expect to backfill tokens for training_mode=sft against an - # external teacher API (OpenAI/etc.), which returns no token IDs — - # reconstruct via tokenizer/renderer. The vLLM-served paths (RL/OPD - # renderer + MITO, and training_mode=sft against a local vLLM teacher) - # already populate tokens via prompt_token_ids/token_ids, so we - # short-circuit the 256-way fanout. - needs_backfill = any(step["tokens"] is None for rollout in train_rollouts for step in rollout["trajectory"]) - if needs_backfill: - logger.info( - "Backfilling tokens for rollout trajectories (expected for training_mode=sft against an external teacher API)" + if batch.metrics.n_trainable == 0: + self.consecutive_empty_batches += 1 + get_logger().warning( + f"Step {step}: post-batch filters dropped all {len(batch.rollouts)} rollouts " + f"(consecutive empty batches: {self.consecutive_empty_batches}/{MAX_CONSECUTIVE_EMPTY_BATCHES})" ) - await asyncio.gather( - *( - asyncio.to_thread( - backfill_rollout_tokens, - rollout, - tokenizer, - renderer=renderer, - ) - for rollout in train_rollouts + if self.consecutive_empty_batches >= MAX_CONSECUTIVE_EMPTY_BATCHES: + raise RuntimeError( + f"{self.consecutive_empty_batches} consecutive zero-trainable batches — " + "check filter config (pre_batch_filters / post_batch_filters) or task difficulty." ) + return + self.consecutive_empty_batches = 0 + if batch.metrics.n_trainable / len(batch.rollouts) <= 0.1: + get_logger().warning( + f"Only {batch.metrics.n_trainable}/{len(batch.rollouts)} rollouts in the batch are trainable " + f"({batch.metrics.n_trainable / len(batch.rollouts):.1%}) — consider reviewing task difficulty / filter config" ) - # Process rollouts in parallel - results = await asyncio.gather( - *( - asyncio.to_thread(interleave_rollout, r, mm_token_type_ids_mapping=mm_token_type_ids_mapping) - for r in train_rollouts - ) - ) - - # Collect results and assign advantages. Metrics are computed over all - # rollouts; only non-filtered samples are sent to the trainer. - train_examples: list[TrainingSample] = [] - rollout_prefill_lens: list[int] = [] - rollout_decode_lens: list[int] = [] - rollout_samples_per_rollout: list[int] = [] - num_prefill_tokens = 0 - num_decode_tokens = 0 - for rollout, samples in zip(train_rollouts, results): - rollout_prefill_tokens = 0 - rollout_decode_tokens = 0 - if samples is None: - samples = [] - rollout_samples_per_rollout.append(len(samples)) - for sample in samples: - sample.advantage = rollout["advantage"] - sample.reward = rollout["reward"] - sample.env_name = rollout["env_name"] - sample.training_mode = config.training_mode - sample_decode_tokens = sum(sample.completion_mask) - sample_prefill_tokens = len(sample.prompt_ids) + len(sample.completion_mask) - sample_decode_tokens - rollout_decode_tokens += sample_decode_tokens - rollout_prefill_tokens += sample_prefill_tokens - if not rollout["is_filtered"]: - train_examples.append(sample) - rollout_prefill_lens.append(rollout_prefill_tokens) - rollout_decode_lens.append(rollout_decode_tokens) - num_prefill_tokens += rollout_prefill_tokens - num_decode_tokens += rollout_decode_tokens - - parallel_preprocess_time = time.perf_counter() - parallel_preprocess_start - logger.debug( - f"Converted {len(train_rollouts)} rollouts ({num_unique_examples} unique examples) " - f"to {len(train_examples)} training examples" + # Materialize at the I/O boundary so prime-rl metadata travels with + # the raw vf payload on disk + in wandb sample tables + rollout_dicts = [r.to_dict() for r in batch.rollouts] + step_path = get_step_path(get_rollout_dir(config.output_dir), step) + await asyncio.to_thread( + save_rollouts, rollout_dicts, step_path / "train_rollouts.jsonl", exclude_keys={"trajectory"} ) - # Compute teacher logprobs (opd only - sft trains on teacher tokens directly) - teacher_logprobs_time = 0 - if config.training_mode == "opd" and teacher_inference is not None: + teacher_logprobs_time = 0.0 # opd only + if config.training_mode == "opd" and self.teacher_inference is not None: assert config.teacher is not None - logger.info(f"Computing teacher logprobs for {len(train_examples)} training examples") - teacher_logprobs_start_time = time.perf_counter() + t = time.perf_counter() teacher_logprobs_list = await compute_teacher_logprobs( - clients=teacher_inference.train_clients, + clients=self.teacher_inference.train_clients, model_name=config.teacher.model.name, - samples=train_examples, + samples=batch.samples, ) - for train_example, teacher_logprobs in zip(train_examples, teacher_logprobs_list): - train_example.teacher_logprobs = teacher_logprobs - teacher_logprobs_time = time.perf_counter() - teacher_logprobs_start_time - logger.debug(f"Computed teacher logprobs in {teacher_logprobs_time:.2f}s") - - training_batch = TrainingBatch( - examples=train_examples, - step=progress.step, + for ex, lp in zip(batch.samples, teacher_logprobs_list): + ex.teacher_logprobs = lp + teacher_logprobs_time = time.perf_counter() - t + + await self.sender.send(TrainingBatch(examples=batch.samples, step=step)) + self.update_dispatch_gate() + + metrics = self.metrics.build( + step=step, + rollouts=batch.rollouts, + metrics=batch.metrics, + progress=self.progress, + step_time=step_time, + save_ckpt_time=save_ckpt_time, + teacher_logprobs_time=teacher_logprobs_time, + pre_filter_seen=self.train_sink.pre_filter_seen, + pre_filter_dropped=self.train_sink.pre_filter_dropped, + pre_filter_dropped_by_name=dict(self.train_sink.pre_filter_dropped_by_name), ) - - await training_batch_sender.send(training_batch) - - step_time = time.perf_counter() - step_start_time - - # Gather metrics in dataframes - results_df = pd.DataFrame( - { - "example_id": [rollout["example_id"] for rollout in train_rollouts], - "env_name": [rollout["env_name"] for rollout in train_rollouts], - "reward": [rollout["reward"] for rollout in train_rollouts], - "is_truncated": [rollout["is_truncated"] for rollout in train_rollouts], - "is_filtered": [rollout["is_filtered"] for rollout in train_rollouts], - "stop_condition": [rollout.get("stop_condition") for rollout in train_rollouts], - "seq_len": [get_seq_len(rollout) for rollout in train_rollouts], - "prefill_len": rollout_prefill_lens, - "decode_len": rollout_decode_lens, - "samples_per_rollout": rollout_samples_per_rollout, - "num_turns": [len(rollout["trajectory"]) for rollout in train_rollouts], - } - ) - - # Separate DataFrames for env reward function metrics, filter flags, and per-rollout timings - # to avoid column name collisions - metrics_df = pd.DataFrame([rollout["metrics"] for rollout in train_rollouts]) - filter_df = pd.DataFrame([rollout["filters"] for rollout in train_rollouts]) - timing_df = pd.DataFrame( - [ - { - "total": rollout["timing"]["total"], - "setup": rollout["timing"]["setup"]["duration"], - "generation": rollout["timing"]["generation"]["duration"], - "model": rollout["timing"]["model"]["duration"], - "env": rollout["timing"]["env"]["duration"], - "scoring": rollout["timing"]["scoring"]["duration"], - "overhead": rollout["timing"]["overhead"], - } - for rollout in train_rollouts - ] - ) - - # Update progress metrics - num_tokens = int(results_df.seq_len.sum()) - progress.total_tokens += num_tokens - progress.total_samples += num_rollouts - progress.total_problems += num_unique_examples - - def compute_solve_rates(df): - """Compute solve_none, solve_all, effective_batch_size for a set of rollouts.""" - reward_per_problem = df.groupby(["env_name", "example_id"]).reward.sum() - solve_none = (reward_per_problem == 0).mean() - solve_all = (reward_per_problem == config.group_size).mean() - return solve_none, solve_all, 1 - solve_none - solve_all - - # Group by (env_name, example_id) to average across rollouts within each problem - by_example = results_df.groupby(["env_name", "example_id"]) - - solve_none, solve_all, effective_batch_size = compute_solve_rates(results_df) - to_log = { - # Progress metrics - "progress/tokens": num_tokens, - "progress/prefill_tokens": num_prefill_tokens, - "progress/decode_tokens": num_decode_tokens, - "progress/samples": num_rollouts, - "progress/problems": num_unique_examples, - "progress/total_tokens": progress.total_tokens, - "progress/total_samples": progress.total_samples, - "progress/total_problems": progress.total_problems, - # Sequence length metrics - "seq_len/all/mean": by_example.seq_len.mean().mean(), - "seq_len/all/max": by_example.seq_len.mean().max(), - "seq_len/all/min": by_example.seq_len.mean().min(), - "prefill_len/all/mean": by_example.prefill_len.mean().mean(), - "prefill_len/all/max": by_example.prefill_len.mean().max(), - "prefill_len/all/min": by_example.prefill_len.mean().min(), - "decode_len/all/mean": by_example.decode_len.mean().mean(), - "decode_len/all/max": by_example.decode_len.mean().max(), - "decode_len/all/min": by_example.decode_len.mean().min(), - "is_truncated/all/mean": by_example.is_truncated.mean().mean(), - "is_truncated/all/max": by_example.is_truncated.mean().max(), - "stop_condition/all/generation_truncated": ( - results_df.is_truncated & (results_df.stop_condition != "prompt_too_long") - ).mean(), - **{ - f"stop_condition/all/{sc}": rate - for sc, rate in results_df.stop_condition.dropna().value_counts(normalize=True).items() - }, - "samples_per_rollout/all/mean": by_example.samples_per_rollout.mean().mean(), - "samples_per_rollout/all/max": by_example.samples_per_rollout.mean().max(), - "samples_per_rollout/all/min": by_example.samples_per_rollout.mean().min(), - "num_turns/all/mean": by_example.num_turns.mean().mean(), - "num_turns/all/max": by_example.num_turns.mean().max(), - "num_turns/all/min": by_example.num_turns.mean().min(), - **{ - f"timing/all/{key}/{stat}": getattr( - timing_df[key].groupby([results_df.env_name, results_df.example_id]).mean(), - stat, - )() - for key in timing_df.columns - for stat in ("mean", "max", "min") - }, - # Train reward - "reward/all/mean": by_example.reward.mean().mean(), - "reward/all/max": by_example.reward.mean().max(), - "reward/all/min": by_example.reward.mean().min(), - # Solve / batch metrics - "solve_none/all": solve_none, - "solve_all/all": solve_all, - "effective_batch_size/all": effective_batch_size, - **{f"batch/{env}": r for env, r in results_df.env_name.value_counts(normalize=True).items()}, - # Time metrics - "time/step": step_time, - "time/generate_completions": generate_completions_time, - "time/teacher_logprobs": teacher_logprobs_time, - "time/save_ckpt": save_ckpt_time, - "time/parallel_preprocess": parallel_preprocess_time, - # Scheduler metrics - **scheduler.get_metrics(), - # Buffer metrics - **buffer.get_metrics(), - # Event loop lag metrics - **event_loop_lag_monitor.get_metrics(), - # Rollout filter metrics (detection rate per filter + overall drop rate) - "filters/all/is_filtered": results_df.is_filtered.astype(float).mean(), - **{f"filters/all/{name}": filter_df[name].astype(float).mean() for name in filter_df.columns}, - # W&B axis - "step": progress.step, - } - - # Per-env metrics - per_env_columns = [ - "seq_len", - "prefill_len", - "decode_len", - "is_truncated", - "samples_per_rollout", - "num_turns", - ] - - for env, env_df in results_df.groupby("env_name"): - env_by_example = env_df.groupby("example_id") - for col in per_env_columns: - to_log[f"{col}/{env}/mean"] = env_by_example[col].mean().mean() - to_log[f"{col}/{env}/max"] = env_by_example[col].mean().max() - if col != "is_truncated": - to_log[f"{col}/{env}/min"] = env_by_example[col].mean().min() - env_timing_df = timing_df.loc[env_df.index] - for key in timing_df.columns: - per_example = env_timing_df.groupby(env_df["example_id"])[key].mean() - to_log[f"timing/{env}/{key}/mean"] = per_example.mean() - to_log[f"timing/{env}/{key}/max"] = per_example.max() - to_log[f"timing/{env}/{key}/min"] = per_example.min() - to_log[f"reward/{env}/mean"] = env_by_example.reward.mean().mean() - to_log[f"reward/{env}/max"] = env_by_example.reward.mean().max() - to_log[f"reward/{env}/min"] = env_by_example.reward.mean().min() - solve_none, solve_all, effective_batch_size = compute_solve_rates(env_df) - to_log[f"solve_none/{env}"] = solve_none - to_log[f"solve_all/{env}"] = solve_all - to_log[f"effective_batch_size/{env}"] = effective_batch_size - to_log[f"stop_condition/{env}/generation_truncated"] = ( - env_df.is_truncated & (env_df.stop_condition != "prompt_too_long") - ).mean() - for sc, rate in env_df.stop_condition.dropna().value_counts(normalize=True).items(): - to_log[f"stop_condition/{env}/{sc}"] = rate - env_metrics_df = metrics_df.loc[env_df.index] - for metric in metrics_df.columns: - to_log[f"metrics/{env}/{metric}"] = env_metrics_df.groupby(env_df["example_id"])[metric].mean().mean() - to_log[f"filters/{env}/is_filtered"] = env_df.is_filtered.astype(float).mean() - env_filter_df = filter_df.loc[env_df.index] - for name in filter_df.columns: - to_log[f"filters/{env}/{name}"] = env_filter_df[name].astype(float).mean() - - # Log metrics to monitor(s) - monitor.log(to_log, step=progress.step) - - # Log samples to monitor(s) if enabled. - monitor.log_samples(train_rollouts, step=progress.step) - - # Log distributions (rewards, advantages) if enabled - monitor.log_distributions( + self.monitor.log(metrics, step=step) + self.monitor.log_samples(rollout_dicts, step=step) + self.monitor.log_distributions( distributions={ - "rewards": [r["reward"] for r in train_rollouts], - "advantages": [r["advantage"] for r in train_rollouts], + "rewards": [r.reward for r in batch.rollouts], + "advantages": [r.advantage for r in batch.rollouts if r.advantage is not None], }, - step=progress.step, + step=step, ) - if usage_reporter and run_id: - usage_reporter.report_training_usage( - run_id=run_id, - step=progress.step, - tokens=num_prefill_tokens + num_decode_tokens, - ) - - reward_mean = by_example.reward.mean().mean() - step_message = f"Step {progress.step} | Time: {step_time:.2f}s | Reward: {reward_mean:.4f} | Seq. Length: {by_example.seq_len.mean().mean():.1f} tokens/sample | Max. Off-Policy Level: {scheduler.max_off_policy_level}" - logger.success(step_message) - - # Increment step - progress.step += 1 - is_first_step = False - - # Free large per-step objects to prevent memory accumulation - del train_rollouts, train_examples, training_batch - del results_df, metrics_df - gc.collect() - # Return free glibc heap pages to the OS. numpy/pandas allocate array data - # via malloc (outside Python's allocator), so gc.collect() alone doesn't - # reclaim the RSS. malloc_trim(0) forces glibc to return freed pages. - try: - ctypes.CDLL("libc.so.6").malloc_trim(0) - except Exception as e: - logger.warning(f"malloc_trim(0) failed - RSS may grow unboundedly: {e}") - - event_loop_lag_monitor.reset() - - # Send heartbeat if configured - if heart is not None: - heart.beat() - - if config.eval and eval_envs is not None: - logger.info("Running final evals") - eval_results = await asyncio.gather( - *[ - eval_env.evaluate( - model_name=student_inference.model_name, - get_client=student_inference.get_eval_client, - step=progress.step, - cache_salt=str(ckpt_step), + if self.usage_reporter is not None: + run_id = os.getenv("RUN_ID", "") + if run_id: + self.usage_reporter.report_training_usage( + run_id=run_id, + step=step, + tokens=batch.metrics.num_prefill_tokens + batch.metrics.num_decode_tokens, ) - for eval_env in eval_envs - ] + if self.heart is not None: + self.heart.beat() + + num_rollouts = len(batch.rollouts) + num_unique_examples = len({r.group_id for r in batch.rollouts}) + num_tokens = sum( + r.raw["token_usage"]["final_input_tokens"] + r.raw["token_usage"]["final_output_tokens"] + for r in batch.rollouts ) - - # Save final eval rollouts to disk - eval_rollouts = [o for outputs in eval_results for o in outputs] - if eval_rollouts: - step_path = get_step_path(get_rollout_dir(config.output_dir), progress.step) - await asyncio.to_thread( - save_rollouts, eval_rollouts, step_path / "eval_rollouts.jsonl", exclude_keys={"trajectory"} + self.progress.total_tokens += num_tokens + self.progress.total_samples += num_rollouts + self.progress.total_problems += num_unique_examples + + self.log_train_batch(batch, step=step, step_time=step_time) + + self.train_sink.reset_pre_filter_stats() + self.progress.step += 1 + self.maybe_trigger_eval(self.progress.step) + + def maybe_trigger_eval(self, step: int) -> None: + """Fire eligible eval epochs and flip to ``PREFER_EVAL`` if anything + fires. No-op when eval is not configured.""" + if self.eval_source is None: + return + fired = self.eval_source.trigger(step) + if not fired: + return + reason = f"eval was triggered at step {step}" + self.dispatcher.switch_mode(DispatcherMode.PREFER_EVAL, reason=reason) + now = time.perf_counter() + for env_name in fired: + self.eval_triggered_at[(env_name, step)] = now + assert self.eval_envs is not None + total_rollouts = sum( + self.eval_envs.get(env_name).config.group_size * len(self.eval_envs.get(env_name).examples) + for env_name in fired + ) + get_logger().info(f"Starting evals in {', '.join(fired)} ({total_rollouts} total rollouts)") + + def collect_pipeline_view(self) -> tuple[str, dict[str, float]]: + """Pipeline view for the orchestrator's ``PeriodicLogger``. Returns + ``(console_body, wandb_payload)``. Per-env ``(env=N, …)`` + breakdowns inline only when there's more than one train / eval env; + the eval halves drop entirely when nothing is accumulating.""" + disp_gauges = self.dispatcher.gauges() + disp_drain = self.dispatcher.metrics.drained( + train_envs={e.name for e in self.train_envs}, + eval_envs={e.name for e in self.eval_envs} if self.eval_envs is not None else set(), + ) + watcher_gauges = self.watcher.gauges() + lag_stats = EventLoopLagStats.from_monitor(self.lag_monitor) + + inflight_by_env = self.dispatcher.inflight_by_env + inflight_train = self.dispatcher.inflight_train_count + inflight_eval = self.dispatcher.inflight_eval_count + train_batch, train_target, _train_unit = self.train_sink.batch_progress() + train_buffered = self.train_sink.buffered_count() + train_batch_by_env = self.train_sink.pending_batch_by_env() + eval_batches = self.eval_sink.batch_progress() if self.eval_sink is not None else [] + multi_train = len(self.train_envs) > 1 + multi_eval = self.eval_envs is not None and len(self.eval_envs) > 1 + + # Train batch: finalized-group survivors only (0→target). Partial-group + # arrivals are surfaced as a separate ``(+N buffered)`` addendum + train_pct = train_batch / train_target if train_target else 0.0 + train_batch_part = f"Train batch {train_batch}/{train_target} ({train_pct:.1%})" + if multi_train: + pairs = [(e.name, train_batch_by_env.get(e.name, 0)) for e in self.train_envs] + train_batch_part += " (" + ", ".join(f"{n}={v}" for n, v in pairs) + ")" + if train_buffered: + train_batch_part += f" (+{train_buffered} buffered)" + + eval_batch_part = "" + for env, _step, eb, exp, _ebuf in eval_batches: + eval_pct = eb / exp if exp else 0.0 + eval_batch_part += f" | {env} {eb}/{exp} ({eval_pct:.1%})" + + # Unified inflight tail: total, then train/eval split, then per-env + # (only when more than one env of a kind makes the split ambiguous) + inflight_part = ( + f"{inflight_train + inflight_eval} inflight rollouts (train={inflight_train}, eval={inflight_eval}" + ) + if multi_train or multi_eval: + env_pairs = [(e.name, inflight_by_env.get(("train", e.name), 0)) for e in self.train_envs] + if self.eval_envs is not None: + env_pairs += [(e.name, inflight_by_env.get(("eval", e.name), 0)) for e in self.eval_envs] + inflight_part += " | " + ", ".join(f"{n}={v}" for n, v in env_pairs) + inflight_part += ")" + + body = train_batch_part + eval_batch_part + "; " + inflight_part + + payload: dict[str, float] = {**disp_gauges, **disp_drain, **watcher_gauges} + if lag_stats.n > 0: + payload["event_loop_lag/min"] = lag_stats.min + payload["event_loop_lag/mean"] = lag_stats.mean + payload["event_loop_lag/median"] = lag_stats.median + payload["event_loop_lag/p90"] = lag_stats.p90 + payload["event_loop_lag/p99"] = lag_stats.p99 + payload["event_loop_lag/max"] = lag_stats.max + payload["event_loop_lag/n"] = float(lag_stats.n) + return body, payload + + def log_train_batch(self, batch: TrainBatch, *, step: int, step_time: float) -> None: + """Per-step ``Step …`` success line. Multi-env runs append an + indented ``╰─`` line per env. ``Error`` is relative to arrivals at + the sink (errored rollouts may have been group-dropped before + reaching ``batch.rollouts``).""" + n_arrivals_total = sum(batch.metrics.arrivals_by_env.values()) + n_errors_total = sum(batch.metrics.errors_by_env.values()) + n_survivors = len(batch.rollouts) + n_trainable = batch.metrics.n_trainable + error_rate = (n_errors_total / n_arrivals_total) if n_arrivals_total else 0.0 + trainable_rate = (n_trainable / n_survivors) if n_survivors else 0.0 + reward_mean = sum(r.reward for r in batch.rollouts) / max(n_survivors, 1) + max_off_policy = max((r.off_policy_steps for r in batch.rollouts), default=0) + turns_mean = sum(len(r.raw.get("trajectory") or []) for r in batch.rollouts) / max(n_survivors, 1) + truncation_rate = sum(1 for r in batch.rollouts if r.is_truncated) / max(n_survivors, 1) + + head = ( + f"Step {step} | {format_time(step_time):>7} | Reward {reward_mean:.4f} | " + f"Trainable {n_trainable}/{n_survivors} ({trainable_rate:.1%}) | " + f"Turns {turns_mean:.1f} | Max Off-Policy {max_off_policy} | " + f"Error {error_rate:.1%} | Truncation {truncation_rate:.1%}" + ) + if len(self.train_envs) <= 1: + get_logger().success(head) + return + + env_names = sorted(set(batch.metrics.arrivals_by_env) | {r.env_name for r in batch.rollouts}) + name_width = max(len(n) for n in env_names) if env_names else 0 + lines = [head] + for env_name in env_names: + env_rollouts = [r for r in batch.rollouts if r.env_name == env_name] + n_env_arrivals = batch.metrics.arrivals_by_env.get(env_name, 0) + n_env_errors = batch.metrics.errors_by_env.get(env_name, 0) + ratio = (n_env_arrivals / n_arrivals_total) if n_arrivals_total else 0.0 + env_error_rate = (n_env_errors / n_env_arrivals) if n_env_arrivals else 0.0 + env_reward = (sum(r.reward for r in env_rollouts) / len(env_rollouts)) if env_rollouts else 0.0 + env_max_off_policy = max((r.off_policy_steps for r in env_rollouts), default=0) + env_turns = ( + sum(len(r.raw.get("trajectory") or []) for r in env_rollouts) / len(env_rollouts) + if env_rollouts + else 0.0 ) - - monitor.save_final_summary() - - # Write final checkpoint - if ckpt_manager is not None: - logger.info("Writing final checkpoint") - ckpt_manager.save(progress, buffer, step=progress.step) - - # Bounded best-effort cleanup. Each await below may block on a remote peer - # (env-server ZMQ recv, inference admin httpx aclose, etc.). The outer - # asyncio.wait gives the whole sequence a single deadline; if anything - # wedges past SHUTDOWN_TIMEOUT_S we force-exit the process. Individual - # awaits intentionally do NOT have their own timeouts — asyncio.wait_for - # would itself hang on an uncancellable await, which is exactly the - # failure mode we're guarding against. - async def _graceful_shutdown() -> None: - training_batch_sender.close() - await scheduler.stop() - if inference_metrics_collector is not None: - await inference_metrics_collector.stop() - await student_inference.stop() - if teacher_inference is not None: - await teacher_inference.stop() - event_loop_lag_monitor_task.cancel() - # Shutdown env processes (also registered as atexit handler for crash safety) - train_envs.shutdown() - if eval_envs is not None: - eval_envs.shutdown() - - shutdown_task = asyncio.create_task(_graceful_shutdown()) - _, pending = await asyncio.wait({shutdown_task}, timeout=SHUTDOWN_TIMEOUT_S) - - if pending: - logger.warning( - f"Orchestrator shutdown did not complete within {SHUTDOWN_TIMEOUT_S}s; " - "forcing process exit. Training artifacts are already persisted." + env_truncation = sum(1 for r in env_rollouts if r.is_truncated) / len(env_rollouts) if env_rollouts else 0.0 + lines.append( + f"╰─ {env_name:<{name_width}} | Ratio {ratio:.1%} | Reward {env_reward:.4f} | " + f"Turns {env_turns:.1f} | Max Off-Policy {env_max_off_policy} | " + f"Error {env_error_rate:.1%} | Truncation {env_truncation:.1%}" + ) + get_logger().success("\n\t\t ".join(lines)) + + def finalize_eval_batch(self, batch: EvalBatch) -> None: + """Persist + log one completed eval epoch (save_rollouts, + monitor.log_eval_samples, monitor.log).""" + if not batch.rollouts: + get_logger().warning(f"Eval @ step={batch.step} env={batch.env_name}: no surviving rollouts, skipping log") + return + + rollout_dicts = [r.to_dict() for r in batch.rollouts] + step_path = get_step_path(get_rollout_dir(self.config.output_dir), batch.step) + save_rollouts( + rollout_dicts, + step_path / f"eval_rollouts_{batch.env_name}.jsonl", + exclude_keys={"trajectory"}, + ) + self.monitor.log_eval_samples(rollout_dicts, env_name=batch.env_name, step=batch.step) + self.monitor.log(batch.metrics.to_wandb_dict(env_name=batch.env_name, step=batch.step), step=batch.step) + + n_total = batch.metrics.n_rollouts + error_rate = ((batch.metrics.n_cancelled + batch.metrics.n_errored) / n_total) if n_total else 0.0 + max_off_policy = max((r.off_policy_steps for r in batch.rollouts), default=0) + triggered_at = self.eval_triggered_at.pop((batch.env_name, batch.step), None) + elapsed = (time.perf_counter() - triggered_at) if triggered_at is not None else 0.0 + + get_logger().success( + f"Evaluated {batch.env_name} (Step {batch.step}) | " + f"{format_time(elapsed):>7} | Reward {batch.metrics.reward_mean:.4f} | " + f"Turns {batch.metrics.num_turns_mean:.1f} | Max Off-Policy {max_off_policy} | " + f"Error {error_rate:.1%} | Truncation {batch.metrics.truncation_rate:.1%}" ) - os._exit(0) - # asyncio.wait swallows task exceptions; re-raise so a fast cleanup - # failure surfaces the same way as it did when each step was awaited - # directly. - await shutdown_task + async def maybe_save_ckpt(self, step: int) -> float: + """Save the checkpoint if we're at an interval boundary. Returns + elapsed time (0.0 when no save happened).""" + if self.ckpt_manager is None or self.config.ckpt is None or not self.config.ckpt.interval: + return 0.0 + if step <= 0: + return 0.0 + # Skip only the drain-entry step (step == max_steps, which never ships): + # it would double-save with the final checkpoint in ``start()`` (also at + # progress.step == max_steps). The last *shipped* step (max_steps - 1) is + # NOT skipped — the trainer saves there (its is_last_step is max_steps), + # so the orchestrator must too or resume from that interval ckpt breaks. + near_end = self.config.max_steps is not None and step >= self.config.max_steps + if near_end: + return 0.0 + if step % self.config.ckpt.interval != 0: + return 0.0 + get_logger().info(f"Saving checkpoint at step {step}") + t = time.perf_counter() + await asyncio.to_thread(self.ckpt_manager.save, self.progress, step) + return time.perf_counter() - t + + def update_dispatch_gate(self) -> None: + """Pause/resume the dispatcher based on how far the orchestrator's + next batch would run ahead of ``policy.version``. Called from two + sites: after shipping a batch (step advances) and from + ``on_new_version`` (policy advances).""" + lead = (self.progress.step + 1) - self.policy.version + gate = self.dispatcher.dispatch_allowed + was_set = gate.is_set() + if lead > TARGET_LAG: + if was_set: + get_logger().info( + "Pausing dispatcher to prevent orchestrator from racing from trainer. Waiting for new policy..." + ) + gate.clear() + else: + if not was_set: + get_logger().info("Resuming dispatcher") + gate.set() + + async def on_new_version(self, step: int) -> None: + """``VersionObserver`` hook: the watcher just advanced ``policy.version``; + re-evaluate the dispatch gate (may resume if the trainer caught up).""" + self.update_dispatch_gate() + + async def stop(self) -> None: + """Bounded best-effort teardown of all components. Has a global + timeout so a wedged peer can't keep the process alive forever — + training artifacts are already persisted before this is reached.""" + + async def teardown() -> None: + if self.sender is not None: + self.sender.close() + if self.dispatcher is not None: + await self.dispatcher.stop() + if self.watcher is not None: + await self.watcher.stop() + if self.periodic_logger is not None: + await self.periodic_logger.stop() + if self.lag_task is not None: + await safe_cancel(self.lag_task) + self.lag_task = None + for task in self.component_tasks: + await safe_cancel(task) + self.component_tasks.clear() + if self.inference_metrics is not None: + await self.inference_metrics.stop() + if self.student_inference is not None: + await self.student_inference.stop() + if self.teacher_inference is not None: + await self.teacher_inference.stop() + if self.train_envs is not None: + self.train_envs.shutdown() + if self.eval_envs is not None: + self.eval_envs.shutdown() + if self.usage_reporter is not None: + self.usage_reporter.close() + + task = asyncio.create_task(teardown()) + _, pending = await asyncio.wait({task}, timeout=SHUTDOWN_TIMEOUT_S) + if pending: + get_logger().warning( + f"Orchestrator shutdown did not complete within {SHUTDOWN_TIMEOUT_S}s; " + "forcing process exit. Training artifacts are already persisted." + ) + os._exit(0) + await task - if usage_reporter: - usage_reporter.close() - logger.success("Orchestrator finished.") +@clean_exit +async def run_orchestrator(config: OrchestratorConfig) -> None: + """Top-level entrypoint. Wrapped in ``@clean_exit`` so wandb is flushed + on exit (success or crash); keeps that out of the class. + """ + await Orchestrator(config).start() - # Optionally, print benchmark table - if config.bench: - print_benchmark(to_col_format(monitor.history)) +def main() -> None: + from prime_rl.utils.config import cli + from prime_rl.utils.process import set_proc_title -def main(): - """Main entry-point for orchestrator. Run using `uv run orchestrator`""" set_proc_title("Orchestrator") import uvloop uvloop.install() - asyncio.run(orchestrate(cli(OrchestratorConfig))) - - -async def setup_student_inference_pool( - *, - config: OrchestratorConfig, - tokenizer, - logger, -): - """Set up the student inference pool (rollouts when rl/opd, evals + weight sync always). - - Routing policy is driven by ``config.renderer``: - - - ``renderer is not None`` → renderer-backed TITO client (``/v1/generate``). - Default for both text-only and VLM rollouts; required for VLMs. - - ``renderer is None`` → MITO (``openai_chat_completions``). - - Eval clients always use MITO. In sft mode ``renderer`` is forced to ``None`` - by a config validator, so the student pool is plain MITO end-to-end. - """ - client_config = config.student.client - model_name = config.student.model.name - - if config.renderer is not None: - renderer = create_renderer(tokenizer, config.renderer) - logger.info(f"Initialized {type(renderer).__name__} for {model_name}") - inference_pool = await setup_inference_pool( - client_config, - model_name=model_name, - train_client_type="renderer", - eval_client_type="openai_chat_completions", - renderer_config=config.renderer, - pool_size=config.pool_size, - ) - logger.info("Using direct renderer rollout client") - return renderer, inference_pool - - logger.info("Using MITO (openai_chat_completions) for rollouts") - inference_pool = await setup_inference_pool( - client_config, - model_name=model_name, - train_client_type="openai_chat_completions", - eval_client_type="openai_chat_completions", - ) - return None, inference_pool + asyncio.run(run_orchestrator(cli(OrchestratorConfig))) if __name__ == "__main__": diff --git a/src/prime_rl/orchestrator/periodic_logger.py b/src/prime_rl/orchestrator/periodic_logger.py new file mode 100644 index 0000000000..4d4cf3d542 --- /dev/null +++ b/src/prime_rl/orchestrator/periodic_logger.py @@ -0,0 +1,64 @@ +"""PeriodicLogger: orchestrator's pipeline view, fires every ``interval`` +seconds. ``collect()`` returns ``(console_body, wandb_payload)`` in one +call so drain-on-read counters fire exactly once per tick. Wandb writes +land on the ``_timestamp`` axis.""" + +from __future__ import annotations + +import asyncio +import time +from typing import Callable + +import wandb + +from prime_rl.utils.async_utils import safe_cancel +from prime_rl.utils.logger import get_logger + + +class PeriodicLogger: + def __init__( + self, + *, + name: str, + collect: Callable[[], tuple[str, dict[str, float]]], + metric_keys: list[str], + interval: float, + wandb_enabled: bool, + ) -> None: + self.name = name + self.collect = collect + self.interval = interval + self.wandb_enabled = wandb_enabled + self.task: asyncio.Task | None = None + self.stopped = asyncio.Event() + + if self.wandb_enabled: + for key in metric_keys: + wandb.define_metric(key, step_metric="_timestamp") + + async def start(self) -> None: + self.task = asyncio.create_task(self.run(), name=f"{self.name}_periodic_logger") + + async def run(self) -> None: + try: + while not self.stopped.is_set(): + try: + await asyncio.wait_for(self.stopped.wait(), timeout=self.interval) + except asyncio.TimeoutError: + pass + self.emit() + except asyncio.CancelledError: + return + + def emit(self) -> None: + body, payload = self.collect() + get_logger().info(body) + if self.wandb_enabled and payload: + payload["_timestamp"] = time.time() + wandb.log(payload) + + async def stop(self) -> None: + self.stopped.set() + if self.task is not None: + await safe_cancel(self.task) + self.task = None diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py deleted file mode 100644 index 44a5e0663c..0000000000 --- a/src/prime_rl/orchestrator/scheduler.py +++ /dev/null @@ -1,591 +0,0 @@ -from __future__ import annotations - -import asyncio -import time -from collections import Counter, defaultdict -from dataclasses import dataclass, field - -import verifiers as vf -from aiolimiter import AsyncLimiter - -from prime_rl.configs.orchestrator import OrchestratorConfig -from prime_rl.orchestrator.buffer import Buffer -from prime_rl.orchestrator.envs import TrainEnvs -from prime_rl.orchestrator.vf_utils import get_seq_len -from prime_rl.utils.async_utils import safe_cancel, safe_cancel_all -from prime_rl.utils.client import InferencePool -from prime_rl.utils.logger import ProgressTracker, get_logger -from prime_rl.utils.utils import ( - get_broadcast_dir, - get_latest_ckpt_step, - get_step_path, - wait_for_path, -) - - -@dataclass -class InflightRequest: - """Metadata for an in-flight request.""" - - off_policy_steps: int - client_config: vf.ClientConfig - env_name: str - group_id: int | None = None - rollout_count: int = 1 - - -@dataclass -class GroupState: - """Tracks the state of a rollout group (one example × N rollouts).""" - - example: dict - rollouts_to_schedule: int - completed_rollouts: list[vf.RolloutOutput] = field(default_factory=list) - pinned_client: vf.ClientConfig | None = None - failed_rollouts: int = 0 - - -class Scheduler: - """ - Asynchronously manages scheduling of rollout requests and policy updates. - Keeps a constant number of rollouts in-flight (continuous batching) and - updates the policy as soon as it becomes available. - - References: - - AReal: https://arxiv.org/abs/2505.24298v1 - - PipelineRL: https://arxiv.org/abs/2509.19128v1 - """ - - def __init__( - self, - train_envs: TrainEnvs, - student_inference: InferencePool, - teacher_inference: InferencePool | None, - buffer: Buffer, - config: OrchestratorConfig, - max_inflight_rollouts: int, - max_off_policy_steps: int, - tasks_per_minute: int | None, - lora_name: str | None = None, - ): - self.logger = get_logger() - if tasks_per_minute is not None: - self.rate_limiter = AsyncLimiter(max_rate=tasks_per_minute, time_period=60) - else: - self.rate_limiter = None - self.train_envs = train_envs - self.buffer = buffer - self.config = config - self.batch_size = config.batch_size - self.token_batch_size = config.token_batch_size - self.group_size = config.group_size - self.max_inflight_rollouts = max_inflight_rollouts - self.max_off_policy_steps = max_off_policy_steps - self.lora_name = lora_name - self.json_logging = config.log.json_logging - - # student_inference is the weight-sync target. teacher_inference is set - # in opd (for logprobs) and sft (for rollouts). rollout_inference is - # whichever pool serves train rollouts for this mode. - self.student_inference = student_inference - self.teacher_inference = teacher_inference - if config.training_mode == "sft": - assert teacher_inference is not None - self.rollout_inference = teacher_inference - else: - self.rollout_inference = student_inference - # model_name is the name to send on rollout requests - matches the rollout pool - self.model_name = self.rollout_inference.model_name - - group_scoring_envs = [env.name for env in train_envs if env.requires_group_scoring] - if group_scoring_envs: - self.logger.info(f"Group rollout scoring active for env(s): {', '.join(group_scoring_envs)}") - - # Track in-flight requests: task -> info - self.inflight_requests: dict[asyncio.Task, InflightRequest] = {} - - # Track in-progress groups while rollouts are generated independently. - self.next_group_id = 0 - self.groups: dict[int, GroupState] = {} - - self.step, self.ckpt_step = 0, 0 - self.checkpoint_ready = asyncio.Event() - self.checkpoint_ready.set() - self.update_weights_time, self.wait_for_ckpt_time = 0, 0 - self.update_policy_task: asyncio.Task | None = None - self.inflight_policy_update_task: asyncio.Task | None = None - self.policy_update_lock = asyncio.Lock() - self.cancelled_rollouts_count = 0 - self.empty_rollouts_by_env: dict[str, int] = defaultdict(int) - self.errored_rollouts_by_env: dict[str, int] = defaultdict(int) - self.errors_by_type: dict[str, int] = defaultdict(int) - self.total_rollouts_by_env: dict[str, int] = defaultdict(int) - self.dropped_groups_by_env: dict[str, int] = defaultdict(int) - self.last_batch_generation_time = 0.0 - - @property - def uses_token_batching(self) -> bool: - return self.token_batch_size is not None - - @property - def batch_target(self) -> int: - if self.uses_token_batching: - assert self.token_batch_size is not None - return self.token_batch_size - assert self.batch_size is not None - return self.batch_size - - def get_batch_progress_increment(self, rollouts: list[vf.RolloutOutput]) -> int: - if self.uses_token_batching: - return sum(get_seq_len(rollout) for rollout in rollouts) - return len(rollouts) - - def finalize_batch_rollouts(self, rollouts: list[vf.RolloutOutput]) -> list[vf.RolloutOutput]: - if self.batch_size is None: - return rollouts - return rollouts[: self.batch_size] - - async def cancel_inflight_rollouts(self): - """Cancel all in-flight rollout requests.""" - count = sum(info.rollout_count for info in self.inflight_requests.values()) - await safe_cancel_all(list(self.inflight_requests)) - self.inflight_requests.clear() - self.groups.clear() - self.cancelled_rollouts_count += count - - @staticmethod - def _client_identity(c: vf.ClientConfig) -> tuple[str, str | None]: - return ( - c.api_base_url, - c.extra_headers.get("X-data-parallel-rank"), - ) - - async def _select_least_loaded_client(self) -> vf.ClientConfig: - """Select the client with the fewest in-flight tasks. - - Uses (api_base_url, dp_rank) as identity rather than client_idx so that - load tracking survives elastic pool refreshes (which reassign indices). - """ - clients = self.rollout_inference.train_clients - while not clients: - await asyncio.sleep(1) - clients = self.rollout_inference.train_clients - inflight = Counter(self._client_identity(info.client_config) for info in self.inflight_requests.values()) - return min(clients, key=lambda c: inflight[self._client_identity(c)]) - - async def drop_group(self, group_id: int) -> int: - """Drop a group and cancel any remaining in-flight rollouts for it. Returns the number of cancelled rollouts.""" - tasks_to_cancel = [] - rollout_count = 0 - for task, info in list(self.inflight_requests.items()): - if info.group_id != group_id: - continue - self.inflight_requests.pop(task, None) - tasks_to_cancel.append(task) - rollout_count += info.rollout_count - self.groups.pop(group_id, None) - await safe_cancel_all(tasks_to_cancel) - return rollout_count - - async def schedule_rollout(self, group_id: int): - """Asynchronously schedules a rollout request (or a group request for group-scoring envs).""" - if self.rate_limiter: - await self.rate_limiter.acquire() - group = self.groups.get(group_id) - if group is None or group.rollouts_to_schedule <= 0: - return - - if group.pinned_client is not None: - client_config = group.pinned_client - else: - client_config = await self._select_least_loaded_client() - if group_id not in self.groups: - return - group.pinned_client = client_config - - env_name = group.example["env_name"] - env = self.train_envs.get(env_name) - - cache_salt = str(self.ckpt_step) - if env.requires_group_scoring: - rollout_count = group.rollouts_to_schedule - group.rollouts_to_schedule = 0 - task = asyncio.create_task( - env.run_group( - client=client_config, - example=group.example, - model_name=self.model_name, - group_size=rollout_count, - cache_salt=cache_salt, - ) - ) - else: - rollout_count = 1 - group.rollouts_to_schedule -= 1 - task = asyncio.create_task( - env.run_rollout( - client=client_config, - example=group.example, - model_name=self.model_name, - cache_salt=cache_salt, - ) - ) - self.inflight_requests[task] = InflightRequest( - off_policy_steps=0, - client_config=client_config, - env_name=env_name, - group_id=group_id, - rollout_count=rollout_count, - ) - - @property - def inflight_rollout_count(self) -> int: - return sum(info.rollout_count for info in self.inflight_requests.values()) - - @property - def inflight_sample_count(self) -> int: - pending = sum(g.rollouts_to_schedule for g in self.groups.values()) - return self.inflight_rollout_count + pending - - async def _schedule_next_request(self) -> bool: - remaining_capacity = self.max_inflight_rollouts - self.inflight_rollout_count - - if remaining_capacity <= 0: - return False - - for group_id, group in self.groups.items(): - if group.rollouts_to_schedule <= 0: - continue - env = self.train_envs.get(group.example["env_name"]) - cost = group.rollouts_to_schedule if env.requires_group_scoring else 1 - if cost <= remaining_capacity: - await self.schedule_rollout(group_id=group_id) - return True - - if remaining_capacity < self.group_size: - return False - - example = self.buffer.sample_examples(n=1)[0] - group_id = self.next_group_id - self.next_group_id += 1 - self.groups[group_id] = GroupState(example=example, rollouts_to_schedule=self.group_size) - await self.schedule_rollout(group_id=group_id) - return True - - async def _fill_inflight_requests(self) -> None: - while await self._schedule_next_request(): - pass - - async def update_policy_loop(self): - """Continuously checks for new policy checkpoints.""" - while True: - await self.maybe_update_policy() - await asyncio.sleep(1) - - def _compute_next_ckpt_step(self) -> int: - # The orchestrator always runs one step ahead of the trainer, so we must advance to at - # least step - 1. We additionally adopt anything fresher the trainer has already - # broadcast (so a fast trainer briefly running on-policy is fine). ``latest_ckpt_step`` - # is non-negative so it also clamps a self.step == 0 startup. - latest_ckpt_step = get_latest_ckpt_step(get_broadcast_dir(self.config.output_dir)) or 0 - return max(self.step - 1, latest_ckpt_step) - - async def _apply_policy_update(self, next_ckpt_step: int) -> None: - # If we're advancing to step - 1, the trainer hasn't broadcast it yet (otherwise - # we would've picked something newer); block until the file lands. - if next_ckpt_step == max(self.step - 1, 0): - self.logger.info( - f"Orchestrator paused: waiting for trainer to broadcast checkpoint {next_ckpt_step} " - f"(orchestrator is one step ahead). Training is progressing normally." - ) - self.checkpoint_ready.clear() - wait_for_ckpt_start_time = time.perf_counter() - await wait_for_path(get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) / "STABLE") - self.wait_for_ckpt_time = time.perf_counter() - wait_for_ckpt_start_time - self.logger.info( - f"Orchestrator resumed: checkpoint {next_ckpt_step} ready (after {self.wait_for_ckpt_time:.2f}s)" - ) - - self.logger.debug( - f"Got new policy with step {next_ckpt_step}. Updating weights and cancelling old rollout requests." - ) - - update_weights_start_time = time.perf_counter() - weights_path = get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) - await self.student_inference.update_weights(weights_path, lora_name=self.lora_name, step=next_ckpt_step) - self.update_weights_time = time.perf_counter() - update_weights_start_time - self.logger.debug(f"Updated weights to step {next_ckpt_step} in {self.update_weights_time:.2f}s") - - self.ckpt_step = next_ckpt_step - if self.lora_name is not None: - self.student_inference.update_model_name(self.lora_name) - # Only redirect rollout requests to the new LoRA when rollouts come from - # student inference (rl/opd). In sft, rollouts go to the teacher and - # the student's LoRA name is irrelevant to them. - if self.rollout_inference is self.student_inference: - self.model_name = self.lora_name - - self.checkpoint_ready.set() - await self._update_off_policy() - - async def _get_or_start_policy_update_task(self, next_ckpt_step: int) -> asyncio.Task: - async with self.policy_update_lock: - task = self.inflight_policy_update_task - if task is not None and not task.done(): - return task - - task = asyncio.create_task(self._apply_policy_update(next_ckpt_step)) - self.inflight_policy_update_task = task - - def _clear_inflight_policy_update(done_task: asyncio.Task) -> None: - if self.inflight_policy_update_task is done_task: - self.inflight_policy_update_task = None - - task.add_done_callback(_clear_inflight_policy_update) - return task - - async def maybe_update_policy(self): - """Updates the policy to the latest available checkpoint. Aborts rollout requests that are older than the max retention steps.""" - while True: - next_ckpt_step = self._compute_next_ckpt_step() - if next_ckpt_step <= self.ckpt_step: - return - - task = await self._get_or_start_policy_update_task(next_ckpt_step) - await asyncio.shield(task) - - async def _update_off_policy(self) -> None: - stale_group_ids = { - info.group_id - for info in self.inflight_requests.values() - if info.group_id is not None and info.off_policy_steps >= self.max_off_policy_steps - } - tasks_to_increment = [ - task - for task, info in list(self.inflight_requests.items()) - if info.group_id is None or info.group_id not in stale_group_ids - ] - - counts = await asyncio.gather(*(self.drop_group(gid) for gid in stale_group_ids)) - removed = sum(counts) - for task in tasks_to_increment: - info = self.inflight_requests.get(task) - if info is None: - continue - info.off_policy_steps += 1 - - self.cancelled_rollouts_count += removed - if removed: - self.logger.warning( - f"Cancelled {removed} old rollout requests (will refill naturally). " - f"Consider increasing max_off_policy_steps to avoid this." - ) - - async def generate_batch(self, step: int) -> list[vf.RolloutOutput]: - """Continuously generates a batch of rollouts.""" - self.step = step - - # Cancel the previous update policy task to avoid concurrent updates - if self.update_policy_task is not None: - await safe_cancel(self.update_policy_task) - - # Manually check the async barrier before starting the step, then re-create the update policy loop - # This ensures the orchestrator stays at most one step ahead of the trainer, while still - # listening for policy updates mid-step. - await self.maybe_update_policy() - self.update_policy_task = asyncio.create_task(self.update_policy_loop()) - - batch_start_time = time.perf_counter() - - self.logger.debug("Starting to generate batch rollouts") - - batch_rollouts: list[vf.RolloutOutput] = [] - batch_progress = 0 - pbar = ProgressTracker( - total=self.batch_target, desc="Generating rollouts (train)", json_logging=self.json_logging, step=step - ) - - while batch_progress < self.batch_target: - await self._fill_inflight_requests() - inflight_tasks = list(self.inflight_requests.keys()) - - finished_tasks, _ = await asyncio.wait( - inflight_tasks, - return_when=asyncio.FIRST_COMPLETED, - ) - await self.checkpoint_ready.wait() - - for finished_task in finished_tasks: - if batch_progress >= self.batch_target: - break - - rollout_info = self.inflight_requests.pop(finished_task, None) - if rollout_info is None: - continue - - group_id = rollout_info.group_id - env_name = rollout_info.env_name - - try: - group = self.groups.get(group_id) - if group is None: - continue - - result = finished_task.result() - rollouts: list[vf.RolloutOutput] = result if isinstance(result, list) else [result] - self.total_rollouts_by_env[env_name] += len(rollouts) - - # Partition rollouts into valid vs failed and tally per-rollout - # error metrics. Tally every failure (group-scoring envs return - # N rollouts per task) so error-rate metrics aren't deflated. - env = self.train_envs.get(env_name) - valid_rollouts: list[vf.RolloutOutput] = [] - for rollout in rollouts: - if rollout["error"] is not None: - self.errored_rollouts_by_env[env_name] += 1 - self.errors_by_type[rollout["error"]["error"]] += 1 - self.logger.warning( - f"Rollout failed in group {group_id} ({env_name}) - " - f"{rollout['error']['error_chain_repr']}" - ) - elif len(rollout["trajectory"]) == 0: - self.empty_rollouts_by_env[env_name] += 1 - self.logger.warning(f"Empty trajectory in group {group_id} ({env_name})") - else: - rollout["env_name"] = env_name - valid_rollouts.append(rollout) - - num_failed = len(rollouts) - len(valid_rollouts) - group.failed_rollouts += num_failed - - # Group-scoring envs compute scores over all N rollouts - # together; the surviving rollouts carry scores computed against - # the (now-missing) failed ones, so partial salvage is unsafe. - # Drop the whole group on any failure. - if num_failed > 0 and env.requires_group_scoring: - self.dropped_groups_by_env[env_name] += 1 - self.logger.warning( - f"Dropping group-scored group {group_id} ({env_name}) after rollout failure" - ) - await self.drop_group(group_id) - continue - - group.completed_rollouts.extend(valid_rollouts) - - # Wait until every dispatched rollout has come back (succeeded - # or failed) before finalizing. The group may finalize as a - # partial group (< group_size) when some rollouts - # errored - downstream advantage computation groups by - # (env_name, example_id), so variable-size groups are fine. - if len(group.completed_rollouts) + group.failed_rollouts < self.group_size: - continue - - if not group.completed_rollouts: - self.dropped_groups_by_env[env_name] += 1 - self.logger.warning( - f"Dropping group {group_id} ({env_name}) - all {self.group_size} rollouts failed" - ) - self.groups.pop(group_id, None) - continue - - if group.failed_rollouts > 0: - self.logger.warning( - f"Partial group {group_id} ({env_name}) - " - f"{len(group.completed_rollouts)}/{self.group_size} valid " - f"({group.failed_rollouts} failed)" - ) - - completed_rollouts = self.groups.pop(group_id).completed_rollouts - - except asyncio.CancelledError: - if group_id is not None: - await self.drop_group(group_id) - continue - except Exception as e: - self.logger.warning(f"Rollout failed: {e}") - if group_id is not None: - await self.drop_group(group_id) - continue - - self.buffer.update(completed_rollouts) - accepted_rollouts = self.buffer.sample_rollouts(n=len(completed_rollouts)) - - batch_rollouts.extend(accepted_rollouts) - progress_increment = self.get_batch_progress_increment(accepted_rollouts) - batch_progress += progress_increment - pbar.update(progress_increment) - - await self._fill_inflight_requests() - - batch_rollouts = self.finalize_batch_rollouts(batch_rollouts) - pbar.close() - self.last_batch_generation_time = time.perf_counter() - batch_start_time - return batch_rollouts - - async def stop(self) -> None: - await self.cancel_inflight_rollouts() - if self.update_policy_task is not None: - await safe_cancel(self.update_policy_task) - self.update_policy_task = None - if self.inflight_policy_update_task is not None: - await safe_cancel(self.inflight_policy_update_task) - self.inflight_policy_update_task = None - - @property - def max_off_policy_level(self) -> int: - steps = [info.off_policy_steps for info in self.inflight_requests.values()] - if not steps: - return 0 - return max(steps) - - @property - def mean_off_policy_level(self) -> float: - steps = [info.off_policy_steps for info in self.inflight_requests.values()] - if not steps: - return 0 - return sum(steps) / len(steps) - - @property - def async_level(self) -> int: - return self.step - self.ckpt_step - - def get_metrics(self) -> dict[str, float]: - total_rollouts = sum(self.total_rollouts_by_env.values()) - metrics = { - "time/wait_for_ckpt": self.wait_for_ckpt_time, - "time/update_weights": self.update_weights_time, - "scheduler/async_level": self.async_level, - "scheduler/inflight_rollouts": self.inflight_rollout_count, - "scheduler/inflight_samples": self.inflight_sample_count, - "scheduler/cancelled_rollouts": self.cancelled_rollouts_count, - "empty_rollouts/all": sum(self.empty_rollouts_by_env.values()) / max(total_rollouts, 1), - "errored_rollouts/all": sum(self.errored_rollouts_by_env.values()) / max(total_rollouts, 1), - "dropped_groups/all": sum(self.dropped_groups_by_env.values()), - "off_policy_level/all/max": self.max_off_policy_level, - "off_policy_level/all/mean": self.mean_off_policy_level, - } - for env_name in self.total_rollouts_by_env: - env_total = max(self.total_rollouts_by_env[env_name], 1) - metrics[f"empty_rollouts/{env_name}"] = self.empty_rollouts_by_env.get(env_name, 0) / env_total - metrics[f"errored_rollouts/{env_name}"] = self.errored_rollouts_by_env.get(env_name, 0) / env_total - for env_name, count in self.dropped_groups_by_env.items(): - metrics[f"dropped_groups/{env_name}"] = count - for error_type, count in self.errors_by_type.items(): - metrics[f"error/{error_type}/count"] = count - by_env: dict[str, list[int]] = {} - for info in self.inflight_requests.values(): - by_env.setdefault(info.env_name, []).append(info.off_policy_steps) - for env_name, steps in by_env.items(): - metrics[f"off_policy_level/{env_name}/max"] = max(steps) - metrics[f"off_policy_level/{env_name}/mean"] = sum(steps) / len(steps) - self.cancelled_rollouts_count = 0 - self.empty_rollouts_by_env.clear() - self.errored_rollouts_by_env.clear() - self.errors_by_type.clear() - self.total_rollouts_by_env.clear() - self.dropped_groups_by_env.clear() - - # Add train pool metrics (e.g. elastic pool server counts) - metrics.update(self.rollout_inference.get_metrics()) - - return metrics diff --git a/src/prime_rl/orchestrator/train_sink.py b/src/prime_rl/orchestrator/train_sink.py new file mode 100644 index 0000000000..54af787bec --- /dev/null +++ b/src/prime_rl/orchestrator/train_sink.py @@ -0,0 +1,312 @@ +"""TrainSink: three-level rollout sink for the training side. + +1. ``process_rollout`` — eager per-rollout tokenization (overlaps with + dispatcher producing more rollouts). Errored rollouts skip this. +2. ``process_group`` — filters errored rollouts, computes advantages over + survivors, runs the pre-batch filter pass. +3. ``process_batch`` — applies post-batch filter annotations and assembles + the trainer-bound ``TrainingSample`` list. Returns a ``TrainBatch``. + +``add()`` returns ``TrainBatch | None``. I/O concerns (ship to trainer, +save_rollouts, monitor.log, teacher logprobs) live on the orchestrator. +""" + +from __future__ import annotations + +import asyncio +import uuid +from collections import defaultdict + +from prime_rl.configs.orchestrator import AdvantageConfig, OrchestratorConfig +from prime_rl.orchestrator.advantage import assign_advantages, setup_advantage_fn +from prime_rl.orchestrator.envs import TrainEnvs +from prime_rl.orchestrator.filters import RolloutFilter, apply_filters +from prime_rl.orchestrator.trajectories import ( + backfill_rollout_tokens, + interleave_rollout, + offload_images_to_disk, +) +from prime_rl.orchestrator.types import TrainBatch, TrainBatchMetrics, TrainRollout +from prime_rl.transport import TrainingSample +from prime_rl.utils.logger import get_logger + + +class TrainSink: + """Three-level train sink. Constructed once, fed via ``add(rollout)``.""" + + def __init__( + self, + config: OrchestratorConfig, + *, + tokenizer, + renderer, + train_envs: TrainEnvs, + mm_token_type_ids_mapping: dict[int, int] | None, + batch_size: int | None, + token_batch_size: int | None, + advantage_config: AdvantageConfig | None, + pre_filters: list[RolloutFilter], + post_filters: list[RolloutFilter], + ) -> None: + assert (batch_size is None) != (token_batch_size is None), ( + "Exactly one of batch_size / token_batch_size must be set" + ) + self.config = config + self.tokenizer = tokenizer + self.renderer = renderer + self.train_envs = train_envs + self.mm_token_type_ids_mapping = mm_token_type_ids_mapping + self.batch_size = batch_size + self.token_batch_size = token_batch_size + # Built once — custom advantage funcs do an ``import_object`` and + # we don't want to pay that per group. ``None`` = reward-only path + self.advantage_fn = setup_advantage_fn(advantage_config) if advantage_config is not None else None + self.pre_filters = pre_filters + self.post_filters = post_filters + + # Keyed by the dispatcher's group UUID. ``(env_name, example_id)`` + # isn't unique — the same example can be re-sampled while an + # earlier group is still in flight + self.pending_groups: dict[uuid.UUID, list[TrainRollout]] = defaultdict(list) + self.pending_batch: list[TrainRollout] = [] + + # Reset by the orchestrator after each ship via ``reset_pre_filter_stats`` + self.pre_filter_seen = 0 + self.pre_filter_dropped = 0 + self.pre_filter_dropped_by_name: dict[str, int] = {} + + # Per-env arrival / error counters since the last ship; reset in + # ``process_batch``. Fuel for the per-env success log breakdown + self.arrivals_by_env: dict[str, int] = defaultdict(int) + self.errors_by_env: dict[str, int] = defaultdict(int) + + def group_size_for(self, env_name: str) -> int: + return self.train_envs.get(env_name).config.group_size + + def in_progress_groups(self) -> list[list[TrainRollout]]: + """Per-rollout groups currently accumulating in ``pending_groups`` — + i.e. groups that haven't hit ``group_size`` yet, so the pipeline log + can reflect partial-group progress. Skips group-scoring envs (whose + rollouts only make sense as a unit — the user expects per-group + fill, not per-rollout, for those).""" + out: list[list[TrainRollout]] = [] + for rollouts in self.pending_groups.values(): + if not rollouts: + continue + env_name = rollouts[0].env_name + if self.train_envs.get(env_name).requires_group_scoring: + continue + out.append(rollouts) + return out + + def batch_progress(self) -> tuple[int, int, str]: + """``(current, target, unit)`` for the train batch — counts only + ``pending_batch`` (survivors of finalized groups, queued for the + trainer), so it's an honest 0→target fill. Partial-group arrivals are + reported separately by ``buffered_count()``.""" + if self.batch_size is not None: + return len(self.pending_batch), self.batch_size, "rollouts" + assert self.token_batch_size is not None + tokens = sum( + r.raw["token_usage"]["final_input_tokens"] + r.raw["token_usage"]["final_output_tokens"] + for r in self.pending_batch + ) + return tokens, self.token_batch_size, "tokens" + + def buffered_count(self) -> int: + """Rollouts that have arrived but sit in not-yet-complete groups + (non-group-scoring envs) — buffered in the sink ahead of the batch.""" + return sum(len(group) for group in self.in_progress_groups()) + + def pending_batch_by_env(self) -> dict[str, int]: + """Per-env breakdown of ``batch_progress()`` (``pending_batch`` only); + values sum to the aggregate.""" + counts: dict[str, int] = defaultdict(int) + for r in self.pending_batch: + counts[r.env_name] += 1 + return dict(counts) + + async def add(self, rollout: TrainRollout) -> TrainBatch | None: + """Process one arrival; finalize the group on the ``group_size``-th + arrival; return a ``TrainBatch`` if the batch threshold is met.""" + await self.process_rollout(rollout) + env_name = rollout.env_name + self.arrivals_by_env[env_name] += 1 + if rollout.error is not None: + self.errors_by_env[env_name] += 1 + self.pending_groups[rollout.group_id].append(rollout) + if len(self.pending_groups[rollout.group_id]) >= self.group_size_for(env_name): + self.process_group(rollout.group_id) + ready = ( + len(self.pending_batch) >= self.batch_size + if self.batch_size is not None + else sum( + r.raw["token_usage"]["final_input_tokens"] + r.raw["token_usage"]["final_output_tokens"] + for r in self.pending_batch + ) + >= (self.token_batch_size or 0) + ) + if ready: + return self.process_batch() + return None + + async def process_rollout(self, rollout: TrainRollout) -> None: + """Tokenize the rollout eagerly. Backfills tokens if the env didn't + return them (SFT against external teacher APIs); errored rollouts + skip tokenization and get dropped at the group level.""" + if rollout.error is not None: + return + raw = rollout.raw + needs_backfill = any(s["tokens"] is None for s in raw.get("trajectory") or []) + if needs_backfill: + await asyncio.to_thread(backfill_rollout_tokens, raw, self.tokenizer, renderer=self.renderer) + samples = await asyncio.to_thread( + interleave_rollout, + raw, + mm_token_type_ids_mapping=self.mm_token_type_ids_mapping, + env_name=rollout.env_name, + ) + rollout.samples = samples or [] + # Offload base64 image bytes to disk as soon as the rollout is + # tokenized, so memory stays flat instead of holding every buffered + # rollout's images until the batch ships (no-op for text-only). + await asyncio.to_thread(offload_images_to_disk, [raw], self.config.output_dir) + + def process_group(self, group_id: uuid.UUID) -> None: + """Finalize one GRPO group: drop errored rollouts (the whole group + when ``requires_group_scoring`` and any failed), assign advantages, + run pre-batch filters, append survivors to ``pending_batch``.""" + group = self.pending_groups.pop(group_id, []) + if not group: + return + env_name = group[0].env_name + example_id = group[0].example_id + survivors = [r for r in group if r.error is None] + num_errored = len(group) - len(survivors) + + # Group-scoring envs: any failure makes survivors' rewards unsafe + # (computed relative to the missing ones) + env = self.train_envs.get(env_name) + if num_errored > 0 and env.requires_group_scoring: + get_logger().debug( + f"Finished group | env={env_name} example_id={example_id} | " + f"rollouts={len(group)} (errored={num_errored}) | dropped: group-scored partial" + ) + return + if not survivors: + get_logger().debug( + f"Finished group | env={env_name} example_id={example_id} | " + f"rollouts={len(group)} (errored={num_errored}) | dropped: all failed" + ) + return + + assign_advantages(survivors, self.advantage_fn) + + # Propagate to the pre-tokenized samples so the orchestrator can + # collect samples at ship time without re-walking rollouts + for r in survivors: + for sample in r.samples: + sample.advantage = r.advantage + sample.reward = r.reward + sample.env_name = r.env_name + sample.training_mode = self.config.training_mode + + if self.pre_filters: + apply_filters(self.pre_filters, survivors) + filtered_by_name: dict[str, int] = {} + num_filtered = 0 + for r in survivors: + self.pre_filter_seen += 1 + if r.is_filtered: + self.pre_filter_dropped += 1 + num_filtered += 1 + for name, hit in r.filter_results.items(): + if hit: + self.pre_filter_dropped_by_name[name] = self.pre_filter_dropped_by_name.get(name, 0) + 1 + filtered_by_name[name] = filtered_by_name.get(name, 0) + 1 + continue + # Reset annotations so the post-batch filter pass starts clean + r.filter_results = {} + r.is_filtered = False + self.pending_batch.append(r) + + # Per-group summary. One line per finalized group; per-filter + # detection breakdown lives at debug level in ``apply_filters`` + rewards = [r.reward for r in survivors] + avg_reward = sum(rewards) / len(rewards) if rewards else 0.0 + filter_str = ", ".join(f"{n}={c}" for n, c in filtered_by_name.items()) if filtered_by_name else "—" + get_logger().debug( + f"Finished group | env={env_name} example_id={example_id} | " + f"rollouts={len(group)} (errored={num_errored}, filtered={num_filtered}) | " + f"reward={avg_reward:.4f} | filters: {filter_str}" + ) + + def process_batch(self) -> TrainBatch: + """Pop a cohort off ``pending_batch`` (by rollout count when + ``batch_size`` is set, by token count when ``token_batch_size`` is + set), apply post-batch filter annotations, and assemble the + trainer-bound ``TrainingSample`` list. Overflow stays for the next + batch.""" + if self.batch_size is not None: + cohort = self.pending_batch[: self.batch_size] + self.pending_batch = self.pending_batch[self.batch_size :] + else: + assert self.token_batch_size is not None + cut = 0 + running = 0 + for i, r in enumerate(self.pending_batch): + running += r.raw["token_usage"]["final_input_tokens"] + r.raw["token_usage"]["final_output_tokens"] + cut = i + 1 + if running >= self.token_batch_size: + break + cohort = self.pending_batch[:cut] + self.pending_batch = self.pending_batch[cut:] + + if self.post_filters: + apply_filters(self.post_filters, cohort) + + # Samples are pre-built by ``process_rollout``; ``process_group`` + # already set advantage/reward on each sample + samples: list[TrainingSample] = [] + prefill_lens: list[int] = [] + decode_lens: list[int] = [] + samples_per_rollout: list[int] = [] + num_prefill = 0 + num_decode = 0 + for r in cohort: + samples_per_rollout.append(len(r.samples)) + prefill = 0 + decode = 0 + for sample in r.samples: + sample_decode = sum(sample.completion_mask) + sample_prefill = len(sample.prompt_ids) + len(sample.completion_mask) - sample_decode + decode += sample_decode + prefill += sample_prefill + if not r.is_filtered: + samples.append(sample) + prefill_lens.append(prefill) + decode_lens.append(decode) + num_prefill += prefill + num_decode += decode + + n_trainable = sum(1 for r in cohort if not r.is_filtered) + + metrics = TrainBatchMetrics( + n_trainable=n_trainable, + num_prefill_tokens=num_prefill, + num_decode_tokens=num_decode, + rollout_prefill_lens=prefill_lens, + rollout_decode_lens=decode_lens, + samples_per_rollout=samples_per_rollout, + samples_shipped=len(samples), + arrivals_by_env=dict(self.arrivals_by_env), + errors_by_env=dict(self.errors_by_env), + ) + self.arrivals_by_env.clear() + self.errors_by_env.clear() + return TrainBatch(rollouts=cohort, samples=samples, metrics=metrics) + + def reset_pre_filter_stats(self) -> None: + self.pre_filter_seen = 0 + self.pre_filter_dropped = 0 + self.pre_filter_dropped_by_name.clear() diff --git a/src/prime_rl/orchestrator/train_source.py b/src/prime_rl/orchestrator/train_source.py new file mode 100644 index 0000000000..db439f7539 --- /dev/null +++ b/src/prime_rl/orchestrator/train_source.py @@ -0,0 +1,59 @@ +"""TrainSource: weighted round-robin across train envs, infinite pull. + +Weights default to configured ``ratio`` (when every env sets one) or to +per-env dataset size. ``next_example`` reshuffles on cursor exhaustion.""" + +from __future__ import annotations + +import random + +from prime_rl.orchestrator.envs import TrainEnvs + + +class TrainSource: + """``next_example(available_permits)`` picks a weighted-RR env and + returns its next example (or ``None`` when the env's per-call permit + cost doesn't fit — the dispatch loop retries when permits free up). + Returned dicts carry ``env_name`` + ``example_id``.""" + + def __init__(self, train_envs: TrainEnvs, *, seed: int | None) -> None: + self.rng = random.Random(seed) + self.envs = list(train_envs) + if not self.envs: + raise ValueError("TrainSource needs at least one train env") + + self.examples: dict[str, list[dict]] = {} + self.cursors: dict[str, int] = {} + # Group-scoring envs reserve ``group_size`` permits up front; + # per-rollout envs need 1 + self.env_costs: dict[str, int] = {} + for env in self.envs: + rows: list[dict] = [] + for row in env.get_dataset(seed=seed): + ex = dict(row) + ex["env_name"] = env.name + rows.append(ex) + self.rng.shuffle(rows) + self.examples[env.name] = rows + self.cursors[env.name] = 0 + self.env_costs[env.name] = env.config.group_size if env.requires_group_scoring else 1 + + self.env_names = [e.name for e in self.envs] + configured_ratios = [e.config.ratio for e in self.envs] + if all(r is not None for r in configured_ratios): + self.weights: list[float] = [float(r) for r in configured_ratios] # type: ignore[arg-type] + else: + self.weights = [float(len(self.examples[name])) for name in self.env_names] + + def next_example(self, available_permits: int) -> dict | None: + env_name = self.rng.choices(self.env_names, weights=self.weights, k=1)[0] + if self.env_costs[env_name] > available_permits: + return None + rows = self.examples[env_name] + cursor = self.cursors[env_name] + if cursor >= len(rows): + self.rng.shuffle(rows) + cursor = 0 + example = rows[cursor] + self.cursors[env_name] = cursor + 1 + return example diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 5679253bf5..caf34ee0ee 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -204,6 +204,8 @@ def backfill_rollout_tokens( def interleave_rollout( output: vf.RolloutOutput, mm_token_type_ids_mapping: dict[int, int] | None = None, + *, + env_name: str = "", ) -> list[TrainingSample] | None: """ Convert vf.RolloutOutput to trainable rollouts by interleaving trajectory steps @@ -305,7 +307,7 @@ def make_sample(tokens: dict[str, Any], step_idx: int) -> TrainingSample: completion_temperatures=[temperature] * len(completion_ids), teacher_logprobs=None, advantage=None, - env_name=output["env_name"], + env_name=env_name, mm_token_type_ids=None, routed_experts=None, # deferred — finalized at end of interleave_rollout ) diff --git a/src/prime_rl/orchestrator/types.py b/src/prime_rl/orchestrator/types.py new file mode 100644 index 0000000000..a22412d3f1 --- /dev/null +++ b/src/prime_rl/orchestrator/types.py @@ -0,0 +1,209 @@ +"""Shared dataclasses for the orchestrator. Data carriers only; no behavior.""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field, fields +from typing import Literal, Protocol + +import verifiers as vf + +from prime_rl.transport import TrainingSample + + +@dataclass +class Policy: + """Mutable shared view of the policy. Passed by reference so observers + see new versions immediately.""" + + version: int = 0 + model_name: str = "" + + +@dataclass +class Progress: + """Persistent counters; ``step`` is the trainer-aligned step.""" + + step: int = 0 + total_tokens: int = 0 + total_samples: int = 0 + total_problems: int = 0 + + +Kind = Literal["train", "eval"] + + +@dataclass +class InflightRollout: + """Per-task scheduling state in the dispatcher; one entry per in-flight + ``run_rollout`` / ``run_group`` task.""" + + kind: Kind + env_name: str + group_id: uuid.UUID + policy_version: int + rollout_count: int + client_config: vf.ClientConfig | None = None + off_policy_steps: int = 0 + eval_step: int | None = None + + +@dataclass +class GroupState: + """Per-group dispatcher state: what's left to schedule + the pinned + client (for prefix-cache hits).""" + + kind: Kind + env_name: str + example: dict + rollouts_to_schedule: int + target_rollouts: int + emitted: int = 0 + eval_step: int | None = None + pinned_client: vf.ClientConfig | None = None + policy_version_at_start: int = 0 + + +@dataclass +class FinishedRollout: + """A completed rollout the sink receives. ``raw`` is the env's untouched + ``vf.RolloutOutput``; prime-rl metadata lives on typed fields. Train vs + eval is discriminated via ``isinstance``. ``rollout_id`` is the only + safe key for tracing one rollout — ``(env_name, example_id)`` collides + on re-sampling and ``group_id`` covers a whole group.""" + + raw: vf.RolloutOutput + env_name: str + example_id: int | str + group_id: uuid.UUID + policy_version: int + off_policy_steps: int + rollout_id: uuid.UUID = field(default_factory=uuid.uuid4) + + @property + def error(self) -> dict | None: + return self.raw.get("error") + + @property + def reward(self) -> float: + return float(self.raw.get("reward", 0.0)) + + @property + def is_truncated(self) -> bool: + return bool(self.raw.get("is_truncated", False)) + + def to_dict(self) -> vf.RolloutOutput: + """``raw`` + metadata merged for I/O (``save_rollouts``, + ``monitor.log_samples``). Shallow copy; never mutates ``self.raw``.""" + out: vf.RolloutOutput = dict(self.raw) # type: ignore[assignment] + for f in fields(self): + if f.name in ("raw", "samples"): + continue + val = getattr(self, f.name) + if f.name == "filter_results": + out["filters"] = dict(val) + continue + out[f.name] = str(val) if isinstance(val, uuid.UUID) else val + return out + + +@dataclass +class TrainRollout(FinishedRollout): + samples: list[TrainingSample] = field(default_factory=list) + advantage: float | None = None + is_filtered: bool = False + filter_results: dict[str, bool] = field(default_factory=dict) + + +@dataclass +class EvalRollout(FinishedRollout): + eval_step: int = 0 + + +@dataclass +class TrainBatchMetrics: + """Per-batch aggregates from ``TrainSink.process_batch``; consumed by + ``MetricsBuilder.build``. ``arrivals_by_env`` / ``errors_by_env`` count + rollouts at the sink (errored ones get dropped at the group level before + reaching ``TrainBatch.rollouts``).""" + + n_trainable: int + num_prefill_tokens: int + num_decode_tokens: int + rollout_prefill_lens: list[int] + rollout_decode_lens: list[int] + samples_per_rollout: list[int] + samples_shipped: int + arrivals_by_env: dict[str, int] = field(default_factory=dict) + errors_by_env: dict[str, int] = field(default_factory=dict) + + +@dataclass +class TrainBatch: + """``samples`` is the trainer-bound payload (post-filter survivors); + ``rollouts`` is the full cohort kept for orchestrator-side I/O.""" + + rollouts: list[TrainRollout] + samples: list[TrainingSample] + metrics: TrainBatchMetrics + + +@dataclass +class EvalBatchMetrics: + """Typed per-batch metrics from ``EvalSink.process_batch``. Final wandb + dict derived via ``to_wandb_dict`` at log time.""" + + n_rollouts: int + n_cancelled: int + n_errored: int + n_examples: int = 0 + group_size: int = 1 + reward_mean: float = 0.0 + completion_len_mean: float = 0.0 + completion_len_max: float = 0.0 + completion_len_min: float = 0.0 + truncation_rate: float = 0.0 + no_response_rate: float = 0.0 + num_turns_mean: float = 0.0 + num_turns_min: float = 0.0 + num_turns_max: float = 0.0 + pass_at_k: dict[str, float] = field(default_factory=dict) + + def to_wandb_dict(self, *, env_name: str, step: int) -> dict[str, float]: + prefix = f"eval/{env_name}" + out: dict[str, float] = { + "step": float(step), + f"{prefix}/cancelled_count": float(self.n_cancelled), + f"{prefix}/errored_count": float(self.n_errored), + } + if self.n_examples > 0: + out[f"{prefix}/avg@{self.group_size}"] = self.reward_mean + out[f"{prefix}/completion_len/mean"] = self.completion_len_mean + out[f"{prefix}/completion_len/max"] = self.completion_len_max + out[f"{prefix}/completion_len/min"] = self.completion_len_min + out[f"{prefix}/is_truncated/mean"] = self.truncation_rate + out[f"{prefix}/no_response/mean"] = self.no_response_rate + out[f"{prefix}/num_turns/mean"] = self.num_turns_mean + out[f"{prefix}/num_turns/min"] = self.num_turns_min + out[f"{prefix}/num_turns/max"] = self.num_turns_max + for k, v in self.pass_at_k.items(): + out[f"{prefix}/{k}"] = v + return out + + +@dataclass +class EvalBatch: + """One env's eval epoch. ``metrics`` is the typed view from + ``EvalSink.process_batch``.""" + + env_name: str + step: int + rollouts: list[EvalRollout] + metrics: EvalBatchMetrics + + +class VersionObserver(Protocol): + """Notified after each policy update; walked by the watcher after it + mutates ``Policy``.""" + + async def on_new_version(self, step: int) -> None: ... diff --git a/src/prime_rl/orchestrator/utils.py b/src/prime_rl/orchestrator/utils.py index 121cc57c54..5675ba3f34 100644 --- a/src/prime_rl/orchestrator/utils.py +++ b/src/prime_rl/orchestrator/utils.py @@ -1,82 +1,101 @@ import asyncio +import logging import time from concurrent.futures import ThreadPoolExecutor from itertools import cycle from pathlib import Path -from typing import Any -import pandas as pd +import orjson import verifiers as vf -from rich.console import Console -from rich.table import Table from verifiers.utils.client_utils import setup_openai_client +from verifiers.utils.save_utils import make_serializable +from prime_rl.configs.orchestrator import OrchestratorConfig from prime_rl.transport import TrainingSample -from prime_rl.utils.logger import get_logger +from prime_rl.utils.client import setup_inference_pool +from prime_rl.utils.logger import InterceptHandler, get_logger from prime_rl.utils.utils import ( - format_time, get_broadcast_dir, get_ckpt_dir, get_step_path, ) +async def setup_student_inference_pool(*, config: OrchestratorConfig, tokenizer): + """Build the student inference pool + matching renderer. Returns + ``(renderer | None, inference_pool)``; ``renderer`` is ``None`` on the + MITO path (``config.renderer is None``).""" + from renderers.base import create_renderer + + client_config = config.student.client + model_name = config.student.model.name + + if config.renderer is not None: + renderer = create_renderer(tokenizer, config.renderer) + get_logger().info(f"Initialized {type(renderer).__name__} for {model_name}") + inference_pool = await setup_inference_pool( + client_config, + model_name=model_name, + train_client_type="renderer", + eval_client_type="openai_chat_completions", + renderer_config=config.renderer, + pool_size=config.pool_size, + ) + get_logger().info("Using direct renderer rollout client") + return renderer, inference_pool + + get_logger().info("Using MITO (openai_chat_completions) for rollouts") + inference_pool = await setup_inference_pool( + client_config, + model_name=model_name, + train_client_type="openai_chat_completions", + eval_client_type="openai_chat_completions", + ) + return None, inference_pool + + +def get_model_completion_len(output: vf.RolloutOutput) -> int: + """Sum of model-generated completion tokens across all turns (excludes + environment-injected tokens between turns).""" + return sum(len(step["tokens"]["completion_ids"]) for step in output["trajectory"] if step.get("tokens")) + + +def get_tool_response_len(output: vf.RolloutOutput) -> int: + """Total tool-response tokens consumed across the whole rollout, read from a + harness-emitted metric (e.g. RLM's `rlm_total_tool_response_tokens`, deduped + across turns/branches/sub-RLMs). Returns 0 when no such metric is present.""" + metrics = output.get("metrics") or {} + for key, value in metrics.items(): + if key.endswith("total_tool_response_tokens") and isinstance(value, (int, float)): + return int(value) + return 0 + + +def save_rollouts(rollouts: list[vf.RolloutOutput], path: Path, exclude_keys: set[str] | None = None) -> None: + """Save rollouts to a JSONL file using verifiers serialization.""" + path.parent.mkdir(parents=True, exist_ok=True) + opts = orjson.OPT_APPEND_NEWLINE | orjson.OPT_SERIALIZE_NUMPY + with open(path, "wb") as f: + for rollout in rollouts: + row = {k: v for k, v in rollout.items() if k not in exclude_keys} if exclude_keys else rollout + f.write(orjson.dumps(row, default=make_serializable, option=opts)) + + +def intercept_vf_logging(logger: str = "verifiers", level: str = "DEBUG", prefix: str | None = None): + """Intercepts verifiers logging and routes through prime-rl logger with optional prefix.""" + vf_logger = logging.getLogger(logger) + vf_logger.handlers.clear() + vf_logger.addHandler(InterceptHandler(prefix=prefix)) + vf_logger.setLevel(level.upper()) + vf_logger.propagate = False + + def set_default_executor(max_workers: int = 64) -> None: """Scale the default asyncio thread pool so asyncio.to_thread has enough capacity.""" get_logger().info(f"Setting default executor to ThreadPoolExecutor(max_workers={max_workers})") asyncio.get_event_loop().set_default_executor(ThreadPoolExecutor(max_workers=max_workers)) -def print_benchmark(history: dict[str, list[Any]]) -> None: - """ - Print benchmark results as rich table. Shows formatted step time values. - First N rows show the per-step values, and the last row shows the mean, - std, min, and max values. - """ - history.pop("step") - assert all(len(v) for v in history.values()), "All metrics must have logged the same number of steps" - - # Turn metric history into pd.DataFrame - df = pd.DataFrame(dict(history.items())) - columns = { - "time/step": "Step Time", - } - df = df.rename(columns=columns) - df = df[list(columns.values())] - df = df.iloc[1:] # Exclude first row - - # Setup console - console = Console() - table = Table(title="Benchmark") - - # Add columns - table.add_column("Step", justify="right") - for col in df.columns: - table.add_column(col, justify="center", style="magenta") - - # Add formatted rows - formatted_df = pd.DataFrame(columns=df.columns) - formatted_df["Step Time"] = df["Step Time"].apply(format_time) - for step, row in formatted_df.iterrows(): - table.add_row(*([str(step)] + [str(x) for x in row])) - - # Separator - num_table_columns = 1 + len(df.columns) - table.add_row(*([""] * num_table_columns)) - - # Add row for formatted, aggregated statistics - mean_df = df.describe().loc[["mean", "std", "min", "max"], :] - formatted_mean_df = pd.DataFrame(columns=mean_df.columns) - formatted_mean_df["Step Time"] = mean_df["Step Time"].apply(format_time) - mean_row = ["Overall"] + formatted_mean_df.T.apply( - lambda row: f"{row['mean']} ± {row['std']} [{row['min']}, {row['max']}]", axis=1 - ).tolist() - table.add_row(*mean_row) - - # Display table - console.print(table) - - async def compute_teacher_logprobs( clients: list[vf.ClientConfig], model_name: str, diff --git a/src/prime_rl/orchestrator/vf_utils.py b/src/prime_rl/orchestrator/vf_utils.py deleted file mode 100644 index 7da1e6a29d..0000000000 --- a/src/prime_rl/orchestrator/vf_utils.py +++ /dev/null @@ -1,98 +0,0 @@ -import logging -from pathlib import Path - -import orjson -import verifiers as vf -from verifiers.utils.save_utils import make_serializable - -from prime_rl.utils.logger import InterceptHandler - - -# TODO: remove once usage is tracked by verifiers -def get_prompt_len(output: vf.RolloutOutput) -> int: - """ - Computes the number of prompt tokens from vf.RolloutOutput. Defined as the - number of prompt ids from the first trajectory step. If raw tokens are not - available, falls back to checking the usage of the first response. - """ - if not output["trajectory"]: - return 0 - first_step = output["trajectory"][0] - if first_step["tokens"] is not None: - return len(first_step["tokens"]["prompt_ids"]) - first_step_response = first_step["response"] - return (first_step_response.get("usage") or {}).get("prompt_tokens", 0) - - -# TODO: remove once usage is tracked by verifiers -def get_seq_len(output: vf.RolloutOutput) -> int: - """ - Computes the number of tokens from vf.RolloutOutput. Defined as the sum of prompt - and completion tokens from the last trajectory step. If raw tokens are not - available, falls back to checking the usage of the last response. - """ - if not output["trajectory"]: - return 0 - last_step = output["trajectory"][-1] - if last_step["tokens"] is not None: - return len(last_step["tokens"]["prompt_ids"]) + len(last_step["tokens"]["completion_ids"]) - last_step_response = last_step["response"] - return (last_step_response.get("usage") or {}).get("total_tokens", 0) - - -# TODO: remove once usage is tracked by verifiers -def get_completion_len(output: vf.RolloutOutput) -> int: - """ - Computes the number of completion tokens from vf.RolloutOutput. Defined as - the difference between the total number of tokens and the number of prompt - tokens. - """ - return get_seq_len(output) - get_prompt_len(output) - - -def get_model_completion_len(output: vf.RolloutOutput) -> int: - """ - Computes the number of model-generated completion tokens across all turns. - Unlike get_completion_len, this excludes environment responses injected - between turns in multi-turn rollouts. - """ - return sum(len(step["tokens"]["completion_ids"]) for step in output["trajectory"] if step.get("tokens")) - - -def get_num_turns(output: vf.RolloutOutput) -> int: - """Number of turns (trajectory steps) in a rollout.""" - return len(output["trajectory"]) - - -def get_tool_response_len(output: vf.RolloutOutput) -> int: - """ - Total tool-response tokens consumed across the whole rollout. - - Read from a harness-emitted metric (e.g. RLM's `rlm_total_tool_response_tokens`, - deduped across turns/branches/sub-RLMs). Returns 0 if no harness metric is - present, which makes this a no-op for envs without tool-response accounting. - """ - metrics = output.get("metrics") or {} - for key, value in metrics.items(): - if key.endswith("total_tool_response_tokens") and isinstance(value, (int, float)): - return int(value) - return 0 - - -def save_rollouts(rollouts: list[vf.RolloutOutput], path: Path, exclude_keys: set[str] | None = None) -> None: - """Save rollouts to a JSONL file using verifiers serialization.""" - path.parent.mkdir(parents=True, exist_ok=True) - opts = orjson.OPT_APPEND_NEWLINE | orjson.OPT_SERIALIZE_NUMPY - with open(path, "wb") as f: - for rollout in rollouts: - row = {k: v for k, v in rollout.items() if k not in exclude_keys} if exclude_keys else rollout - f.write(orjson.dumps(row, default=make_serializable, option=opts)) - - -def intercept_vf_logging(logger: str = "verifiers", level: str = "DEBUG", prefix: str | None = None): - """Intercepts verifiers logging and routes through prime-rl logger with optional prefix.""" - vf_logger = logging.getLogger(logger) - vf_logger.handlers.clear() - vf_logger.addHandler(InterceptHandler(prefix=prefix)) - vf_logger.setLevel(level.upper()) - vf_logger.propagate = False diff --git a/src/prime_rl/orchestrator/watcher.py b/src/prime_rl/orchestrator/watcher.py new file mode 100644 index 0000000000..9a245823bb --- /dev/null +++ b/src/prime_rl/orchestrator/watcher.py @@ -0,0 +1,122 @@ +"""WeightWatcher: polls the broadcast dir, advances ``Policy``, notifies +observers (dispatcher → off-policy cancel). Standalone async task; the +orchestrator's barrier bounds the in-flight lead.""" + +from __future__ import annotations + +import asyncio +import time + +from prime_rl.configs.orchestrator import OrchestratorConfig +from prime_rl.orchestrator.types import Policy, VersionObserver +from prime_rl.utils.async_utils import safe_cancel +from prime_rl.utils.client import InferencePool +from prime_rl.utils.logger import format_time, get_logger +from prime_rl.utils.pathing import get_broadcast_dir, get_step_path, wait_for_path +from prime_rl.utils.utils import get_latest_ckpt_step + + +class WeightWatcher: + """``await watcher.start()`` to drive the polling loop until ``stop()``.""" + + def __init__( + self, + config: OrchestratorConfig, + *, + policy: Policy, + inference: InferencePool, + observers: list[VersionObserver], + lora_name: str | None, + ckpt_step: int = 0, + poll_interval: float = 1.0, + ) -> None: + self.config = config + self.policy = policy + self.inference = inference + self.observers = observers + self.lora_name = lora_name + self.ckpt_step = ckpt_step + self.poll_interval = poll_interval + + self.last_update_weights_time: float = 0.0 + self.last_wait_for_ckpt_time: float = 0.0 + self.update_count: int = 0 + + self.task: asyncio.Task | None = None + self.update_lock = asyncio.Lock() + self.stopped = asyncio.Event() + + async def start(self) -> None: + self.task = asyncio.current_task() + try: + while not self.stopped.is_set(): + next_step = self.compute_next_ckpt_step() + if next_step > self.ckpt_step: + await self.apply_policy_update(next_step) + await asyncio.sleep(self.poll_interval) + except asyncio.CancelledError: + return + + async def stop(self) -> None: + self.stopped.set() + if self.task is not None: + await safe_cancel(self.task) + self.task = None + + def compute_next_ckpt_step(self) -> int: + """Next checkpoint to adopt — at least ``policy.version`` (we stay + one step ahead of the trainer) plus anything fresher already + published in ``broadcasts/``.""" + broadcast_dir = get_broadcast_dir(self.config.output_dir) + latest_ckpt_step = get_latest_ckpt_step(broadcast_dir) or 0 + return max(self.policy.version, latest_ckpt_step) + + async def apply_policy_update(self, next_step: int) -> None: + async with self.update_lock: + if next_step <= self.ckpt_step: + # Another caller raced us — bail without re-applying + return + + broadcast_dir = get_broadcast_dir(self.config.output_dir) + weights_path = get_step_path(broadcast_dir, next_step) + stable_marker = weights_path / "STABLE" + if not stable_marker.exists(): + get_logger().info( + f"Orchestrator paused: waiting for trainer to broadcast checkpoint {next_step}. " + "Training is progressing normally." + ) + t0 = time.perf_counter() + await wait_for_path(stable_marker) + self.last_wait_for_ckpt_time = time.perf_counter() - t0 + get_logger().info( + f"Orchestrator resumed: checkpoint {next_step} ready (after {format_time(self.last_wait_for_ckpt_time)})" + ) + + get_logger().debug(f"Updating weights to step {next_step}") + t1 = time.perf_counter() + await self.inference.update_weights(weights_path, lora_name=self.lora_name, step=next_step) + self.last_update_weights_time = time.perf_counter() - t1 + self.update_count += 1 + get_logger().debug(f"Updated weights to step {next_step} in {format_time(self.last_update_weights_time)}") + + self.ckpt_step = next_step + self.policy.version = next_step + if self.lora_name is not None: + self.inference.update_model_name(self.lora_name) + self.policy.model_name = self.lora_name + + for observer in self.observers: + try: + await observer.on_new_version(next_step) + except Exception as exc: + get_logger().warning( + f"Observer {type(observer).__name__}.on_new_version({next_step}) raised: {exc!r}" + ) + + def gauges(self) -> dict[str, float]: + return { + "watcher/policy_version": float(self.policy.version), + "watcher/update_count": float(self.update_count), + "watcher/last_update_weights_time": self.last_update_weights_time, + "watcher/last_wait_for_ckpt_time": self.last_wait_for_ckpt_time, + } diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index ba57f7c0d5..83afa666dc 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -25,7 +25,7 @@ setup_cp_params, shard_for_cp, ) -from prime_rl.utils.logger import setup_logger +from prime_rl.utils.logger import format_time, setup_logger from prime_rl.trainer.rl.loss import ( compute_entropy, compute_loss, @@ -525,13 +525,13 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: tensors[key].append(loss_tensor.detach().to("cpu")) # Debug log with *local, micro step* stats - micro_step_message = f"Micro Step {micro_step}/{len(micro_batches)} | Loss: {tensors['loss'][-1].mean().item():.4f} | Entropy: {tensors['entropy/all'][-1].mean().item():.4f}" + micro_step_message = f"Micro Step {micro_step}/{len(micro_batches)} | Loss {tensors['loss'][-1].mean().item():.4f} | Entropy {tensors['entropy/all'][-1].mean().item():.4f}" if micro_batch["training_mode"] != "sft": - micro_step_message += f" | Mismatch KL: {tensors['mismatch_kl/all'][-1].mean().item():.4f}" + micro_step_message += f" | Mismatch KL {tensors['mismatch_kl/all'][-1].mean().item():.4f}" if "max_vio" in tensors: - micro_step_message += f" | Max Vio: {tensors['max_vio'][-1].mean().item():.4f}" + micro_step_message += f" | Max Vio {tensors['max_vio'][-1].mean().item():.4f}" if "routing_confidence" in tensors: - micro_step_message += f" | Routing Conf.: {tensors['routing_confidence'][-1].mean().item():.4f}" + micro_step_message += f" | Routing Conf. {tensors['routing_confidence'][-1].mean().item():.4f}" logger.debug(micro_step_message) # compute_loss already divided by the global token count. Undo FSDP's per-rank averaging @@ -583,16 +583,16 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: # Log step metrics step_time = time.perf_counter() - step_start_time - step_message = f"Step {progress.step} | Time: {step_time:.2f}s | Loss: {tensor_stats['loss/mean']:.4f} | Entropy: {tensor_stats['entropy/all/mean']:.4f}" + step_message = f"Step {progress.step} | {format_time(step_time):>7} | Loss {tensor_stats['loss/mean']:.4f} | Entropy {tensor_stats['entropy/all/mean']:.4f}" if "mismatch_kl/all/mean" in tensor_stats: - step_message += f" | Mismatch KL: {tensor_stats['mismatch_kl/all/mean']:.4f}" + step_message += f" | Mismatch KL {tensor_stats['mismatch_kl/all/mean']:.4f}" if grad_norm is not None: - step_message += f" | Grad. Norm: {grad_norm:.4f}" - step_message += f" | LR: {current_lr:.2e} | Throughput: {throughput:.0f} tokens/s | MFU: {mfu:.1f}% | Peak Mem.: {peak_memory:.1f} GiB" + step_message += f" | Grad. Norm {grad_norm:.4f}" + step_message += f" | LR {current_lr:.2e} | Throughput {throughput:.0f} tokens/s | MFU {mfu:.1f}% | Peak Mem. {peak_memory:.1f} GiB" if "max_vio/mean" in tensor_stats: - step_message += f" | Max Vio: {tensor_stats['max_vio/mean']:.4f}" + step_message += f" | Max Vio {tensor_stats['max_vio/mean']:.4f}" if "routing_confidence/mean" in tensor_stats: - step_message += f" | Routing Conf.: {tensor_stats['routing_confidence/mean']:.4f}" + step_message += f" | Routing Conf. {tensor_stats['routing_confidence/mean']:.4f}" logger.success(step_message) # Log performance metrics diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index e9faff2bc2..87a2337379 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -22,7 +22,7 @@ from prime_rl.utils.cp import setup_cp_params, shard_for_cp from prime_rl.trainer.runs import Progress, get_multi_run_manager, setup_multi_run_manager from prime_rl.trainer.models.layers.lora import set_lora_num_tokens -from prime_rl.utils.logger import setup_logger +from prime_rl.utils.logger import format_time, setup_logger from prime_rl.trainer.optim import setup_optimizer from prime_rl.trainer.scheduler import setup_scheduler from prime_rl.trainer.model import ( @@ -325,7 +325,7 @@ def run_validation(step: int) -> None: if mean_loss != mean_loss: logger.warning(f"Validation at step {step} had no valid tokens") else: - logger.success(f"Validation | Step {step} | Loss: {mean_loss:.4f}") + logger.success(f"Validation | Step {step} | Loss {mean_loss:.4f}") monitor.log({"val/loss": mean_loss, "step": step}, step=step) gc_handler = GarbageCollection(config.gc.interval) if config.gc else None @@ -501,15 +501,15 @@ def run_validation(step: int) -> None: # Log step metrics step_time = time.perf_counter() - step_start_time - step_message = f"Step {progress.step} | Time: {step_time:.2f}s | Loss: {batch_loss:.4f}" + step_message = f"Step {progress.step} | {format_time(step_time):>7} | Loss {batch_loss:.4f}" if grad_norm is not None: - step_message += f" | Grad. Norm: {grad_norm:.4f}" - step_message += f" | LR: {current_lr:.2e} | Throughput: {throughput:.0f} tokens/s | MFU: {mfu:.1f}% | Peak Mem.: {peak_memory:.1f}/{max_memory:.1f} GiB ({peak_memory / max_memory * 100:.1f}%)" + step_message += f" | Grad. Norm {grad_norm:.4f}" + step_message += f" | LR {current_lr:.2e} | Throughput {throughput:.0f} tokens/s | MFU {mfu:.1f}% | Peak Mem. {peak_memory:.1f}/{max_memory:.1f} GiB ({peak_memory / max_memory * 100:.1f}%)" if is_moe_model: for name, label in (("max_vio", "Max Vio"), ("routing_confidence", "Routing Conf.")): value = moe_stats[name].item() if value > 0: - step_message += f" | {label}: {value:.4f}" + step_message += f" | {label} {value:.4f}" logger.success(step_message) # Log progress metrics diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index fbe4af4fdc..b9ee8f4b9d 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -2,6 +2,7 @@ import asyncio import os +from collections.abc import Mapping from itertools import cycle from pathlib import Path from typing import Protocol, runtime_checkable @@ -16,6 +17,16 @@ from prime_rl.configs.shared import ClientConfig from prime_rl.utils.logger import get_logger +# Identity tuple used by ``select_train_client`` to key load counts. ``api_base_url`` +# distinguishes servers; ``X-data-parallel-rank`` distinguishes DP shards within a +# server, since the router uses that header to route to specific GPU ranks. +ClientIdentity = tuple[str, str | None] + + +def client_identity(client: vf.ClientConfig) -> ClientIdentity: + """Stable identity for load balancing across inference clients.""" + return (client.api_base_url, client.extra_headers.get("X-data-parallel-rank")) + @runtime_checkable class InferencePool(Protocol): @@ -44,6 +55,15 @@ async def get_eval_client(self) -> vf.ClientConfig: """Get next eval client in round-robin fashion.""" ... + async def select_train_client(self, load: Mapping[ClientIdentity, int]) -> vf.ClientConfig: + """Pick the train client with lowest in-flight load. + + Waits for at least one train client to be available, then returns + the one with the smallest ``load[client_identity(client)]``. The + caller owns the in-flight counter; the pool just picks against it. + """ + ... + async def wait_for_ready(self, model_name: str, timeout: int | None = None) -> None: """Wait for inference pool to be ready.""" ... @@ -106,6 +126,11 @@ def eval_clients(self) -> list[vf.ClientConfig]: async def get_eval_client(self) -> vf.ClientConfig: return next(self._eval_cycle) + async def select_train_client(self, load: Mapping[ClientIdentity, int]) -> vf.ClientConfig: + while not self.train_clients: + await asyncio.sleep(0.5) + return min(self.train_clients, key=lambda c: load[client_identity(c)]) + async def wait_for_ready(self, model_name: str, timeout: int | None = None) -> None: await check_health( self._admin_clients, timeout=timeout if timeout is not None else self._wait_for_ready_timeout @@ -301,10 +326,10 @@ def _is_retryable_pause_error(exception: BaseException) -> bool: PAUSE_TOTAL_TIMEOUT_S = 300.0 -async def _pause_engines(admin_clients: list[AsyncClient]) -> None: +async def _pause_engines(admin_clients: list[AsyncClient], *, step: int) -> None: """Pause all inference engines, waiting for in-flight requests to drain.""" logger = get_logger() - logger.info("Pausing inference engines for weight update") + logger.info(f"Updating policy in-flight to v{step}") @retry( retry=retry_if_exception(_is_retryable_pause_error), @@ -321,7 +346,7 @@ async def _pause(client: AsyncClient) -> None: response.raise_for_status() await asyncio.gather(*[_pause(client) for client in admin_clients]) - logger.info("All inference engines paused") + logger.debug("All inference engines paused") async def _resume_engines(admin_clients: list[AsyncClient]) -> None: @@ -333,7 +358,7 @@ async def _resume(client: AsyncClient) -> None: response.raise_for_status() await asyncio.gather(*[_resume(client) for client in admin_clients]) - logger.info("All inference engines resumed") + logger.debug("All inference engines resumed") async def update_weights( @@ -364,7 +389,7 @@ async def _update_weights(admin_client: AsyncClient, weight_dir: str | None) -> response.raise_for_status() # Pause engines so all DP workers drain in-flight work and can join the NCCL broadcast - await _pause_engines(admin_clients) + await _pause_engines(admin_clients, step=step) try: # Create ready marker before servers enter receive path (used by NCCL broadcast) diff --git a/src/prime_rl/utils/elastic.py b/src/prime_rl/utils/elastic.py index 39f4962c1b..951b3673c1 100644 --- a/src/prime_rl/utils/elastic.py +++ b/src/prime_rl/utils/elastic.py @@ -11,6 +11,7 @@ import asyncio import socket import time +from collections.abc import Mapping from dataclasses import dataclass from pathlib import Path from typing import Literal @@ -21,7 +22,7 @@ from renderers import RendererConfig from prime_rl.configs.shared import ClientConfig -from prime_rl.utils.client import load_lora_adapter, setup_admin_clients, setup_clients +from prime_rl.utils.client import ClientIdentity, client_identity, load_lora_adapter, setup_admin_clients, setup_clients from prime_rl.utils.logger import get_logger # --- Shared discovery functions --- @@ -230,6 +231,11 @@ async def get_eval_client(self) -> vf.ClientConfig: self._eval_index += 1 return client + async def select_train_client(self, load: Mapping[ClientIdentity, int]) -> vf.ClientConfig: + while not self.train_clients: + await asyncio.sleep(self.sync_interval) + return min(self.train_clients, key=lambda c: load[client_identity(c)]) + @property def admin_clients(self) -> list[AsyncClient]: return list(self._admin_clients.values()) diff --git a/src/prime_rl/utils/logger.py b/src/prime_rl/utils/logger.py index dce8046d4e..9e48d935a6 100644 --- a/src/prime_rl/utils/logger.py +++ b/src/prime_rl/utils/logger.py @@ -245,3 +245,28 @@ def close(self): percent = int(100 * self.current / self.total) if percent > self._last_logged_percent: self._emit_progress(percent) + + +def format_time(seconds: float) -> str: + """ + Format a time in seconds to a human-readable format: + - >1d -> Xd Yh + - >1h -> Xh Ym + - >1m -> Xm Ys + - <1s -> Xms + - Else: Xs + """ + if seconds < 1: + return f"{seconds * 1000:.0f}ms" + if seconds < 60: + return f"{seconds:.1f}s" + if seconds < 3600: + m, s = divmod(seconds, 60) + return f"{int(m)}m {int(s)}s" + if seconds < 86400: + h, rem = divmod(seconds, 3600) + m = rem // 60 + return f"{int(h)}h {int(m)}m" + d, rem = divmod(seconds, 86400) + h = rem // 3600 + return f"{int(d)}d {int(h)}h" diff --git a/tests/integration/test_reverse_text.py b/tests/integration/test_reverse_text.py index 532779dd79..590bd2f851 100644 --- a/tests/integration/test_reverse_text.py +++ b/tests/integration/test_reverse_text.py @@ -93,7 +93,7 @@ def test_reward_in_range(rl_process: ProcessResult, test_no_error, output_dir: P """Tests that the reward is in range in the RL process""" with open(output_dir / "logs" / "orchestrator.log", "r") as f: orchestrator_stdout = strip_escape_codes(f.read()).splitlines() - check_reward_in_range(orchestrator_stdout, min_threshold=0.65) + check_reward_in_range(orchestrator_stdout, min_threshold=0.6) def test_mismatch_kl_in_range(rl_process: ProcessResult, test_no_error, output_dir: Path): @@ -113,4 +113,4 @@ def test_reward_in_range_resume(rl_resume_process: ProcessResult, test_no_error_ """Tests that the reward is in range in the RL resume process""" with open(output_dir / "logs" / "orchestrator.log", "r") as f: orchestrator_stdout = strip_escape_codes(f.read()).splitlines() - check_reward_in_range(orchestrator_stdout, min_threshold=0.65) + check_reward_in_range(orchestrator_stdout, min_threshold=0.6) diff --git a/tests/unit/orchestrator/test_advantage.py b/tests/unit/orchestrator/test_advantage.py index 6acd1057e4..89022c8428 100644 --- a/tests/unit/orchestrator/test_advantage.py +++ b/tests/unit/orchestrator/test_advantage.py @@ -1,4 +1,5 @@ import math +import uuid import pytest @@ -11,10 +12,11 @@ from prime_rl.orchestrator.advantage import ( AdvantageInputs, AdvantageOutputs, - compute_advantages, + assign_advantages, default_advantage_fn, setup_advantage_fn, ) +from prime_rl.orchestrator.types import TrainRollout def _make_rollout( @@ -247,104 +249,45 @@ def test_efficiency_turns_penalty(): assert result.advantages == pytest.approx([0.625, 0.125, -0.875, 0.125], abs=1e-6) -def test_compute_advantages_with_config(): - rewards = [1.0, 0.5, 0.8, 0.2, 0.9, 0.1] - lengths = [10, 12, 8, 15, 11, 9] - rollouts = [_make_rollout(r, l, example_id=i // 3) for i, (r, l) in enumerate(zip(rewards, lengths))] - - compute_advantages(rollouts, advantage_config=DefaultAdvantageConfig()) - - advantages = [r["advantage"] for r in rollouts] - assert len(advantages) == 6 - assert sum(advantages[:3]) == pytest.approx(0.0, abs=1e-5) - assert sum(advantages[3:]) == pytest.approx(0.0, abs=1e-5) - - -def test_compute_advantages_no_cross_group_leakage(): - """Per-problem grouping: each problem must be centered against its own mean, in-order. - - Two problems with very different reward scales — cross-group leakage would pull the - small-scale group's advantages toward the large-scale group's mean (and vice versa). - Distinct positional values also catch ordering bugs in the group→flat round-trip. - """ - rewards = [10.0, 20.0, 30.0, 0.0, 0.1, 0.2] - rollouts = [_make_rollout(r, example_id=i // 3) for i, r in enumerate(rewards)] - - compute_advantages(rollouts, advantage_config=DefaultAdvantageConfig()) - - advantages = [r["advantage"] for r in rollouts] - assert advantages == pytest.approx([-10.0, 0.0, 10.0, -0.1, 0.0, 0.1], abs=1e-5) - - -def test_compute_advantages_without_config(): - rewards = [1.0, 0.5, 0.8] - lengths = [10, 12, 8] - rollouts = [_make_rollout(r, l) for r, l in zip(rewards, lengths)] - - compute_advantages(rollouts, advantage_config=None) - - advantages = [r["advantage"] for r in rollouts] - assert advantages == rewards - - -def test_compute_advantages_partial_groups(): - """Partial groups (size < group_size) are advantaged against their own mean. - - Two groups of different sizes must round-trip cleanly: each group's advantages - must sum to zero and not leak into the other. - """ - # Group A (example_id=0): 4 rollouts. Group B (example_id=1): 2 rollouts. - rollouts = [ - _make_rollout(1.0, example_id=0), - _make_rollout(0.0, example_id=0), - _make_rollout(1.0, example_id=0), - _make_rollout(0.0, example_id=0), - _make_rollout(0.3, example_id=1), - _make_rollout(0.7, example_id=1), +def _train_rollouts(rewards: list[float]) -> list[TrainRollout]: + """Wrap a list of rewards into ``TrainRollout``\\ s sharing a single + ``group_id`` — ``assign_advantages`` works on one group at a time + (the sink groups by ``group_id`` upstream).""" + gid = uuid.uuid4() + return [ + TrainRollout( + raw={"reward": r, "trajectory": []}, + env_name="test", + example_id=0, + group_id=gid, + policy_version=0, + off_policy_steps=0, + ) + for r in rewards ] - compute_advantages(rollouts, advantage_config=DefaultAdvantageConfig()) - advantages = [r["advantage"] for r in rollouts] - # Group A: mean=0.5, advantages=[0.5, -0.5, 0.5, -0.5] - assert advantages[:4] == pytest.approx([0.5, -0.5, 0.5, -0.5], abs=1e-5) - # Group B: mean=0.5, advantages=[-0.2, 0.2] - assert advantages[4:] == pytest.approx([-0.2, 0.2], abs=1e-5) - - -def test_compute_advantages_singleton_group_gets_zero_advantage(): - """A group of size 1 has reward == mean, so its advantage is 0 (filterable downstream).""" - rollouts = [ - _make_rollout(0.5, example_id=0), - _make_rollout(0.8, example_id=0), - _make_rollout(0.3, example_id=1), # singleton group - ] - - compute_advantages(rollouts, advantage_config=DefaultAdvantageConfig()) - - advantages = [r["advantage"] for r in rollouts] - # Group 0: mean=0.65, advantages=[-0.15, 0.15] - assert advantages[:2] == pytest.approx([-0.15, 0.15], abs=1e-5) - # Group 1 (singleton): advantage=0 - assert advantages[2] == pytest.approx(0.0, abs=1e-5) +def test_assign_advantages_writes_field(): + rollouts = _train_rollouts([1.0, 0.5, 0.8]) + fn = setup_advantage_fn(DefaultAdvantageConfig()) + assign_advantages(rollouts, fn) + advs = [r.advantage for r in rollouts] + assert sum(advs) == pytest.approx(0.0, abs=1e-6) -def test_compute_advantages_disambiguates_example_id_across_envs(): - """example_id=0 in env A and example_id=0 in env B must not be grouped together.""" - rollouts = [ - _make_rollout(1.0, env_name="env_a", example_id=0), - _make_rollout(0.0, env_name="env_a", example_id=0), - _make_rollout(100.0, env_name="env_b", example_id=0), - _make_rollout(200.0, env_name="env_b", example_id=0), - ] +def test_assign_advantages_without_fn_is_reward(): + """``advantage_fn=None`` falls back to ``advantage = reward``.""" + rollouts = _train_rollouts([1.0, 0.5, 0.8]) + assign_advantages(rollouts, None) + assert [r.advantage for r in rollouts] == [1.0, 0.5, 0.8] - compute_advantages(rollouts, advantage_config=DefaultAdvantageConfig()) - advantages = [r["advantage"] for r in rollouts] - # env_a group: mean=0.5, advantages=[0.5, -0.5] - assert advantages[:2] == pytest.approx([0.5, -0.5], abs=1e-5) - # env_b group: mean=150, advantages=[-50, 50] - assert advantages[2:] == pytest.approx([-50.0, 50.0], abs=1e-5) +def test_assign_advantages_singleton_group_is_zero(): + """A group of size 1 has reward == mean, so its advantage is 0.""" + rollouts = _train_rollouts([0.7]) + fn = setup_advantage_fn(DefaultAdvantageConfig()) + assign_advantages(rollouts, fn) + assert rollouts[0].advantage == pytest.approx(0.0, abs=1e-6) def test_setup_advantage_fn_with_custom_config(): diff --git a/tests/unit/orchestrator/test_buffer.py b/tests/unit/orchestrator/test_buffer.py deleted file mode 100644 index e7d2b038c0..0000000000 --- a/tests/unit/orchestrator/test_buffer.py +++ /dev/null @@ -1,233 +0,0 @@ -import random -from unittest.mock import MagicMock - -import pytest -import verifiers as vf -from datasets import Dataset - -from prime_rl.configs.orchestrator import BufferConfig, EnvConfig -from prime_rl.orchestrator.buffer import Buffer -from prime_rl.orchestrator.envs import Envs, TrainEnv - - -def make_env(name: str, vf_env: vf.Environment, **config_kwargs) -> TrainEnv: - """Create a TrainEnv without calling vf.load_environment.""" - config = EnvConfig(id=name, name=name, **config_kwargs) - env = TrainEnv.__new__(TrainEnv) - env.config = config - env._env = vf_env - env._env_client = None - env._env_server_process = None - env.sampling_args = {} - return env - - -def make_envs(env_dict: dict[str, TrainEnv]) -> Envs: - """Create an Envs container from a dict of Env instances.""" - envs = Envs.__new__(Envs) - envs._envs = env_dict - return envs - - -@pytest.fixture(autouse=True) -def set_seed(): - random.seed(42) - - -@pytest.fixture -def mock_openai_client(): - """Return a mocked OpenAI client.""" - return MagicMock() - - -@pytest.fixture -def dummy_dataset() -> Dataset: - """Return a dummy dataset with 5 examples.""" - return Dataset.from_dict( - { - "question": ["q0", "q1", "q2", "q3", "q4"], - "answer": ["a0", "a1", "a2", "a3", "a4"], - } - ) - - -@pytest.fixture -def dummy_envs(mock_openai_client, dummy_dataset) -> Envs: - """Return an Envs with two dummy envs.""" - env_a = vf.SingleTurnEnv( - client=mock_openai_client, - model="test-model", - dataset=dummy_dataset, - rubric=vf.Rubric(), - ) - env_b = vf.SingleTurnEnv( - client=mock_openai_client, - model="test-model", - dataset=dummy_dataset, - rubric=vf.Rubric(), - ) - return make_envs( - { - "env_a": make_env("env_a", env_a), - "env_b": make_env("env_b", env_b), - } - ) - - -@pytest.fixture -def make_rollouts(): - def _make_rollouts( - buffer: Buffer, env_name: str, indices: list[int], rewards: list[float] - ) -> list[vf.RolloutOutput]: - all_rollouts = [] - eb = buffer.env_buffers[env_name] - examples = list(eb.examples.values()) - for idx, reward in zip(indices, rewards): - example = examples[idx] - rollouts = [ - vf.RolloutOutput( - example_id=example["example_id"], - task=example["env_name"], - prompt=example["prompt"], - prompt_ids=[0], - prompt_mask=[1], - completion_ids=[1], - completion_mask=[1], - completion_logprobs=[0.0], - is_truncated=False, - reward=reward, - advantage=1.0, - metrics={}, - ) - ] * 2 - for r in rollouts: - r["env_name"] = env_name - all_rollouts.extend(rollouts) - return all_rollouts - - return _make_rollouts - - -def get_normal_count(buffer: Buffer) -> int: - return sum(eb.num_normal for eb in buffer.env_buffers.values()) - - -def test_buffer_init_and_sample(dummy_envs): - buffer = Buffer(dummy_envs, BufferConfig()) - assert buffer.env_buffers["env_a"].num_normal == 5 - assert buffer.env_buffers["env_b"].num_normal == 5 - samples = buffer.sample_examples(2) - assert len(samples) == 2 - - -def test_buffer_problem_pool_assignment(dummy_envs, make_rollouts): - """Problems are moved to easy/hard pools based on reward thresholds.""" - buffer = Buffer(dummy_envs, BufferConfig(easy_threshold=1.0, hard_threshold=0.0)) - buffer.update(make_rollouts(buffer, "env_a", list(range(5)), rewards=[1.0, 1.0, 0.5, 0.5, 0.0])) - - assert len(buffer.env_buffers["env_a"].easy_examples) == 2 - assert len(buffer.env_buffers["env_a"].hard_examples) == 1 - # 2 normal from env_a + 5 from env_b = 7 - assert get_normal_count(buffer) == 7 - - -def test_buffer_online_difficulty_filtering(dummy_envs, make_rollouts): - """With online_difficulty_filtering=True, only partial reward rollouts are kept.""" - buffer = Buffer( - dummy_envs, - BufferConfig(online_difficulty_filtering=True), - ) - buffer.update(make_rollouts(buffer, "env_a", list(range(5)), rewards=[1.0, 0.5, 0.0, 0.5, 0.5])) - - # Only 3 problems with reward 0.5 -> 6 rollouts kept - assert len(buffer.rollout_buffer) == 6 - - -def test_buffer_no_filtering_by_default(dummy_envs, make_rollouts): - """With online_difficulty_filtering=False (default), all rollouts are kept.""" - buffer = Buffer(dummy_envs, BufferConfig()) - buffer.update(make_rollouts(buffer, "env_a", list(range(5)), rewards=[1.0, 0.5, 0.0, 0.5, 0.5])) - - # All 5 problems -> 10 rollouts kept - assert len(buffer.rollout_buffer) == 10 - - -def test_buffer_save_load_with_conversion(dummy_envs, make_rollouts, tmp_path): - """Easy/hard problems are partially converted to normal on load.""" - buffer = Buffer(dummy_envs, BufferConfig(easy_threshold=1.0, hard_threshold=0.0)) - buffer.update(make_rollouts(buffer, "env_a", list(range(5)), rewards=[1.0, 1.0, 0.5, 0.5, 0.0])) - buffer.save(tmp_path / "buffer") - - new_buffer = Buffer(dummy_envs, BufferConfig(easy_fraction=0.5, hash_keys=["prompt", "env_name"])) - new_buffer.load(tmp_path / "buffer") - - # 1 of 2 easy problems converted to normal - assert len(new_buffer.env_buffers["env_a"].easy_examples) == 1 - # 2 were normal + 5 from env_b + 1 converted from easy = 8 - assert get_normal_count(new_buffer) == 8 - - -def test_buffer_env_ratios(mock_openai_client, dummy_dataset): - env_a = vf.SingleTurnEnv(client=mock_openai_client, model="test-model", dataset=dummy_dataset, rubric=vf.Rubric()) - env_b = vf.SingleTurnEnv(client=mock_openai_client, model="test-model", dataset=dummy_dataset, rubric=vf.Rubric()) - envs = make_envs( - { - "env_a": make_env("env_a", env_a, ratio=0.8), - "env_b": make_env("env_b", env_b, ratio=0.2), - } - ) - - buffer = Buffer(envs, BufferConfig()) - assert buffer.env_buffers["env_a"].num_normal == 5 - assert buffer.env_buffers["env_b"].num_normal == 5 - - samples = buffer.sample_examples(100) - env_a_count = sum(1 for p in samples if p["env_name"] == "env_a") - assert 60 <= env_a_count <= 95 - - -def test_buffer_env_ratios_validation(): - """Validates that env ratios must be positive and all-or-nothing.""" - from pydantic import ValidationError - - from prime_rl.configs.orchestrator import TrainConfig, TrainEnvConfig - - with pytest.raises(ValidationError): - EnvConfig(id="env_a", ratio=-0.3) - - with pytest.raises(ValidationError, match="mix of set and unset"): - TrainConfig(env=[TrainEnvConfig(id="a", ratio=0.5), TrainEnvConfig(id="b")]) - - -def test_buffer_no_cross_env_pool_assignment(mock_openai_client, tmp_path): - """Pool assignments don't transfer if example_id exists but env changed.""" - original_dataset = Dataset.from_dict({"question": ["q0"], "answer": ["a0"]}) - original_env = vf.SingleTurnEnv( - client=mock_openai_client, - model="test-model", - dataset=original_dataset, - rubric=vf.Rubric(), - ) - original_env_set = make_envs({"env_a": make_env("env_a", original_env)}) - - buffer = Buffer(original_env_set, BufferConfig(easy_threshold=1.0)) - eb = buffer.env_buffers["env_a"] - example_id = list(eb.examples.keys())[0] - example = eb.examples.pop(example_id) - eb.easy_examples.append(example) - buffer.save(tmp_path / "buffer") - - new_dataset = Dataset.from_dict({"question": ["different_q"], "answer": ["different_a"]}) - new_env = vf.SingleTurnEnv( - client=mock_openai_client, - model="test-model", - dataset=new_dataset, - rubric=vf.Rubric(), - ) - new_env_set = make_envs({"env_b": make_env("env_b", new_env)}) - - new_buffer = Buffer(new_env_set, BufferConfig()) - new_buffer.load(tmp_path / "buffer") - - assert len(new_buffer.env_buffers["env_b"].easy_examples) == 0 - assert new_buffer.env_buffers["env_b"].num_normal == 1 diff --git a/tests/unit/orchestrator/test_filters.py b/tests/unit/orchestrator/test_filters.py index e77f51f61f..2643bf71bb 100644 --- a/tests/unit/orchestrator/test_filters.py +++ b/tests/unit/orchestrator/test_filters.py @@ -1,4 +1,5 @@ import math +import uuid from prime_rl.configs.orchestrator import GibberishFilterConfig, RepetitionFilterConfig from prime_rl.orchestrator.filters import ( @@ -8,10 +9,19 @@ setup_filter, setup_filters, ) - - -def _make_rollout(completion_ids, completion_logprobs, reward=1.0, multi_step=False): - """Create a minimal rollout dict matching the verifiers RolloutOutput structure.""" +from prime_rl.orchestrator.types import TrainRollout + + +def _make_rollout( + completion_ids: list[int], + completion_logprobs: list[float], + *, + reward: float = 1.0, + multi_step: bool = False, +) -> TrainRollout: + """Build a ``TrainRollout`` with a minimal ``vf.RolloutOutput``-shaped + raw payload — enough for the filters to inspect ``trajectory`` / + ``stop_condition`` / etc.""" if multi_step: mid = len(completion_ids) // 2 trajectory = [ @@ -40,12 +50,20 @@ def _make_rollout(completion_ids, completion_logprobs, reward=1.0, multi_step=Fa } } ] - return { + raw = { "trajectory": trajectory, "reward": reward, "stop_condition": None, "metrics": {}, } + return TrainRollout( + raw=raw, + env_name="test", + example_id=0, + group_id=uuid.uuid4(), + policy_version=0, + off_policy_steps=0, + ) def _make_gibberish_filter(vocab_size=128_000, token_id_threshold=100_000, logprob_offset=2.0, enforce=False): @@ -209,7 +227,7 @@ def test_setup_filters_multiple(): GibberishFilterConfig(), RepetitionFilterConfig(), ] - filters = setup_filters(configs, vocab_size=128_000) + filters = setup_filters(configs, vocab_size=128_000, kind="post-batch") assert len(filters) == 2 assert filters[0].name == "gibberish" assert filters[1].name == "repetition" @@ -229,12 +247,12 @@ def test_apply_filters_enforced_flags_rollout(): apply_filters([gibberish_filter], [rollout]) - assert rollout["reward"] == 1.0 - assert rollout["trajectory"][0]["tokens"]["completion_ids"] == [120_000] - assert rollout["trajectory"][0]["tokens"]["completion_mask"] == [1] - assert rollout["stop_condition"] is None - assert rollout["filters"] == {"gibberish": True} - assert rollout["is_filtered"] is True + assert rollout.reward == 1.0 + assert rollout.raw["trajectory"][0]["tokens"]["completion_ids"] == [120_000] + assert rollout.raw["trajectory"][0]["tokens"]["completion_mask"] == [1] + assert rollout.raw["stop_condition"] is None + assert rollout.filter_results == {"gibberish": True} + assert rollout.is_filtered is True def test_apply_filters_preserves_clean_rollouts(): @@ -248,12 +266,12 @@ def test_apply_filters_preserves_clean_rollouts(): apply_filters([gibberish_filter], [rollout]) - assert rollout["reward"] == 1.0 - assert rollout["trajectory"][0]["tokens"]["completion_ids"] == [50, 60, 70] - assert all(m == 1 for m in rollout["trajectory"][0]["tokens"]["completion_mask"]) - assert rollout["stop_condition"] is None - assert rollout["filters"] == {"gibberish": False} - assert rollout["is_filtered"] is False + assert rollout.reward == 1.0 + assert rollout.raw["trajectory"][0]["tokens"]["completion_ids"] == [50, 60, 70] + assert all(m == 1 for m in rollout.raw["trajectory"][0]["tokens"]["completion_mask"]) + assert rollout.raw["stop_condition"] is None + assert rollout.filter_results == {"gibberish": False} + assert rollout.is_filtered is False def test_apply_filters_first_filter_wins(): @@ -268,9 +286,9 @@ def test_apply_filters_first_filter_wins(): apply_filters([gibberish_filter, repetition_filter], [rollout]) - assert rollout["stop_condition"] is None - assert rollout["filters"] == {"gibberish": True, "repetition": False} - assert rollout["is_filtered"] is True + assert rollout.raw["stop_condition"] is None + assert rollout.filter_results == {"gibberish": True, "repetition": False} + assert rollout.is_filtered is True def test_apply_filters_empty_list(): @@ -279,9 +297,9 @@ def test_apply_filters_empty_list(): completion_logprobs=[-1.0, -1.0, -1.0], ) apply_filters([], [rollout]) - assert rollout["filters"] == {} - assert rollout["is_filtered"] is False - assert rollout["reward"] == 1.0 + assert rollout.filter_results == {} + assert rollout.is_filtered is False + assert rollout.reward == 1.0 def test_apply_filters_mixed_batch(): @@ -294,10 +312,10 @@ def test_apply_filters_mixed_batch(): apply_filters([gibberish_filter], [clean, dirty]) - assert clean["reward"] == 1.0 - assert dirty["reward"] == 1.0 - assert clean["is_filtered"] is False - assert dirty["is_filtered"] is True + assert clean.reward == 1.0 + assert dirty.reward == 1.0 + assert clean.is_filtered is False + assert dirty.is_filtered is True def test_apply_filters_enforced_preserves_rollout_tokens(): @@ -311,14 +329,14 @@ def test_apply_filters_enforced_preserves_rollout_tokens(): apply_filters([gibberish_filter], [rollout]) - assert rollout["trajectory"][0]["tokens"]["completion_ids"] == [10, 120_000, 30] - assert rollout["trajectory"][0]["tokens"]["completion_logprobs"] == [ + assert rollout.raw["trajectory"][0]["tokens"]["completion_ids"] == [10, 120_000, 30] + assert rollout.raw["trajectory"][0]["tokens"]["completion_logprobs"] == [ -1.0, gibberish_filter.logprob_threshold - 1.0, -0.5, ] - assert rollout["trajectory"][0]["tokens"]["completion_mask"] == [1, 1, 1] - assert rollout["is_filtered"] is True + assert rollout.raw["trajectory"][0]["tokens"]["completion_mask"] == [1, 1, 1] + assert rollout.is_filtered is True def test_apply_filters_preserves_existing_stop_condition(): @@ -329,12 +347,12 @@ def test_apply_filters_preserves_existing_stop_condition(): completion_logprobs=[gibberish_filter.logprob_threshold - 1.0], reward=1.0, ) - rollout["stop_condition"] = "generation_truncated" + rollout.raw["stop_condition"] = "generation_truncated" apply_filters([gibberish_filter], [rollout]) - assert rollout["stop_condition"] == "generation_truncated" - assert rollout["is_filtered"] is True + assert rollout.raw["stop_condition"] == "generation_truncated" + assert rollout.is_filtered is True # --- apply_filters tests (monitor-only, enforce=False) --- @@ -351,11 +369,11 @@ def test_apply_filters_monitor_only_tracks_detection(): apply_filters([gibberish_filter], [rollout]) - assert rollout["reward"] == 1.0 - assert all(m == 1 for m in rollout["trajectory"][0]["tokens"]["completion_mask"]) - assert rollout["stop_condition"] is None - assert rollout["filters"] == {"gibberish": True} - assert rollout["is_filtered"] is False + assert rollout.reward == 1.0 + assert all(m == 1 for m in rollout.raw["trajectory"][0]["tokens"]["completion_mask"]) + assert rollout.raw["stop_condition"] is None + assert rollout.filter_results == {"gibberish": True} + assert rollout.is_filtered is False def test_apply_filters_monitor_only_mixed_batch(): @@ -368,7 +386,7 @@ def test_apply_filters_monitor_only_mixed_batch(): apply_filters([gibberish_filter], [clean, dirty]) - assert clean["reward"] == 1.0 - assert dirty["reward"] == 1.0 - assert clean["is_filtered"] is False - assert dirty["is_filtered"] is False + assert clean.reward == 1.0 + assert dirty.reward == 1.0 + assert clean.is_filtered is False + assert dirty.is_filtered is False diff --git a/tests/unit/orchestrator/test_orchestrator_setup.py b/tests/unit/orchestrator/test_orchestrator_setup.py index a1554f152c..2372b004fd 100644 --- a/tests/unit/orchestrator/test_orchestrator_setup.py +++ b/tests/unit/orchestrator/test_orchestrator_setup.py @@ -1,10 +1,10 @@ import asyncio from types import SimpleNamespace -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch from renderers import Qwen3VLRendererConfig -from prime_rl.orchestrator.orchestrator import setup_student_inference_pool +from prime_rl.orchestrator.utils import setup_student_inference_pool def test_setup_student_inference_pool_uses_renderer_when_enabled(): @@ -20,21 +20,19 @@ async def run() -> None: renderer=renderer_settings, pool_size=None, ) - logger = MagicMock() renderer = object() inference_pool = object() with ( - patch("prime_rl.orchestrator.orchestrator.create_renderer", return_value=renderer) as create_renderer_mock, + patch("renderers.base.create_renderer", return_value=renderer) as create_renderer_mock, patch( - "prime_rl.orchestrator.orchestrator.setup_inference_pool", + "prime_rl.orchestrator.utils.setup_inference_pool", new=AsyncMock(return_value=inference_pool), ) as setup_pool_mock, ): returned_renderer, returned_pool = await setup_student_inference_pool( config=config, tokenizer=tokenizer, - logger=logger, ) assert returned_renderer is renderer @@ -66,20 +64,18 @@ async def run() -> None: model=SimpleNamespace(name="student-model"), ), ) - logger = MagicMock() inference_pool = object() with ( - patch("prime_rl.orchestrator.orchestrator.create_renderer") as create_renderer_mock, + patch("renderers.base.create_renderer") as create_renderer_mock, patch( - "prime_rl.orchestrator.orchestrator.setup_inference_pool", + "prime_rl.orchestrator.utils.setup_inference_pool", new=AsyncMock(return_value=inference_pool), ) as setup_pool_mock, ): renderer, returned_pool = await setup_student_inference_pool( config=config, tokenizer=tokenizer, - logger=logger, ) assert renderer is None diff --git a/tests/unit/orchestrator/test_scheduler.py b/tests/unit/orchestrator/test_scheduler.py deleted file mode 100644 index 009d0d3d82..0000000000 --- a/tests/unit/orchestrator/test_scheduler.py +++ /dev/null @@ -1,272 +0,0 @@ -import asyncio -from pathlib import Path -from types import SimpleNamespace -from unittest.mock import AsyncMock, MagicMock, patch - -import verifiers as vf - -from prime_rl.orchestrator.scheduler import GroupState, InflightRequest, Scheduler -from prime_rl.utils.async_utils import safe_cancel - - -def make_scheduler() -> Scheduler: - scheduler = Scheduler.__new__(Scheduler) - scheduler.step = 9 - scheduler.ckpt_step = 7 - scheduler.config = SimpleNamespace(output_dir=Path("/tmp/prime-rl-test")) - scheduler.logger = MagicMock() - scheduler.checkpoint_ready = asyncio.Event() - scheduler.checkpoint_ready.set() - scheduler.lora_name = None - scheduler.model_name = "test-model" - scheduler.update_weights_time = 0 - scheduler.wait_for_ckpt_time = 0 - scheduler.inflight_requests = {} - scheduler.groups = {} - scheduler.max_off_policy_steps = 1 - scheduler.cancelled_rollouts_count = 0 - scheduler.policy_update_lock = asyncio.Lock() - scheduler.inflight_policy_update_task = None - scheduler.update_policy_task = None - scheduler.rate_limiter = None - return scheduler - - -def test_update_off_policy_does_not_increment_interleaved_on_policy_tasks(): - async def run() -> None: - scheduler = Scheduler.__new__(Scheduler) - scheduler.max_off_policy_steps = 1 - scheduler.cancelled_rollouts_count = 0 - scheduler.logger = MagicMock() - - client = SimpleNamespace(api_base_url="http://test") - stale_task = asyncio.create_task(asyncio.sleep(60)) - survivor_task = asyncio.create_task(asyncio.sleep(60)) - interleaved_task = None - - scheduler.inflight_requests = { - stale_task: InflightRequest(off_policy_steps=1, client_config=client, env_name="test", group_id=1), - survivor_task: InflightRequest(off_policy_steps=0, client_config=client, env_name="test", group_id=2), - } - - async def drop_group(group_id: int) -> int: - tasks_to_remove = [ - task for task, info in list(scheduler.inflight_requests.items()) if info.group_id == group_id - ] - for task in tasks_to_remove: - scheduler.inflight_requests.pop(task, None) - task.cancel() - - await asyncio.sleep(0) - - nonlocal interleaved_task - if interleaved_task is None: - interleaved_task = asyncio.create_task(asyncio.sleep(60)) - scheduler.inflight_requests[interleaved_task] = InflightRequest( - off_policy_steps=0, - client_config=client, - env_name="test", - group_id=3, - ) - return len(tasks_to_remove) - - scheduler.drop_group = drop_group - - await scheduler._update_off_policy() - - assert stale_task not in scheduler.inflight_requests - assert scheduler.inflight_requests[survivor_task].off_policy_steps == 1 - assert interleaved_task is not None - assert scheduler.inflight_requests[interleaved_task].off_policy_steps == 0 - assert scheduler.cancelled_rollouts_count == 1 - - for task in (stale_task, survivor_task, interleaved_task): - if task is not None and not task.done(): - task.cancel() - await asyncio.sleep(0) - - asyncio.run(run()) - - -def test_maybe_update_policy_reuses_inflight_update_after_cancellation(): - async def run() -> None: - scheduler = make_scheduler() - started = asyncio.Event() - release = asyncio.Event() - applied_steps: list[int] = [] - - async def update_weights(weight_dir, lora_name=None, step=0) -> None: - applied_steps.append(step) - started.set() - await release.wait() - - scheduler.student_inference = SimpleNamespace( - update_weights=update_weights, - update_model_name=MagicMock(), - ) - scheduler.rollout_inference = scheduler.student_inference - scheduler._update_off_policy = AsyncMock() - - with ( - patch("prime_rl.orchestrator.scheduler.get_latest_ckpt_step", return_value=8), - patch("prime_rl.orchestrator.scheduler.wait_for_path", new=AsyncMock()), - ): - first = asyncio.create_task(scheduler.maybe_update_policy()) - await started.wait() - await safe_cancel(first) - - second = asyncio.create_task(scheduler.maybe_update_policy()) - await asyncio.sleep(0) - assert applied_steps == [8] - - release.set() - await second - - assert applied_steps == [8] - assert scheduler.ckpt_step == 8 - - asyncio.run(run()) - - -def test_stop_cancels_inflight_policy_update_task(): - async def run() -> None: - scheduler = make_scheduler() - started = asyncio.Event() - cancelled = asyncio.Event() - - async def update_weights(weight_dir, lora_name=None, step=0) -> None: - started.set() - try: - await asyncio.Future() - finally: - cancelled.set() - - scheduler.student_inference = SimpleNamespace( - update_weights=update_weights, - update_model_name=MagicMock(), - ) - scheduler.rollout_inference = scheduler.student_inference - scheduler._update_off_policy = AsyncMock() - - with ( - patch("prime_rl.orchestrator.scheduler.get_latest_ckpt_step", return_value=8), - patch("prime_rl.orchestrator.scheduler.wait_for_path", new=AsyncMock()), - ): - scheduler.update_policy_task = asyncio.create_task(scheduler.maybe_update_policy()) - await started.wait() - await asyncio.wait_for(scheduler.stop(), timeout=0.2) - - assert cancelled.is_set() - assert scheduler.update_policy_task is None - assert scheduler.inflight_policy_update_task is None - - asyncio.run(run()) - - -def test_client_identity_distinguishes_base_url_and_dp_rank(): - client_a = vf.ClientConfig( - api_base_url="http://worker-a:8000/v1", - extra_headers={"X-data-parallel-rank": "0"}, - ) - client_b = vf.ClientConfig( - api_base_url="http://worker-a:8000/v1", - extra_headers={"X-data-parallel-rank": "1"}, - ) - - assert Scheduler._client_identity(client_a) != Scheduler._client_identity(client_b) - - -def test_lora_policy_update_in_sft_keeps_teacher_model_name(): - """In sft mode, train_pool is the teacher. LoRA updates the student inference - pool but must not change scheduler.model_name (which is what gets sent to the - teacher endpoint on each rollout request).""" - - async def run() -> None: - scheduler = make_scheduler() - scheduler.model_name = "teacher-model" - scheduler.lora_name = "student-lora" - - student_inference = SimpleNamespace( - update_weights=AsyncMock(), - update_model_name=MagicMock(), - ) - teacher_inference = SimpleNamespace() - scheduler.student_inference = student_inference - scheduler.rollout_inference = teacher_inference # sft: train_pool != student_inference - scheduler._update_off_policy = AsyncMock() - - with ( - patch("prime_rl.orchestrator.scheduler.get_latest_ckpt_step", return_value=8), - patch("prime_rl.orchestrator.scheduler.wait_for_path", new=AsyncMock()), - ): - await scheduler.maybe_update_policy() - - student_inference.update_weights.assert_awaited_once() - student_inference.update_model_name.assert_called_once_with("student-lora") - assert scheduler.model_name == "teacher-model" - - asyncio.run(run()) - - -def test_lora_policy_update_in_rl_updates_model_name(): - """In rl/opd mode, train_pool is the student. LoRA updates redirect rollout - requests to the new LoRA name.""" - - async def run() -> None: - scheduler = make_scheduler() - scheduler.model_name = "student-model" - scheduler.lora_name = "student-lora" - - student_inference = SimpleNamespace( - update_weights=AsyncMock(), - update_model_name=MagicMock(), - ) - scheduler.student_inference = student_inference - scheduler.rollout_inference = student_inference # rl/opd: same pool - scheduler._update_off_policy = AsyncMock() - - with ( - patch("prime_rl.orchestrator.scheduler.get_latest_ckpt_step", return_value=8), - patch("prime_rl.orchestrator.scheduler.wait_for_path", new=AsyncMock()), - ): - await scheduler.maybe_update_policy() - - student_inference.update_weights.assert_awaited_once() - student_inference.update_model_name.assert_called_once_with("student-lora") - assert scheduler.model_name == "student-lora" - - asyncio.run(run()) - - -def test_schedule_rollout_uses_train_pool(): - """schedule_rollout dispatches to train_pool's clients with train_pool's model name.""" - - async def run() -> None: - scheduler = make_scheduler() - scheduler.model_name = "teacher-model" - teacher_client = vf.ClientConfig(api_base_url="http://teacher.example/v1") - env = SimpleNamespace( - requires_group_scoring=False, - run_rollout=AsyncMock(return_value=[]), - ) - scheduler.rollout_inference = SimpleNamespace(train_clients=[teacher_client]) - scheduler.train_envs = SimpleNamespace(get=MagicMock(return_value=env)) - scheduler.groups = { - 0: GroupState( - example={"env_name": "math", "example_id": "ex-1"}, - rollouts_to_schedule=1, - ) - } - - await scheduler.schedule_rollout(group_id=0) - await asyncio.gather(*scheduler.inflight_requests) - - env.run_rollout.assert_awaited_once_with( - client=teacher_client, - example={"env_name": "math", "example_id": "ex-1"}, - model_name="teacher-model", - cache_salt="7", - ) - assert scheduler.groups[0].pinned_client is teacher_client - - asyncio.run(run()) diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 87df2646c6..77b9e99781 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -15,7 +15,7 @@ def interleave_rollout(output, *args, **kwargs): - output.setdefault("env_name", "test-env") + kwargs.setdefault("env_name", output.get("env_name", "test-env")) return _interleave_rollout(output, *args, **kwargs) diff --git a/tests/utils.py b/tests/utils.py index 23165d5ecd..dbad6c55ad 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -25,7 +25,7 @@ def check_number_goes_up_or_down( lines: list[str], start_step: int = 0, end_step: int = -1, - pattern: str = r"Reward:\s*(\d+\.\d{4})", + pattern: str = r"Reward:?\s+(\d+\.\d{4})", go_up: bool = True, ): """Helper to assert that a number in lines goes up from a specified start to end step""" @@ -100,18 +100,18 @@ def check_metric_in_range( def check_reward_goes_up(lines: list[str]): - return check_number_goes_up_or_down(lines, go_up=True, pattern=r"Reward:\s*(\d+\.\d{4})") + return check_number_goes_up_or_down(lines, go_up=True, pattern=r"Reward:?\s+(\d+\.\d{4})") def check_loss_goes_down(lines: list[str]): - return check_number_goes_up_or_down(lines, go_up=False, pattern=r"Loss:\s*(\d+\.\d{4})") + return check_number_goes_up_or_down(lines, go_up=False, pattern=r"Loss:?\s+(\d+\.\d{4})") def check_eval_avg_goes_up(lines: list[str], env_name: str): - """Assert that the last `Evaluated {env_name} ... Avg@K=X.XXXX` line reports a - higher score than the first one. Use for smoke tests with `interval = 1` - evals.""" - pattern = rf"Evaluated {re.escape(env_name)} .*Avg@\d+=(\d+\.\d{{4}})" + """Assert that the last `Evaluated {env_name} (Step N) | ... | Reward X.XXXX` + line reports a higher score than the first one. Use for smoke tests with + `interval = 1` evals.""" + pattern = rf"Evaluated {re.escape(env_name)} .*Reward:?\s+(\d+\.\d{{4}})" eval_lines = [line for line in lines if "SUCCESS" in line and re.search(pattern, line)] assert len(eval_lines) >= 2, f"Need at least 2 eval lines for {env_name!r}, found {len(eval_lines)}" start = float(re.search(pattern, eval_lines[0]).group(1)) @@ -132,7 +132,7 @@ def check_reward_in_range( check_metric_in_range( lines, metric_name="Reward", - pattern=r"Reward:\s*(\d+\.\d{4})", + pattern=r"Reward:?\s+(\d+\.\d{4})", step=step, min_threshold=min_threshold, max_threshold=max_threshold, @@ -146,7 +146,7 @@ def check_avg_reward_in_range( max_threshold: float | None = None, ): """Helper to assert that the average reward over the last N steps is within a threshold""" - pattern = r"Reward:\s*(\d+\.\d{4})" + pattern = r"Reward:?\s+(\d+\.\d{4})" step_lines = [line for line in lines if "SUCCESS" in line and "Step" in line and re.search(pattern, line)] assert len(step_lines) >= last_n_steps, ( f"Not enough step lines found. Expected at least {last_n_steps}, got {len(step_lines)}" @@ -179,7 +179,7 @@ def check_avg_mismatch_kl_in_range( max_threshold: float | None = None, ): """Helper to assert that the average mismatch KL over the last N steps is within a threshold""" - pattern = r"Mismatch KL:\s*(\d+\.\d{4})" + pattern = r"Mismatch KL:?\s+(\d+\.\d{4})" step_lines = [line for line in lines if "SUCCESS" in line and "Step" in line and re.search(pattern, line)] assert len(step_lines) >= last_n_steps, ( f"Not enough step lines found. Expected at least {last_n_steps}, got {len(step_lines)}" @@ -215,7 +215,7 @@ def check_mismatch_kl_in_range( check_metric_in_range( lines, metric_name="Mismatch KL", - pattern=r"Mismatch KL:\s*(\d+\.\d{4})", + pattern=r"Mismatch KL:?\s+(\d+\.\d{4})", step=step, min_threshold=min_threshold, max_threshold=max_threshold, diff --git a/uv.lock b/uv.lock index 1ea087e2d1..26522a57ef 100644 --- a/uv.lock +++ b/uv.lock @@ -66,6 +66,7 @@ members = [ "prime-rl-configs", "renderers", "reverse-text", + "rlm-swe", "science-env", "simpleqa-verified", "tau2-bench", @@ -369,6 +370,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/73/b6e24bd22e6720ca8ee9a85a0c4a2971af8497d8f3193fa05390cbd46e09/backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8", size = 15148, upload-time = "2022-10-05T19:19:30.546Z" }, ] +[[package]] +name = "bashlex" +version = "0.18" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/76/60/aae0bb54f9af5e0128ba90eb83d8d0d506ee8f0475c4fdda3deeda20b1d2/bashlex-0.18.tar.gz", hash = "sha256:5bb03a01c6d5676338c36fd1028009c8ad07e7d61d8a1ce3f513b7fff52796ee", size = 68742, upload-time = "2023-01-18T15:21:26.402Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/be/6985abb1011fda8a523cfe21ed9629e397d6e06fb5bae99750402b25c95b/bashlex-0.18-py2.py3-none-any.whl", hash = "sha256:91d73a23a3e51711919c1c899083890cdecffc91d8c088942725ac13e9dcfffa", size = 69539, upload-time = "2023-01-18T15:21:24.167Z" }, +] + [[package]] name = "bcrypt" version = "5.0.0" @@ -903,6 +913,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/23/18/4cedda786e7da429e7489549a9e5461530d4133130e541f25fb94f015776/cyclopts-4.11.2-py3-none-any.whl", hash = "sha256:838020120b939549ff7c8423aca29c86764b5dd1d8a5d7f3753a6327861f537b", size = 213537, upload-time = "2026-05-04T00:11:56.103Z" }, ] +[[package]] +name = "dataclasses-json" +version = "0.6.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "marshmallow", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "typing-inspect", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/a4/f71d9cf3a5ac257c993b5ca3f93df5f7fb395c725e7f1e6479d2514173c3/dataclasses_json-0.6.7.tar.gz", hash = "sha256:b6b3e528266ea45b9535223bc53ca645f5208833c29229e847b3f26a1cc55fc0", size = 32227, upload-time = "2024-06-09T16:20:19.103Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" }, +] + [[package]] name = "datasets" version = "4.6.1" @@ -2555,6 +2578,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/e0/4486f11e51bbba8b0c041098859e869e304d1c261e59244baa3d295d47b7/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b", size = 23015, upload-time = "2025-09-27T18:36:37.868Z" }, ] +[[package]] +name = "marshmallow" +version = "3.26.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/79/de6c16cc902f4fc372236926b0ce2ab7845268dcc30fb2fbb7f71b418631/marshmallow-3.26.2.tar.gz", hash = "sha256:bbe2adb5a03e6e3571b573f42527c6fe926e17467833660bebd11593ab8dfd57", size = 222095, upload-time = "2025-12-22T06:53:53.309Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/2f/5108cb3ee4ba6501748c4908b908e55f42a5b66245b4cfe0c99326e1ef6e/marshmallow-3.26.2-py3-none-any.whl", hash = "sha256:013fa8a3c4c276c24d26d84ce934dc964e2aa794345a0f8c7e5a7191482c8a73", size = 50964, upload-time = "2025-12-22T06:53:51.801Z" }, +] + [[package]] name = "math-env" version = "0.1.5" @@ -2899,6 +2934,26 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b4/68/04d7a8f0f786545cf9b8c280c57aa6befb5977af6e884b8b54191cbe44b3/msgspec-0.21.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ef3ec2296248d1f8b9231acb051b6d471dfde8f21819e86c9adaaa9f42918521", size = 227303, upload-time = "2026-04-12T21:44:13.709Z" }, ] +[[package]] +name = "multi-swe-bench" +version = "1.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dataclasses-json", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "docker", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "gitpython", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pygithub", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pyyaml", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "swe-rex", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "toml", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "tqdm", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "unidiff", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/48/ad/6b7cda600a50392c790b14ee420b9a3bb318a982a298c05f2d1c066a434f/multi_swe_bench-1.1.2.tar.gz", hash = "sha256:44944bc6608d7d9b8d4390f3ce0a3b2c69122ea6be6e35766c6fde2328f50392", size = 1267660, upload-time = "2025-12-18T07:16:09.584Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/a8/060eb46096742944d8d37c34094d4e0fb34b28c6291a877388543ea65660/multi_swe_bench-1.1.2-py3-none-any.whl", hash = "sha256:09a5770096d6a035383c5240762ffa8c87b1e8df7d374110de8fb781b4e5a9f9", size = 4942355, upload-time = "2025-12-18T07:16:07.468Z" }, +] + [[package]] name = "multidict" version = "6.7.1" @@ -2928,6 +2983,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6c/28/dd72947e59a6a8c856448a5e74da6201cb5502ddff644fbc790e4bd40b9a/multiprocess-0.70.18-py39-none-any.whl", hash = "sha256:e78ca805a72b1b810c690b6b4cc32579eba34f403094bbbae962b7b5bf9dfcb8", size = 133478, upload-time = "2025-04-17T03:11:26.253Z" }, ] +[[package]] +name = "mypy-extensions" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, +] + [[package]] name = "narwhals" version = "2.21.0" @@ -4033,6 +4097,7 @@ envs = [ { name = "opencode-science", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "opencode-swe", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "reverse-text", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "rlm-swe", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "science-env", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "simpleqa-verified", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "tau2-bench", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, @@ -4130,6 +4195,7 @@ requires-dist = [ { name = "reverse-text", marker = "extra == 'envs'", editable = "deps/verifiers/environments/reverse_text" }, { name = "rich", specifier = ">=14.0.0" }, { name = "ring-flash-attn", specifier = ">=0.1.8" }, + { name = "rlm-swe", marker = "extra == 'envs'", editable = "deps/research-environments/environments/rlm_swe" }, { name = "science-env", marker = "extra == 'envs'", editable = "deps/research-environments/environments/science_env" }, { name = "setproctitle", specifier = ">=1.3.0" }, { name = "simpleqa-verified", marker = "extra == 'envs'", editable = "deps/research-environments/environments/simpleqa_verified" }, @@ -4517,6 +4583,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1a/03/bef6fff907e212d67a0003f8ea4819307bba91b2856074a0763dd483ccc4/pyfiglet-1.0.2-py3-none-any.whl", hash = "sha256:889b351d79c99e50a3f619c8f8e6ffdb27fd8c939fc43ecbd7559bd57d5f93ea", size = 1085824, upload-time = "2023-09-13T20:56:18.707Z" }, ] +[[package]] +name = "pygithub" +version = "2.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyjwt", extra = ["crypto"], marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pynacl", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "requests", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "urllib3", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/c3/8465a311197e16cf5ab68789fe689535e90f6b61ab524cc32a39e67237ae/pygithub-2.9.1.tar.gz", hash = "sha256:59771d7ff63d54d427be2e7d0dad2208dfffc2b0a045fec959263787739b611c", size = 2594989, upload-time = "2026-04-14T07:26:13.622Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/aa/81a5506f089a26338bff17535e4339b3b22049ebd1bcdeff756c4d7a7559/pygithub-2.9.1-py3-none-any.whl", hash = "sha256:2ec78fca30092d51a42d76f4ddb02131b6f0c666a35dfdf364cf302cdda115b9", size = 449710, upload-time = "2026-04-14T07:26:12.382Z" }, +] + [[package]] name = "pygments" version = "2.20.0" @@ -4540,6 +4622,25 @@ crypto = [ { name = "cryptography", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] +[[package]] +name = "pynacl" +version = "1.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and platform_python_implementation != 'PyPy' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_python_implementation != 'PyPy' and sys_platform == 'linux') or (platform_machine == 'aarch64' and platform_python_implementation == 'PyPy' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'x86_64' and platform_python_implementation == 'PyPy' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/9a/4019b524b03a13438637b11538c82781a5eda427394380381af8f04f467a/pynacl-1.6.2.tar.gz", hash = "sha256:018494d6d696ae03c7e656e5e74cdfd8ea1326962cc401bcf018f1ed8436811c", size = 3511692, upload-time = "2026-01-01T17:48:10.851Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/b4/e927e0653ba63b02a4ca5b4d852a8d1d678afbf69b3dbf9c4d0785ac905c/pynacl-1.6.2-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8845c0631c0be43abdd865511c41eab235e0be69c81dc66a50911594198679b0", size = 800020, upload-time = "2026-01-01T17:32:18.34Z" }, + { url = "https://files.pythonhosted.org/packages/7f/81/d60984052df5c97b1d24365bc1e30024379b42c4edcd79d2436b1b9806f2/pynacl-1.6.2-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:22de65bb9010a725b0dac248f353bb072969c94fa8d6b1f34b87d7953cf7bbe4", size = 1399174, upload-time = "2026-01-01T17:32:20.239Z" }, + { url = "https://files.pythonhosted.org/packages/68/f7/322f2f9915c4ef27d140101dd0ed26b479f7e6f5f183590fd32dfc48c4d3/pynacl-1.6.2-cp38-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:46065496ab748469cdd999246d17e301b2c24ae2fdf739132e580a0e94c94a87", size = 835085, upload-time = "2026-01-01T17:32:22.24Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d0/f301f83ac8dbe53442c5a43f6a39016f94f754d7a9815a875b65e218a307/pynacl-1.6.2-cp38-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8a66d6fb6ae7661c58995f9c6435bda2b1e68b54b598a6a10247bfcdadac996c", size = 1437614, upload-time = "2026-01-01T17:32:23.766Z" }, + { url = "https://files.pythonhosted.org/packages/c4/58/fc6e649762b029315325ace1a8c6be66125e42f67416d3dbd47b69563d61/pynacl-1.6.2-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:26bfcd00dcf2cf160f122186af731ae30ab120c18e8375684ec2670dccd28130", size = 818251, upload-time = "2026-01-01T17:32:25.69Z" }, + { url = "https://files.pythonhosted.org/packages/c9/a8/b917096b1accc9acd878819a49d3d84875731a41eb665f6ebc826b1af99e/pynacl-1.6.2-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:c8a231e36ec2cab018c4ad4358c386e36eede0319a0c41fed24f840b1dac59f6", size = 1402859, upload-time = "2026-01-01T17:32:27.215Z" }, + { url = "https://files.pythonhosted.org/packages/85/42/fe60b5f4473e12c72f977548e4028156f4d340b884c635ec6b063fe7e9a5/pynacl-1.6.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:68be3a09455743ff9505491220b64440ced8973fe930f270c8e07ccfa25b1f9e", size = 791926, upload-time = "2026-01-01T17:32:29.314Z" }, + { url = "https://files.pythonhosted.org/packages/fa/f9/e40e318c604259301cc091a2a63f237d9e7b424c4851cafaea4ea7c4834e/pynacl-1.6.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8b097553b380236d51ed11356c953bf8ce36a29a3e596e934ecabe76c985a577", size = 1363101, upload-time = "2026-01-01T17:32:31.263Z" }, +] + [[package]] name = "pyparsing" version = "3.3.2" @@ -4958,6 +5059,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/02/18ba0727a1c755c528d6a52b363d62c0b7a8e64cf961b3030c046107db4d/ring_flash_attn-0.1.8-py3-none-any.whl", hash = "sha256:296c929516c3b21f7bcdaeca44a99bb541779a7b63979eb0f67837dcb18a2bb9", size = 25437, upload-time = "2025-09-10T11:53:07.565Z" }, ] +[[package]] +name = "rlm-swe" +version = "0.4.2" +source = { editable = "deps/research-environments/environments/rlm_swe" } +dependencies = [ + { name = "multi-swe-bench", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "prime-sandboxes", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "swebench", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "verifiers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] + +[package.metadata] +requires-dist = [ + { name = "multi-swe-bench", specifier = ">=1.1.2" }, + { name = "prime-sandboxes", specifier = ">=0.2.19" }, + { name = "swebench", specifier = "==4.1.0" }, + { name = "verifiers", specifier = ">=0.1.13.dev8" }, +] + [[package]] name = "rpds-py" version = "0.30.0" @@ -5303,6 +5423,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/65/5e726c372da8a5e35022a94388b12252710aad0c2351699c3d76ae8dba78/supervisor-4.3.0-py2.py3-none-any.whl", hash = "sha256:0bcb763fddafba410f35cbde226aa7f8514b9fb82eb05a0c85f6588d1c13f8db", size = 320736, upload-time = "2025-08-23T18:25:00.767Z" }, ] +[[package]] +name = "swe-rex" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bashlex", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "fastapi", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pexpect", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pydantic", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "python-multipart", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "requests", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "rich", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "uvicorn", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/86/a069f93ec866151a4d476d546e60220e66b3788878b6e248b2df3ab2c5f1/swe_rex-1.4.0.tar.gz", hash = "sha256:14f8a24c49a63f9e251340b1109ac75a4aacbaece410f8599209de9bfca843c0", size = 41755, upload-time = "2025-08-14T01:19:20.22Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/0d/d06ab2aa78138055c297490762cd7b4d8ac58a544783f874c869cdb7b534/swe_rex-1.4.0-py3-none-any.whl", hash = "sha256:61261ad03eb23b717b5901cd5d229f24f6e1be2e120aad5c2e5ea3384a1d15ad", size = 47756, upload-time = "2025-08-14T01:19:18.93Z" }, +] + [[package]] name = "swebench" version = "4.1.0" @@ -5856,6 +5995,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, ] +[[package]] +name = "typing-inspect" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mypy-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dc/74/1789779d91f1961fa9438e9a8710cdae6bd138c80d7303996933d117264a/typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78", size = 13825, upload-time = "2023-05-24T20:25:47.612Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/f3/107a22063bf27bdccf2024833d3445f4eea42b2e598abfbd46f6a63b6cb0/typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f", size = 8827, upload-time = "2023-05-24T20:25:45.287Z" }, +] + [[package]] name = "typing-inspection" version = "0.4.2"