trainer: standardize checkpoint loading on the HF v5 meta contract#100
Conversation
Unify how all trainers load a checkpoint onto a model: construct on the meta device, materialize empty on the target device, load the weights, then initialize only the tensors the checkpoint did NOT provide (e.g. non-persistent RoPE buffers). Replaces the non-pipeline trainers' construct-on-device + no_init_weights hack and the pipeline trainer's pre-v5 bespoke buffer-reset walk with one shared mechanism that works for both forgather and pure-HF models, and for remote checkpoints (DiLoCo). From-scratch construction is unchanged (on-device, the model's own init). - modelsrc/transformer/init_weights.py: add module_is_initialized; fix the _is_hf_initialized polarity (unmarked tensor => initialize, per HF v5); make simple_weight_init and init_weights_by_regex flag-aware so loaded tensors are skipped. - sharded_checkpoint.py: flag_loaded_tensors marks loaded keys after both the single-file and sharded load paths; initialize_missing_weights runs model.apply(model._init_weights) but GATES the dispatch to modules that still own an unflagged tensor — so loaded weights are never clobbered even by models carrying older bundled init code — with a safe non-persistent-buffer reset fallback for split pipeline-stage modules. - trainer.py: checkpoint loading routes through meta (default+resume with a model_init => meta); post-load runs initialize_missing_weights gated on _constructed_on_meta; device+no_init_weights kept as a fallback; docstrings updated. (_initialize_non_persistent_buffers is now unused.) - pipeline_trainer.py: its resume path now calls the shared initialize_missing_weights — both trainers use one method. - docs/configuration/model-initialization.md: document the trainer/loader application of the contract. Tests: 12 contract unit tests + a real-model end-to-end integration test against models/tiny (construct-on-meta -> materialize -> load -> init-missing -> forward). Full tests/unit/ml + tests/unit/forgather_server green (2737 passed). The integration test caught (and now guards against) a clobbering bug where stale bundled init re-initialized loaded weights. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Review of PR #100 surfaced a real regression plus cleanups: - FIX (regression): PipelineTrainer overrides _prepare_model but inherits the base _prepare post-load block, which now reads self._constructed_on_meta. The pipeline override never set it, so any pipeline RESUME would AttributeError. Give Trainer a class-level _constructed_on_meta = False default so the attribute is always defined; add a regression test. - Remove the now-dead Trainer._initialize_non_persistent_buffers (both call sites were replaced by initialize_missing_weights). - Perf: in the sharded loader, flag loaded tensors ONCE after the shard loop (using the module's full loaded-key set) instead of re-scanning the whole module per shard — was O(shards x params). - Docs: note that initialize_missing_weights gates per-module (a module co-locating a loaded param with an unflagged buffer relies on the model's _init_weights being per-tensor flag-aware; forgather keeps RoPE buffers in dedicated modules); note that construct_model_on="device" does not run the init-missing pass on resume. Tests: tests/unit/ml green (2049 passed) incl. the new class-default guard. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Local review pass (multi-agent) — addressedRan a high-effort local review. Outcome: Fixed in
Verified non-issues / pre-existing (noted, not changed):
Known limitation (documented): the module-granularity gate means a module co-locating a loaded param with an unflagged (not-in-checkpoint) buffer relies on the model's Full |
The init-missing pass must run AFTER load_checkpoint: load flags the tensors it fills (flag_loaded_tensors) so init only touches the rest. Running it first leaves everything unflagged and re-initializes the whole model — the slow full-init the meta route exists to avoid — only to overwrite it with the loaded weights. The base trainer was already load-then-init, but the PipelineTrainer called initialize_missing_weights inside _prepare_model, which runs BEFORE the base _prepare's load_checkpoint — so the pipeline path inited before loading. Restructure so both trainers share one post-load hook: - Trainer._prepare: `if resume: self._restore_from_checkpoint()`, which does load_checkpoint() THEN (if _constructed_on_meta) self._initialize_missing_after_load(). - _initialize_missing_after_load() is overridable: base inits self.model; PipelineTrainer overrides it to init self.pipeline_modules (its materialized stages; self.model is the meta skeleton) and sets _constructed_on_meta=True in its _prepare_model. Removed the premature init from the pipeline resume branch (+ now-unused missing_buffers import). Tests: ordering unit test (_restore_from_checkpoint load-before-init + meta gate) and a real-model spy test (models/tiny) asserting _init_weights is invoked only on unflagged RoPE modules after load, never on loaded Linear/Embedding — fails if init runs before load. tests/unit/ml green (2052 passed). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Ordering fix (reviewer-caught) —
|
Audited every sharded_checkpoint load call site against the new HF-v5 "load then initialize-missing" contract. Two real fixes; everything else verified safe (real-device construct, load-then-save, module=None raw dict, or the already-fixed trainer paths). - sharded_checkpoint.py module docstring: the "primary use-case" meta example omitted create_sharing_metadata/retie_parameters around to_empty and initialize_missing_weights after the load — so it taught a workflow that leaves RoPE inv_freq as garbage. Corrected to the full contract. - model_conversion/hf_converter.py: builds the source model under no_init_weights() then runs a logit-comparison forward pass. Its non-persistent RoPE buffers were never initialized (no_init only no-ops torch.nn.init.*; inv_freq is computed in reset_parameters), so the source logits were garbage and the FG->HF equivalence check was meaningless. Added initialize_missing_weights(src_model) after the load. (Pre-existing latent bug, surfaced by the audit.) - docs/checkpointing/sharded_checkpoint_api.md: all module-mode load examples now show to_empty -> load_checkpoint -> initialize_missing_weights, plus reference entries for initialize_missing_weights / flag_loaded_tensors. Verified OK/NA (no change): inference_server (real-device from_config, not meta), update_model & finalize (load-then-save, no forward), cli/model (asserts non-meta), diloco server & fsdp2 (module=None raw dict), and all trainer/FSDP2/pipeline paths (already load-then-init on this branch). Tests: test_model_conversion (114) + meta contract/integration (18) green. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Audit: sharded_checkpoint load sites (
|
The resume branch was `if resume: pass / else: <from-scratch init>`. Invert to `if not resume: <from-scratch init>` and fold the resume-case explanation into the comment. Behavior unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Second (delta-scoped) review pass confirmed the post-first-review changes
are sound; two small fixes:
- docs/checkpointing/sharded_checkpoint_api.md: the meta-construct load
examples use torch.device("meta") but didn't import torch — a verbatim
copy-paste would NameError. Added `import torch`.
- pipeline_trainer._initialize_missing_after_load docstring: clarified that
split stage modules take initialize_missing_weights' fallback path (no
HF _init_weights), so it resets only non-persistent-buffer modules; the
load does flag loaded tensors (load_checkpoint over model_parts), which
would also protect the apply path, but the fallback is what runs.
Review notes (no code change): the "device"+resume path not running the
init-missing pass (RoPE buffers from no_init_weights) is pre-existing and
not introduced here — tracked separately. The pipeline load goes through
forgather load_checkpoint (which flags), NOT FSDP2's DCP set_model_state_dict.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Summary
Second of the DiLoCo-foundations PRs. Unifies how all trainers load a checkpoint onto a model on the HF v5 standard: construct on meta → materialize empty on device → load weights → initialize only the tensors the checkpoint did not provide (e.g. non-persistent RoPE buffers). This replaces two divergent, pre-v5 mechanisms:
no_init_weights()hack, andmissing_buffers+reset_parameterswalk (which predates v5 and didn't generalize to pure HF hub models).From-scratch construction is unchanged (on-device, the model's own init). DiLoCo is just a remote checkpoint load, so it rides this same path (wired in the next PR).
Changes
Model-side —
modelsrc/transformer/init_weights.pymodule_is_initialized; fix the_is_hf_initializedpolarity (an unflagged tensor is freshly materialized and is initialized — matching HF v5; previously the default skipped unflagged tensors). Makesimple_weight_initandinit_weights_by_regexflag-aware so loaded tensors are skipped.Loader —
src/forgather/ml/sharded_checkpoint.pyflag_loaded_tensors(module, loaded_keys)marks loaded tensors_is_hf_initialized=Trueafter both the single-file and sharded load paths.initialize_missing_weights(module)is the shared pass:model.apply(model._init_weights)gated so it only touches modules that still own an unflagged tensor. This gating means loaded weights are never clobbered even by models carrying older bundled init code (caught by the integration test below). Safe non-persistent-buffer reset fallback for split pipeline-stage modules that lack_init_weights.Trainers
trainer.py: checkpoint loading routes through meta (default+resume with amodel_init→ meta); post-load runsinitialize_missing_weightsgated on_constructed_on_meta;device+no_init_weightskept as a fallback for models that can't build on meta. Docstrings updated.pipeline_trainer.py: its resume path now calls the same sharedinitialize_missing_weights— both trainers use one method.Docs —
docs/configuration/model-initialization.mddocuments the trainer/loader application of the contract.Testing
tests/unit/ml/test_meta_init_contract.py— 12 unit tests: flag polarity, flag-aware skip/init,flag_loaded_tensors, bothinitialize_missing_weightspaths.tests/unit/ml/test_meta_checkpoint_load.py— real-model end-to-end onmodels/tiny: construct-on-meta → materialize → load → init-missing → forward; asserts loaded weights preserved, RoPEinv_freqrecomputed finite, forward runs. This test caught a real clobbering bug (stale bundled init re-initializing loaded embeddings), now guarded by the gated dispatch.tests/unit/ml/+tests/unit/forgather_server/green (2737 passed).forgather lsclean.Notes for review
Trainer._initialize_non_persistent_buffersis now unused (both call sites replaced byinitialize_missing_weights) — left in place; candidate for removal.models/tinyintegration test) is worth a manual pass — happy to run one.models/tiny) carry older bundledinit_weights.py; the gated dispatch makes the contract correct for them regardless, but regenerating them would pick up the new flag-aware init source.🤖 Generated with Claude Code