trainer: configurable checkpoint state components + empty-meta construction#104
Merged
Conversation
…uction Restores fine-grained control over what a checkpoint saves/loads, and makes "build the model empty on meta, weights supplied externally" a first-class path. Driven by one config knob; DiLoCo is the first consumer. - TrainingArguments.checkpoint_components (list[str]|None, None=all): selects which state components a run saves/loads. - BaseTrainer.get_active_state_components() filters get_state_components() by that field — applied at the single live consumer (CheckpointManager), so all five get_state_components() implementations stay untouched. - CheckpointManager skips _save_model / _load_model_from_checkpoint when "model" is filtered out (model_state_component is None); validate_checkpoint accepts a model-less checkpoint (one with the coordinator manifest) so it remains discoverable for resume. - Construction derives from the set: Trainer._model_weights_external() is true when "model" is excluded; _prepare_model then forces meta, skips the meta->default downgrade, and initializes the empty skeleton in place; _restore_from_checkpoint skips its own init in that case. PipelineTrainer also runs _initialize_params when weights are external (so a model-excluded resume still initializes the stages). - DiLoCo defaults in lm_training_project.yaml: construct_model_on=meta + checkpoint_components=[optimizer, scheduler, trainer, dataset, rng] (everything but model), overridable via the checkpoint_components var. The DiLoCo worker now builds empty on meta (allocation-free) and checkpoints training state but never model weights — the server owns the weights and the parameter sync fills them at register. Tests: tests/unit/ml/test_checkpoint_components.py (filter, external signal, manager skip save/load) + empty-meta build in test_meta_checkpoint_load.py. Docs: diloco-architecture.md section updated from planned to implemented; diloco.md + templates de-stale the "planned PR4" notes. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… model-less marker Address local-review low-severity findings: - get_active_state_components now raises on a checkpoint_components key outside the known vocabulary (model/optimizer/scheduler/trainer/dataset/ rng). A misspelled "model" must not silently drop the component and turn a normal run into a weights-external one (no-silent-fallback). A *known* key a given run doesn't produce is still allowed and ignored. - Model-less checkpoints are now marked explicitly: the CheckpointManager writes MODEL_EXCLUDED_MARKER when "model" is excluded, and validate_checkpoint accepts a model-less checkpoint ONLY with that marker. A normal checkpoint missing its weights (partial/corrupt save) without the marker stays invalid, so discovery falls back to an older complete one rather than selecting the partial and failing at load (replaces the broader manifest-present acceptance). Tests extended: unknown-key raises; known-but-unproduced ignored; marker present -> valid, removed -> invalid. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ion) Per project policy: permanent docs on dev must not narrate revision-to- revision deltas or PR history (dev accumulates WIP across many PRs and makes no dev-only compatibility guarantee). Rewrite the diloco-architecture.md "Checkpoint state selection + empty-meta construction" section as plain present-tense documentation — drop the "PR #102 was scoped to… / this follow-up made it real / as implemented" framing and the stale "manifest-only checkpoints are valid" line (it's the explicit MODEL_EXCLUDED_MARKER now). Also record the policy in CLAUDE.md: feature branches base on origin/dev, main is the release branch, and docs describe the feature as-is (WIP design docs may track PR continuity but must carry a TODO-remove). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Dataset position under DiLoCo is tracked by the server via work-units, not the local dataloader, so a local dataset checkpoint would be stale/misleading on resume. Exclude "dataset" from the DiLoCo checkpoint_components default (now [optimizer, scheduler, trainer, rng]); update the arch doc to match and explain why. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
trainer_options.md: add the checkpoint_components row + notes; update the construct_model_on row and meta notes to the current behavior (default→meta auto-route on resume, meta-init load contract, weights-external empty-meta). user_guide.md: add a "Selecting Checkpoint Components" section and note the weights-external (DiLoCo-style) mode; scope the "model always required" caveat to runs that save model weights. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add a "Selecting active components (checkpoint_components)" subsection to the distributed-checkpoint abstraction doc: get_active_state_components() filters each trainer's full get_state_components() declaration at a single chokepoint (CheckpointManager consumes the filtered set), with the known-key vocabulary and typo guard. Documents the weights-external mode — excluding "model" skips model save/load, writes MODEL_EXCLUDED_MARKER, and validate_checkpoint accepts a model-less checkpoint only with that marker (partial/corrupt normal ones stay invalid). Cross-links trainer_options.md + user_guide.md. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…uild The previous change ran _initialize_params (rank-0 builds the full model on CPU and distributes initialized parameters) for all weights-external (DiLoCo) pipeline runs. That is the expensive last-resort construction path — it can OOM rank-0 and defeats the purpose of supplying weights externally. Under DiLoCo the parameter server owns the persistent weights and fills them via the sync at register; the only state the sync does not carry is the non-persistent buffers (e.g. RoPE inv_freq). So for the external case the pipeline trainer now runs the per-stage initialize-missing pass (_initialize_missing_after_load) — which recomputes exactly those buffers, locally and cheaply — instead of _initialize_params. The garbage persistent weights from to_empty() are overwritten by _apply_global_params before the first forward. From-scratch (no checkpoint, non-DiLoCo) still uses _initialize_params, and normal resume still loads then initializes-missing, both unchanged. Arch doc updated to match. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Unify the meta-construction "load -> init-missing" path so the external- weights (DiLoCo) case stops doing a full init that the parameter sync then overwrites. - New non-HF callback event on_load_model_weights, dispatched from _restore_from_checkpoint at the point a checkpoint would load, when checkpoint_components excludes "model". The handler loads the weights from its external source and flags them (_is_hf_initialized); the trainer's following initialize-missing pass then fills only what wasn't provided (non-persistent buffers, e.g. RoPE inv_freq). _restore_from_checkpoint now runs for resume OR external, composing a non-model checkpoint load with the external weight load. - DiLoCoCallback: the worker bring-up (registration + apply global params) moves from on_train_begin into _start_worker, invoked by on_load_model_weights; on_train_begin is an idempotent fallback. After start() it flags the server-provided state-dict keys (per-stage for pipeline) so the init pass skips them instead of clobbering them. - _prepare_model no longer inits the empty skeleton for the external case (moved to the hook); PipelineTrainer skips _initialize_params for external (the expensive rank-0 build) — the hook + per-stage initialize-missing recompute only the buffers. Removes the full-init-then-overwrite and the per-trainer external branches; DiLoCo is now "a checkpoint load whose source is the server." Arch doc updated. 433 diloco+checkpoint+meta tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…gle hook Two safety checks at the external-load point (no-silent-fallback): - _restore_from_checkpoint raises if checkpoint_components excludes "model" but no callback implements on_load_model_weights (BaseTrainer._has_event_handler). - _verify_external_weights_loaded() runs AFTER the hook and BEFORE initialize-missing: every persistent tensor (state_dict, keep_vars) must be flagged _is_hf_initialized, i.e. actually loaded. Since initialize-missing fills anything unloaded, this is the only place to catch "the source loaded nothing" — otherwise it would silently random-init and train against garbage. Non-persistent buffers (RoPE) are excluded (init recomputes them). - Both enumerate via a new overridable _materialized_modules() (base -> self.model; pipeline -> pipeline_modules), which also subsumes the pipeline _initialize_missing_after_load override. DiLoCo callback simplified (it's forgather-only, so the hook always fires): _start_worker merges into on_load_model_weights (no idempotency guard, single entry point); on_train_begin becomes a defensive assert that the worker was started, failing loud if the run is misconfigured (model not excluded). Tests: handler-presence + weights-loaded verification (incl. non-persistent buffer exclusion) in test_checkpoint_components.py; on_train_begin assert + no-op-after-load in test_diloco_callback.py (call sites moved to on_load_model_weights). 471 diloco+checkpoint+meta tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
DiLoCo trains better under a Warmup-Stable-Decay schedule (warmup + stable,
no fixed-length decay) than the default fixed cosine-decay, and the LR
scheduler state is saved with the worker checkpoint so it resumes cleanly.
- Move the WSD annealing math (annealing_tokens -> annealing_steps ->
decay_start_step) and the --annealing-tokens arg up into the shared
lm_training_project.yaml [globals], defaulting to NO decay phase
(annealing_tokens=0 -> decay_start_step=-1). Remove the now-duplicated
computation from finetune_v2.yaml and examples/pretrain/small-llm, which
inherit it.
- Swap WSDScheduler into the DiLoCo example configs (base_lm/diloco,
tiny_experiments/diloco {default,small}).
- wsd_scheduler.py: allow decay_steps=0 only when decay is disabled
(decay_start_step < 0 and not start_decay); give both asserts actionable
messages instead of bare failures.
- Fix a pre-existing typo in tiny_experiments/diloco/default.yaml:
"Trenable_dilocoue" -> "enable_diloco" (the override knob was silently
dead — it resolved to True only because the garbled name was undefined).
- Update WSDScheduler tests: decay_steps=0 is OK when decay is disabled,
raises when start_decay is requested.
Reviewed for errors (inheritance, render, assert interactions) before
landing. All affected configs render clean; WSD + optim tests pass.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…s FSDP2) General review round on PR #104 surfaced a critical regression and minor items. - CRITICAL: model save/load was gated on `model_state_component is None`, which is also true for a normal FSDP2 (world_size>1) run — it registers no "model" StateComponent and saves/loads via model_save_fn/model_load_fn hooks. That gating silently dropped FSDP2 model weights and mislabeled the checkpoint model-less. Fix: gate on an explicit `model_weights_external` flag passed to CheckpointManager from the trainer's `_model_weights_external()` (moved to BaseTrainer). FSDP2 normal runs save/load again; DiLoCo still skips + marks. Wired at all three manager construction sites; +2 regression tests. - DiLoCo callback now flags only the parameter names it actually syncs (named_parameters, matching ParamView.apply_global) rather than all state_dict keys — so _verify_external_weights_loaded still catches a persistent buffer the server never supplied instead of masking it. No change for shipped models (no persistent buffers). - Fix stale on_train_begin references in diloco_callback (error message, load_state_dict docstring, comments) after the hook rename. NB: finetune_v2 --start-annealing requiring an explicit annealing_tokens is intentional (not restored); the relaxed WSD assert now gives an actionable message pointing at it. 653 diloco+checkpoint+meta+optim tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Follow-up to the DiLoCo model-bundle PR (#102). Restores fine-grained control over what a checkpoint saves/loads (which existed before the CheckpointManager refactor) and makes "build the model empty on meta, weights supplied externally" a first-class trainer path. One config knob drives both; DiLoCo is the first consumer.
This implements the design recorded in
docs/trainers/diloco-architecture.md(now updated from "Planned" to implemented).The problem
A DiLoCo worker must never save/load model weights — the parameter server owns them, and a stale local copy from a different sync round would be wrong. But it should still checkpoint other state (trainer progress / LR / RNG, optionally optimizer). And it should build the model empty on meta (allocation-free) rather than on-device only to overwrite it.
The blockers:
_resolve_checkpoint→resume_from_checkpoint=pathmade_restore_from_checkpointload model weights; and explicitconstruct_model_on=metawas downgraded todefaulton a fresh run.Design — one knob, derived behavior
TrainingArguments.checkpoint_components(list[str] | None,None= all → unchanged behavior) selects which components a run saves/loads.BaseTrainer.get_active_state_components()filtersget_state_components()by that field, applied at the single live consumer (CheckpointManager.__init__) — so all fiveget_state_components()implementations (base/ddp/fsdp2/accel/pipeline) stay untouched."model"filtered out,model_state_component is None, and the manager skips_save_model/_load_model_from_checkpoint.validate_checkpointaccepts a model-less checkpoint (manifest present) so it's still discoverable for resume._model_weights_external()is true when"model"is excluded →_prepare_modelforces meta, skips the downgrade, and initializes the empty skeleton in place;_restore_from_checkpointskips its own init then.PipelineTraineralso runs_initialize_paramswhen weights are external (so a model-excluded resume still inits the stages).lm_training_project.yamlsets, under DiLoCo,construct_model_on: meta+checkpoint_components: [optimizer, scheduler, trainer, dataset, rng](everything butmodel), overridable via thecheckpoint_componentsvar. The inner-optimizer keep/skip choice is thus config, not policy.Tests
tests/unit/ml/test_checkpoint_components.py— the filter (None/subset/unknown-key),_model_weights_external, and the CheckpointManager skipping model save/load (and not clobbering live weights on load) while persisting non-model state.test_meta_checkpoint_load.py::test_meta_empty_no_checkpoint_fully_initializes— the empty-meta build (skeleton fully initialized, RoPE recomputed, forward runs) on a realmodels/tiny.Docs
diloco-architecture.md's section moved from "Planned (PR 4)" to "Checkpoint state selection + empty-meta construction" (implemented, with corrected touch-points);diloco.mdand the templates de-stale the "planned" notes.🤖 Generated with Claude Code