Skip to content

trainer: standardize checkpoint loading on the HF v5 meta contract#100

Merged
jdinalt merged 6 commits into
devfrom
feature/diloco-meta-construction
May 30, 2026
Merged

trainer: standardize checkpoint loading on the HF v5 meta contract#100
jdinalt merged 6 commits into
devfrom
feature/diloco-meta-construction

Conversation

@jdinalt
Copy link
Copy Markdown
Owner

@jdinalt jdinalt commented May 29, 2026

Summary

Second of the DiLoCo-foundations PRs. Unifies how all trainers load a checkpoint onto a model on the HF v5 standard: construct on meta → materialize empty on device → load weights → initialize only the tensors the checkpoint did not provide (e.g. non-persistent RoPE buffers). This replaces two divergent, pre-v5 mechanisms:

  • the non-pipeline trainers' construct-on-device + no_init_weights() hack, and
  • the Pipeline Trainer's bespoke missing_buffers + reset_parameters walk (which predates v5 and didn't generalize to pure HF hub models).

From-scratch construction is unchanged (on-device, the model's own init). DiLoCo is just a remote checkpoint load, so it rides this same path (wired in the next PR).

Changes

Model-sidemodelsrc/transformer/init_weights.py

  • Add module_is_initialized; fix the _is_hf_initialized polarity (an unflagged tensor is freshly materialized and is initialized — matching HF v5; previously the default skipped unflagged tensors). Make simple_weight_init and init_weights_by_regex flag-aware so loaded tensors are skipped.

Loadersrc/forgather/ml/sharded_checkpoint.py

  • flag_loaded_tensors(module, loaded_keys) marks loaded tensors _is_hf_initialized=True after both the single-file and sharded load paths.
  • initialize_missing_weights(module) is the shared pass: model.apply(model._init_weights) gated so it only touches modules that still own an unflagged tensor. This gating means loaded weights are never clobbered even by models carrying older bundled init code (caught by the integration test below). Safe non-persistent-buffer reset fallback for split pipeline-stage modules that lack _init_weights.

Trainers

  • trainer.py: checkpoint loading routes through meta (default+resume with a model_init → meta); post-load runs initialize_missing_weights gated on _constructed_on_meta; device+no_init_weights kept as a fallback for models that can't build on meta. Docstrings updated.
  • pipeline_trainer.py: its resume path now calls the same shared initialize_missing_weights — both trainers use one method.

Docsdocs/configuration/model-initialization.md documents the trainer/loader application of the contract.

Testing

  • tests/unit/ml/test_meta_init_contract.py — 12 unit tests: flag polarity, flag-aware skip/init, flag_loaded_tensors, both initialize_missing_weights paths.
  • tests/unit/ml/test_meta_checkpoint_load.pyreal-model end-to-end on models/tiny: construct-on-meta → materialize → load → init-missing → forward; asserts loaded weights preserved, RoPE inv_freq recomputed finite, forward runs. This test caught a real clobbering bug (stale bundled init re-initializing loaded embeddings), now guarded by the gated dispatch.
  • Full tests/unit/ml/ + tests/unit/forgather_server/ green (2737 passed). forgather ls clean.

Notes for review

  • Trainer._initialize_non_persistent_buffers is now unused (both call sites replaced by initialize_missing_weights) — left in place; candidate for removal.
  • A full CLI training smoke against a real project (beyond the in-process models/tiny integration test) is worth a manual pass — happy to run one.
  • Existing materialized model dirs (e.g. models/tiny) carry older bundled init_weights.py; the gated dispatch makes the contract correct for them regardless, but regenerating them would pick up the new flag-aware init source.

🤖 Generated with Claude Code

jdinalt and others added 2 commits May 29, 2026 22:46
Unify how all trainers load a checkpoint onto a model: construct on the
meta device, materialize empty on the target device, load the weights,
then initialize only the tensors the checkpoint did NOT provide (e.g.
non-persistent RoPE buffers). Replaces the non-pipeline trainers'
construct-on-device + no_init_weights hack and the pipeline trainer's
pre-v5 bespoke buffer-reset walk with one shared mechanism that works for
both forgather and pure-HF models, and for remote checkpoints (DiLoCo).

From-scratch construction is unchanged (on-device, the model's own init).

- modelsrc/transformer/init_weights.py: add module_is_initialized; fix
  the _is_hf_initialized polarity (unmarked tensor => initialize, per HF
  v5); make simple_weight_init and init_weights_by_regex flag-aware so
  loaded tensors are skipped.
- sharded_checkpoint.py: flag_loaded_tensors marks loaded keys after both
  the single-file and sharded load paths; initialize_missing_weights runs
  model.apply(model._init_weights) but GATES the dispatch to modules that
  still own an unflagged tensor — so loaded weights are never clobbered
  even by models carrying older bundled init code — with a safe
  non-persistent-buffer reset fallback for split pipeline-stage modules.
- trainer.py: checkpoint loading routes through meta (default+resume with
  a model_init => meta); post-load runs initialize_missing_weights gated
  on _constructed_on_meta; device+no_init_weights kept as a fallback;
  docstrings updated. (_initialize_non_persistent_buffers is now unused.)
- pipeline_trainer.py: its resume path now calls the shared
  initialize_missing_weights — both trainers use one method.
- docs/configuration/model-initialization.md: document the trainer/loader
  application of the contract.

Tests: 12 contract unit tests + a real-model end-to-end integration test
against models/tiny (construct-on-meta -> materialize -> load ->
init-missing -> forward). Full tests/unit/ml + tests/unit/forgather_server
green (2737 passed). The integration test caught (and now guards against)
a clobbering bug where stale bundled init re-initialized loaded weights.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Review of PR #100 surfaced a real regression plus cleanups:

- FIX (regression): PipelineTrainer overrides _prepare_model but inherits
  the base _prepare post-load block, which now reads self._constructed_on_meta.
  The pipeline override never set it, so any pipeline RESUME would
  AttributeError. Give Trainer a class-level _constructed_on_meta = False
  default so the attribute is always defined; add a regression test.
- Remove the now-dead Trainer._initialize_non_persistent_buffers (both call
  sites were replaced by initialize_missing_weights).
- Perf: in the sharded loader, flag loaded tensors ONCE after the shard
  loop (using the module's full loaded-key set) instead of re-scanning the
  whole module per shard — was O(shards x params).
- Docs: note that initialize_missing_weights gates per-module (a module
  co-locating a loaded param with an unflagged buffer relies on the model's
  _init_weights being per-tensor flag-aware; forgather keeps RoPE buffers in
  dedicated modules); note that construct_model_on="device" does not run the
  init-missing pass on resume.

Tests: tests/unit/ml green (2049 passed) incl. the new class-default guard.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jdinalt
Copy link
Copy Markdown
Owner Author

jdinalt commented May 29, 2026

Local review pass (multi-agent) — addressed

Ran a high-effort local review. Outcome:

Fixed in deebd335:

  • Regression (real): PipelineTrainer overrides _prepare_model but inherits the base _prepare post-load block, which now reads self._constructed_on_meta — the pipeline override never set it, so any pipeline resume would AttributeError. Fixed with a class-level _constructed_on_meta = False default + a regression test. (Untested in CI, which is why the suite stayed green.)
  • Dead code: removed Trainer._initialize_non_persistent_buffers (both call sites replaced by initialize_missing_weights).
  • Perf: sharded loader now flags loaded tensors once after the shard loop instead of re-scanning the whole module per shard (was O(shards×params)).
  • Docs: documented the per-module gating caveat on initialize_missing_weights and that construct_model_on="device" skips the init-missing pass on resume.

Verified non-issues / pre-existing (noted, not changed):

  • DiLoCo worker not regressed — it stays on the on-device path this PR; meta wiring is the next PR.
  • The _is_hf_initialized attribute is inert for load-then-save/serve callers (not part of state_dict).
  • init_weights_by_regex's all-or-nothing raise on a partial checkpoint with an init_prefix module is a pre-existing edge case, unchanged by the polarity flip.

Known limitation (documented): the module-granularity gate means a module co-locating a loaded param with an unflagged (not-in-checkpoint) buffer relies on the model's _init_weights being per-tensor flag-aware (HF's is; forgather's init_weights_by_regex is). Forgather keeps RoPE buffers in dedicated buffer-only modules, so it isn't triggered.

Full tests/unit/ml green (2049 passed).

The init-missing pass must run AFTER load_checkpoint: load flags the
tensors it fills (flag_loaded_tensors) so init only touches the rest.
Running it first leaves everything unflagged and re-initializes the whole
model — the slow full-init the meta route exists to avoid — only to
overwrite it with the loaded weights.

The base trainer was already load-then-init, but the PipelineTrainer
called initialize_missing_weights inside _prepare_model, which runs
BEFORE the base _prepare's load_checkpoint — so the pipeline path inited
before loading. Restructure so both trainers share one post-load hook:

- Trainer._prepare: `if resume: self._restore_from_checkpoint()`, which
  does load_checkpoint() THEN (if _constructed_on_meta)
  self._initialize_missing_after_load().
- _initialize_missing_after_load() is overridable: base inits self.model;
  PipelineTrainer overrides it to init self.pipeline_modules (its
  materialized stages; self.model is the meta skeleton) and sets
  _constructed_on_meta=True in its _prepare_model. Removed the premature
  init from the pipeline resume branch (+ now-unused missing_buffers
  import).

Tests: ordering unit test (_restore_from_checkpoint load-before-init +
meta gate) and a real-model spy test (models/tiny) asserting _init_weights
is invoked only on unflagged RoPE modules after load, never on loaded
Linear/Embedding — fails if init runs before load. tests/unit/ml green
(2052 passed).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jdinalt
Copy link
Copy Markdown
Owner Author

jdinalt commented May 30, 2026

Ordering fix (reviewer-caught) — 92191cf4

Correct catch: initialize_missing_weights must run after load_checkpoint — otherwise nothing is flagged yet and the entire model gets re-initialized (the slow full-init the meta route exists to avoid), only to be overwritten by the load.

The base trainer was already load-then-init, but the PipelineTrainer called the init inside _prepare_model, which runs before the base _prepare's load_checkpoint — so the pipeline path initialized before loading.

Restructured so both trainers share one post-load hook:

  • Trainer._prepare: if resume: self._restore_from_checkpoint()load_checkpoint() then (if _constructed_on_meta) self._initialize_missing_after_load().
  • _initialize_missing_after_load() is overridable: base inits self.model; PipelineTrainer overrides it to init self.pipeline_modules (its materialized stages; self.model is the meta skeleton) and sets _constructed_on_meta=True in its _prepare_model. Removed the premature init from the pipeline resume branch.

Ordering tests added (per request):

  • test_restore_from_checkpoint_loads_before_initializing — asserts call order [load, init] (+ a sibling for the not-meta skip).
  • test_init_missing_skips_loaded_modules_after_load (real models/tiny) — spies on _init_weights, asserts it fires only on unflagged RoPE modules after load, never on loaded Linear/Embedding. Fails if init runs before load.

tests/unit/ml green (2052 passed).

Audited every sharded_checkpoint load call site against the new HF-v5
"load then initialize-missing" contract. Two real fixes; everything else
verified safe (real-device construct, load-then-save, module=None raw
dict, or the already-fixed trainer paths).

- sharded_checkpoint.py module docstring: the "primary use-case" meta
  example omitted create_sharing_metadata/retie_parameters around
  to_empty and initialize_missing_weights after the load — so it taught a
  workflow that leaves RoPE inv_freq as garbage. Corrected to the full
  contract.
- model_conversion/hf_converter.py: builds the source model under
  no_init_weights() then runs a logit-comparison forward pass. Its
  non-persistent RoPE buffers were never initialized (no_init only no-ops
  torch.nn.init.*; inv_freq is computed in reset_parameters), so the
  source logits were garbage and the FG->HF equivalence check was
  meaningless. Added initialize_missing_weights(src_model) after the load.
  (Pre-existing latent bug, surfaced by the audit.)
- docs/checkpointing/sharded_checkpoint_api.md: all module-mode load
  examples now show to_empty -> load_checkpoint -> initialize_missing_weights,
  plus reference entries for initialize_missing_weights / flag_loaded_tensors.

Verified OK/NA (no change): inference_server (real-device from_config, not
meta), update_model & finalize (load-then-save, no forward), cli/model
(asserts non-meta), diloco server & fsdp2 (module=None raw dict), and all
trainer/FSDP2/pipeline paths (already load-then-init on this branch).

Tests: test_model_conversion (114) + meta contract/integration (18) green.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jdinalt
Copy link
Copy Markdown
Owner Author

jdinalt commented May 30, 2026

Audit: sharded_checkpoint load sites (99aba931)

Audited every sharded_checkpoint load call site in the repo against the new "load → initialize_missing_weights" contract (17 sites). Two real fixes; all others verified safe.

Fixed:

  • Module docstring (sharded_checkpoint.py): the canonical "primary use-case" meta example omitted create_sharing_metadata/retie_parameters + initialize_missing_weights — it taught a workflow that leaves RoPE inv_freq as to_empty garbage. Corrected to the full contract.
  • model_conversion/hf_converter.py (pre-existing latent bug): builds the source model under no_init_weights() then runs a logit-comparison forward. no_init_weights only no-ops torch.nn.init.*; RoPE inv_freq is computed in reset_parameters, so the source logits were garbage and the FG→HF equivalence check was meaningless. Added initialize_missing_weights(src_model) after the load.
  • docs/checkpointing/sharded_checkpoint_api.md: all module-mode load examples now show to_empty → load_checkpoint → initialize_missing_weights, plus reference entries for the two new functions.

Verified OK/NA (no change): inference_server/service.py (real-device from_config, not meta — so RoPE computed in __init__), update_model + finalize (load-then-save, no forward), cli/model.py (hard-asserts non-meta), diloco/server.py + FSDP2 (module=None raw-dict reads), and all trainer/FSDP2/pipeline paths (already load-then-init on this branch).

Tests: test_model_conversion (114) + meta contract/integration (18) green.

jdinalt and others added 2 commits May 30, 2026 07:41
The resume branch was `if resume: pass / else: <from-scratch init>`.
Invert to `if not resume: <from-scratch init>` and fold the resume-case
explanation into the comment. Behavior unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Second (delta-scoped) review pass confirmed the post-first-review changes
are sound; two small fixes:

- docs/checkpointing/sharded_checkpoint_api.md: the meta-construct load
  examples use torch.device("meta") but didn't import torch — a verbatim
  copy-paste would NameError. Added `import torch`.
- pipeline_trainer._initialize_missing_after_load docstring: clarified that
  split stage modules take initialize_missing_weights' fallback path (no
  HF _init_weights), so it resets only non-persistent-buffer modules; the
  load does flag loaded tensors (load_checkpoint over model_parts), which
  would also protect the apply path, but the fallback is what runs.

Review notes (no code change): the "device"+resume path not running the
init-missing pass (RoPE buffers from no_init_weights) is pre-existing and
not introduced here — tracked separately. The pipeline load goes through
forgather load_checkpoint (which flags), NOT FSDP2's DCP set_model_state_dict.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jdinalt jdinalt merged commit cc41e7d into dev May 30, 2026
1 check passed
@jdinalt jdinalt deleted the feature/diloco-meta-construction branch May 30, 2026 07:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant