Skip to content

trainer: configurable checkpoint state components + empty-meta construction#104

Merged
jdinalt merged 11 commits into
devfrom
feature/diloco-checkpoint-components
May 31, 2026
Merged

trainer: configurable checkpoint state components + empty-meta construction#104
jdinalt merged 11 commits into
devfrom
feature/diloco-checkpoint-components

Conversation

@jdinalt
Copy link
Copy Markdown
Owner

@jdinalt jdinalt commented May 30, 2026

Summary

Follow-up to the DiLoCo model-bundle PR (#102). Restores fine-grained control over what a checkpoint saves/loads (which existed before the CheckpointManager refactor) and makes "build the model empty on meta, weights supplied externally" a first-class trainer path. One config knob drives both; DiLoCo is the first consumer.

This implements the design recorded in docs/trainers/diloco-architecture.md (now updated from "Planned" to implemented).

The problem

A DiLoCo worker must never save/load model weights — the parameter server owns them, and a stale local copy from a different sync round would be wrong. But it should still checkpoint other state (trainer progress / LR / RNG, optionally optimizer). And it should build the model empty on meta (allocation-free) rather than on-device only to overwrite it.

The blockers: _resolve_checkpointresume_from_checkpoint=path made _restore_from_checkpoint load model weights; and explicit construct_model_on=meta was downgraded to default on a fresh run.

Design — one knob, derived behavior

  • TrainingArguments.checkpoint_components (list[str] | None, None = all → unchanged behavior) selects which components a run saves/loads.
  • BaseTrainer.get_active_state_components() filters get_state_components() by that field, applied at the single live consumer (CheckpointManager.__init__) — so all five get_state_components() implementations (base/ddp/fsdp2/accel/pipeline) stay untouched.
  • Model save/load gated on the component: with "model" filtered out, model_state_component is None, and the manager skips _save_model / _load_model_from_checkpoint. validate_checkpoint accepts a model-less checkpoint (manifest present) so it's still discoverable for resume.
  • Construction derives from the set: _model_weights_external() is true when "model" is excluded → _prepare_model forces meta, skips the downgrade, and initializes the empty skeleton in place; _restore_from_checkpoint skips its own init then. PipelineTrainer also runs _initialize_params when weights are external (so a model-excluded resume still inits the stages).
  • DiLoCo defaults in config: lm_training_project.yaml sets, under DiLoCo, construct_model_on: meta + checkpoint_components: [optimizer, scheduler, trainer, dataset, rng] (everything but model), overridable via the checkpoint_components var. The inner-optimizer keep/skip choice is thus config, not policy.

Tests

  • tests/unit/ml/test_checkpoint_components.py — the filter (None/subset/unknown-key), _model_weights_external, and the CheckpointManager skipping model save/load (and not clobbering live weights on load) while persisting non-model state.
  • test_meta_checkpoint_load.py::test_meta_empty_no_checkpoint_fully_initializes — the empty-meta build (skeleton fully initialized, RoPE recomputed, forward runs) on a real models/tiny.
  • Full checkpoint + meta suites green (290 pre-existing + new).

Docs

diloco-architecture.md's section moved from "Planned (PR 4)" to "Checkpoint state selection + empty-meta construction" (implemented, with corrected touch-points); diloco.md and the templates de-stale the "planned" notes.

🤖 Generated with Claude Code

jdinalt and others added 11 commits May 30, 2026 20:44
…uction

Restores fine-grained control over what a checkpoint saves/loads, and makes
"build the model empty on meta, weights supplied externally" a first-class
path. Driven by one config knob; DiLoCo is the first consumer.

- TrainingArguments.checkpoint_components (list[str]|None, None=all): selects
  which state components a run saves/loads.
- BaseTrainer.get_active_state_components() filters get_state_components() by
  that field — applied at the single live consumer (CheckpointManager), so
  all five get_state_components() implementations stay untouched.
- CheckpointManager skips _save_model / _load_model_from_checkpoint when
  "model" is filtered out (model_state_component is None); validate_checkpoint
  accepts a model-less checkpoint (one with the coordinator manifest) so it
  remains discoverable for resume.
- Construction derives from the set: Trainer._model_weights_external() is true
  when "model" is excluded; _prepare_model then forces meta, skips the
  meta->default downgrade, and initializes the empty skeleton in place;
  _restore_from_checkpoint skips its own init in that case. PipelineTrainer
  also runs _initialize_params when weights are external (so a model-excluded
  resume still initializes the stages).
- DiLoCo defaults in lm_training_project.yaml: construct_model_on=meta +
  checkpoint_components=[optimizer, scheduler, trainer, dataset, rng]
  (everything but model), overridable via the checkpoint_components var.

The DiLoCo worker now builds empty on meta (allocation-free) and checkpoints
training state but never model weights — the server owns the weights and the
parameter sync fills them at register.

Tests: tests/unit/ml/test_checkpoint_components.py (filter, external signal,
manager skip save/load) + empty-meta build in test_meta_checkpoint_load.py.
Docs: diloco-architecture.md section updated from planned to implemented;
diloco.md + templates de-stale the "planned PR4" notes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… model-less marker

Address local-review low-severity findings:
- get_active_state_components now raises on a checkpoint_components key
  outside the known vocabulary (model/optimizer/scheduler/trainer/dataset/
  rng). A misspelled "model" must not silently drop the component and turn a
  normal run into a weights-external one (no-silent-fallback). A *known* key
  a given run doesn't produce is still allowed and ignored.
- Model-less checkpoints are now marked explicitly: the CheckpointManager
  writes MODEL_EXCLUDED_MARKER when "model" is excluded, and
  validate_checkpoint accepts a model-less checkpoint ONLY with that marker.
  A normal checkpoint missing its weights (partial/corrupt save) without the
  marker stays invalid, so discovery falls back to an older complete one
  rather than selecting the partial and failing at load (replaces the
  broader manifest-present acceptance).

Tests extended: unknown-key raises; known-but-unproduced ignored; marker
present -> valid, removed -> invalid.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ion)

Per project policy: permanent docs on dev must not narrate revision-to-
revision deltas or PR history (dev accumulates WIP across many PRs and makes
no dev-only compatibility guarantee). Rewrite the diloco-architecture.md
"Checkpoint state selection + empty-meta construction" section as plain
present-tense documentation — drop the "PR #102 was scoped to… / this
follow-up made it real / as implemented" framing and the stale "manifest-only
checkpoints are valid" line (it's the explicit MODEL_EXCLUDED_MARKER now).

Also record the policy in CLAUDE.md: feature branches base on origin/dev,
main is the release branch, and docs describe the feature as-is (WIP design
docs may track PR continuity but must carry a TODO-remove).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Dataset position under DiLoCo is tracked by the server via work-units, not
the local dataloader, so a local dataset checkpoint would be stale/misleading
on resume. Exclude "dataset" from the DiLoCo checkpoint_components default
(now [optimizer, scheduler, trainer, rng]); update the arch doc to match and
explain why.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
trainer_options.md: add the checkpoint_components row + notes; update the
construct_model_on row and meta notes to the current behavior (default→meta
auto-route on resume, meta-init load contract, weights-external empty-meta).
user_guide.md: add a "Selecting Checkpoint Components" section and note the
weights-external (DiLoCo-style) mode; scope the "model always required"
caveat to runs that save model weights.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add a "Selecting active components (checkpoint_components)" subsection to the
distributed-checkpoint abstraction doc: get_active_state_components() filters
each trainer's full get_state_components() declaration at a single chokepoint
(CheckpointManager consumes the filtered set), with the known-key vocabulary
and typo guard. Documents the weights-external mode — excluding "model" skips
model save/load, writes MODEL_EXCLUDED_MARKER, and validate_checkpoint accepts
a model-less checkpoint only with that marker (partial/corrupt normal ones
stay invalid). Cross-links trainer_options.md + user_guide.md.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…uild

The previous change ran _initialize_params (rank-0 builds the full model on
CPU and distributes initialized parameters) for all weights-external
(DiLoCo) pipeline runs. That is the expensive last-resort construction path
— it can OOM rank-0 and defeats the purpose of supplying weights externally.

Under DiLoCo the parameter server owns the persistent weights and fills them
via the sync at register; the only state the sync does not carry is the
non-persistent buffers (e.g. RoPE inv_freq). So for the external case the
pipeline trainer now runs the per-stage initialize-missing pass
(_initialize_missing_after_load) — which recomputes exactly those buffers,
locally and cheaply — instead of _initialize_params. The garbage persistent
weights from to_empty() are overwritten by _apply_global_params before the
first forward.

From-scratch (no checkpoint, non-DiLoCo) still uses _initialize_params, and
normal resume still loads then initializes-missing, both unchanged. Arch doc
updated to match.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Unify the meta-construction "load -> init-missing" path so the external-
weights (DiLoCo) case stops doing a full init that the parameter sync then
overwrites.

- New non-HF callback event on_load_model_weights, dispatched from
  _restore_from_checkpoint at the point a checkpoint would load, when
  checkpoint_components excludes "model". The handler loads the weights from
  its external source and flags them (_is_hf_initialized); the trainer's
  following initialize-missing pass then fills only what wasn't provided
  (non-persistent buffers, e.g. RoPE inv_freq). _restore_from_checkpoint now
  runs for resume OR external, composing a non-model checkpoint load with the
  external weight load.
- DiLoCoCallback: the worker bring-up (registration + apply global params)
  moves from on_train_begin into _start_worker, invoked by
  on_load_model_weights; on_train_begin is an idempotent fallback. After
  start() it flags the server-provided state-dict keys (per-stage for
  pipeline) so the init pass skips them instead of clobbering them.
- _prepare_model no longer inits the empty skeleton for the external case
  (moved to the hook); PipelineTrainer skips _initialize_params for external
  (the expensive rank-0 build) — the hook + per-stage initialize-missing
  recompute only the buffers.

Removes the full-init-then-overwrite and the per-trainer external branches;
DiLoCo is now "a checkpoint load whose source is the server." Arch doc
updated. 433 diloco+checkpoint+meta tests pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…gle hook

Two safety checks at the external-load point (no-silent-fallback):
- _restore_from_checkpoint raises if checkpoint_components excludes "model"
  but no callback implements on_load_model_weights (BaseTrainer._has_event_handler).
- _verify_external_weights_loaded() runs AFTER the hook and BEFORE
  initialize-missing: every persistent tensor (state_dict, keep_vars) must be
  flagged _is_hf_initialized, i.e. actually loaded. Since initialize-missing
  fills anything unloaded, this is the only place to catch "the source loaded
  nothing" — otherwise it would silently random-init and train against garbage.
  Non-persistent buffers (RoPE) are excluded (init recomputes them).
- Both enumerate via a new overridable _materialized_modules() (base ->
  self.model; pipeline -> pipeline_modules), which also subsumes the pipeline
  _initialize_missing_after_load override.

DiLoCo callback simplified (it's forgather-only, so the hook always fires):
_start_worker merges into on_load_model_weights (no idempotency guard, single
entry point); on_train_begin becomes a defensive assert that the worker was
started, failing loud if the run is misconfigured (model not excluded).

Tests: handler-presence + weights-loaded verification (incl. non-persistent
buffer exclusion) in test_checkpoint_components.py; on_train_begin assert +
no-op-after-load in test_diloco_callback.py (call sites moved to
on_load_model_weights). 471 diloco+checkpoint+meta tests pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
DiLoCo trains better under a Warmup-Stable-Decay schedule (warmup + stable,
no fixed-length decay) than the default fixed cosine-decay, and the LR
scheduler state is saved with the worker checkpoint so it resumes cleanly.

- Move the WSD annealing math (annealing_tokens -> annealing_steps ->
  decay_start_step) and the --annealing-tokens arg up into the shared
  lm_training_project.yaml [globals], defaulting to NO decay phase
  (annealing_tokens=0 -> decay_start_step=-1). Remove the now-duplicated
  computation from finetune_v2.yaml and examples/pretrain/small-llm, which
  inherit it.
- Swap WSDScheduler into the DiLoCo example configs (base_lm/diloco,
  tiny_experiments/diloco {default,small}).
- wsd_scheduler.py: allow decay_steps=0 only when decay is disabled
  (decay_start_step < 0 and not start_decay); give both asserts actionable
  messages instead of bare failures.
- Fix a pre-existing typo in tiny_experiments/diloco/default.yaml:
  "Trenable_dilocoue" -> "enable_diloco" (the override knob was silently
  dead — it resolved to True only because the garbled name was undefined).
- Update WSDScheduler tests: decay_steps=0 is OK when decay is disabled,
  raises when start_decay is requested.

Reviewed for errors (inheritance, render, assert interactions) before
landing. All affected configs render clean; WSD + optim tests pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…s FSDP2)

General review round on PR #104 surfaced a critical regression and minor items.

- CRITICAL: model save/load was gated on `model_state_component is None`, which
  is also true for a normal FSDP2 (world_size>1) run — it registers no "model"
  StateComponent and saves/loads via model_save_fn/model_load_fn hooks. That
  gating silently dropped FSDP2 model weights and mislabeled the checkpoint
  model-less. Fix: gate on an explicit `model_weights_external` flag passed to
  CheckpointManager from the trainer's `_model_weights_external()` (moved to
  BaseTrainer). FSDP2 normal runs save/load again; DiLoCo still skips + marks.
  Wired at all three manager construction sites; +2 regression tests.
- DiLoCo callback now flags only the parameter names it actually syncs
  (named_parameters, matching ParamView.apply_global) rather than all
  state_dict keys — so _verify_external_weights_loaded still catches a
  persistent buffer the server never supplied instead of masking it. No change
  for shipped models (no persistent buffers).
- Fix stale on_train_begin references in diloco_callback (error message,
  load_state_dict docstring, comments) after the hook rename.

NB: finetune_v2 --start-annealing requiring an explicit annealing_tokens is
intentional (not restored); the relaxed WSD assert now gives an actionable
message pointing at it.

653 diloco+checkpoint+meta+optim tests pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jdinalt jdinalt merged commit 050f9f2 into dev May 31, 2026
1 check passed
@jdinalt jdinalt deleted the feature/diloco-checkpoint-components branch May 31, 2026 00:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant