diff --git a/docs/design/diloco-work-unit-dispatch.md b/docs/design/diloco-work-unit-dispatch.md index 8a944c091..f5f481c2c 100644 --- a/docs/design/diloco-work-unit-dispatch.md +++ b/docs/design/diloco-work-unit-dispatch.md @@ -543,30 +543,31 @@ either: ### DiLoCo server restart -_Implementation note: the proposal here described persisted queue -state across server restarts. The first live bringup uncovered a -cross-experiment hazard with that design — stale queues from a -previous run lingered in the persisted state and surfaced as -ghost queues on `/work/queues` after the operator switched -datasets. The implementation reverses the proposal: `_work_queues` -is **not** written to `server_state.pt`, and a pre-#46 file -containing them is loaded with a warning + dropped. The trade is -broader: on a server restart **all** in-flight + already-completed -units in the live epoch return to "available" and workers re-issue -from scratch on first contact. Justifies the cost as: server -restarts are rare; the cross-experiment ghost-queue surface was a -weekly footgun. The 409-on-length-mismatch guard described in the -next section still protects against the worst form of dataset -mismatch._ - -- ~~Queue state persisted in the server's checkpoint~~ — dropped - (see note above). The server doesn't checkpoint queues at all - now. -- On restart: bitmaps are reset. Workers re-register their datasets - on first contact (the wrap calls `/datasets/register` lazily on - the iterator's first request). The queue map is reconstructed - fresh; previously-trained rows within the epoch may be re-issued. - Crash recovery beyond "the queue resets" is out of scope. +The server is the authority for which rows have been consumed, so queue +state is persisted with its checkpoint and restored on restart (#105) — +the worker deliberately keeps **no** dataset-progress state of its own, +relying on the server to track and persist it. + +- **Persisted in `server_state.pt`:** the per-`(dataset_id, shuffle_seed)` + `issued`/`completed` bitmaps + counters, and `_dataset_lengths` (the + first-registered row count per `dataset_id`). Queues are keyed on disk by + a `"dataset_id|seed"` string (tuple keys don't round-trip) and bitmaps + are stored as `bytes`. +- **On restart:** `load_state` rehydrates `_work_queues`. A worker + re-registering its dataset (the wrap calls `/datasets/register` lazily on + the iterator's first request) hits the *restored* queue — + `_handle_register_dataset` reuses any existing queue for the key — so + already-issued units stay issued and issuance resumes at the next + un-issued unit, not from 0. +- **Cross-experiment safety (the original ghost-queue worry):** a changed + dataset hashes to a different `dataset_id` → a fresh `(dataset_id, seed)` + key → any stale queue from a prior dataset is simply never matched and + sits inert. The 409-on-length-mismatch guard catches a same-id/different- + length config. An operator wanting a hard reset restarts from the model + weights and purges the rest of `output_dir`. +- **Save cadence:** queues are written on the normal `save_every_n_rounds` + cadence and flushed once more on graceful shutdown (SIGINT/SIGTERM), so a + clean stop doesn't lose units issued since the last autosave. ### Worker dataset mismatch diff --git a/docs/trainers/diloco.md b/docs/trainers/diloco.md index fbd3d743d..2c00b0f8f 100644 --- a/docs/trainers/diloco.md +++ b/docs/trainers/diloco.md @@ -967,11 +967,13 @@ under DiLoCo unless you have a measured reason otherwise. If a worker dies holding an issued unit, that unit is lost (the server's one-way issuance design — at most `N_workers` units lost per -epoch). The DiLoCo server's `_work_queues` is **not** persisted across -server restarts (pre-#46 it was, but cross-experiment state-bleed from -a stale checkpoint outweighed crash-recovery utility). On server -restart, workers re-register their datasets on first contact and the -queue map is reconstructed fresh. +epoch). The DiLoCo server's `_work_queues` **is** persisted with its +checkpoint and restored on restart, so a server bounce does not re-issue +already-consumed rows within the epoch: a re-registering worker resumes +at the next un-issued unit. A changed dataset hashes to a new +`dataset_id`, so stale queues from a prior dataset are never matched +(no cross-experiment bleed); the queue is flushed on graceful shutdown +as well as the periodic save cadence. Design details: `docs/design/diloco-work-unit-dispatch.md`. diff --git a/src/forgather/ml/diloco/server.py b/src/forgather/ml/diloco/server.py index 0d55efd1f..7c66a5866 100644 --- a/src/forgather/ml/diloco/server.py +++ b/src/forgather/ml/diloco/server.py @@ -3021,22 +3021,40 @@ def save_state(self, path: Optional[str] = None): checkpoint_path, self.get_global_params(), safetensors=self.safetensors ) - # Work-queue state is intentionally NOT persisted (#46). Earlier - # versions rode the per-queue bitmap into server_state.pt so a - # server restart preserved issuance — the rationale was crash - # recovery within a run. In practice the more common pattern - # was "operator changed the dataset, output_dir stayed the - # same" → workers register a fresh dataset_id, but the old - # queue lingers in /work/queues with a stale hint.length and - # an unrecognized dataset_id key. The 2026-05-26 bringup - # chased exactly this confusion. - # - # Trade: on a server crash, in-flight units (≤ N_workers, one - # per worker) become re-issuable. That matches the design's - # accepted worker-death budget. Workers re-register their - # datasets on startup anyway, so the server reconstructs the - # queue map on demand. Operators who really want mid-epoch - # resume can keep their own queue snapshot. + # Work-unit dispatch state IS persisted: the server is the authority + # for which rows each worker has consumed, so a restart must not + # re-issue already-trained units within an epoch (#105). Snapshot the + # per-(dataset_id, shuffle_seed) issued/completed bitmaps under the + # lock (concurrent issuance must not mutate them mid-serialization), + # plus the per-dataset length snapshot used for the registration + # integrity check. Disk form keys queues by a "dataset_id|seed" + # string (tuple keys don't round-trip cleanly) and stores bytes, not + # bytearray. A dataset change hashes to a different dataset_id → a + # fresh key → any stale queue is simply never matched and sits inert + # (an operator wanting a hard reset restarts from the model weights + # and purges the rest). + with self._work_queues_lock: + work_queues = { + f"{ds}|{seed}": { + "total_units": q.total_units, + "issued": bytes(q.issued), + "completed": bytes(q.completed), + "hint_length": q.hint_length, + "issued_count": q.issued_count, + "completed_count": q.completed_count, + "by_worker": {w: dict(c) for w, c in q.by_worker.items()}, + "dataset_path": q.dataset_path, + "dataset_name": q.dataset_name, + "dataset_split": q.dataset_split, + "dataset_revision": q.dataset_revision, + "dataset_data_files": ( + list(q.dataset_data_files) if q.dataset_data_files else None + ), + } + for (ds, seed), q in self._work_queues.items() + } + dataset_lengths = dict(self._dataset_lengths) + # Snapshot the known-worker roster under the lock so a concurrent # registration can't mutate the dict mid-serialization (issue #103 # follow-up). Persisting it here lets a restarted server offer the @@ -3052,6 +3070,8 @@ def save_state(self, path: Optional[str] = None): "async_mode": self.async_mode, "total_submissions": self._total_submissions, "known_workers": known_workers, + "work_queues": work_queues, + "dataset_lengths": dataset_lengths, } torch.save(server_state, os.path.join(checkpoint_path, "server_state.pt")) @@ -3202,21 +3222,66 @@ def load_state(self, checkpoint_path: Optional[str] = None): # repopulates as workers register. self._known_workers = server_state.get("known_workers", {}) or {} - # Work-queue state is no longer persisted (#46). For - # backward-compat with pre-#46 checkpoints, surface a - # warning if any work-queue entries are present — they're - # being silently ignored, and the operator should know in - # case they were expecting mid-epoch resume from this - # checkpoint. - legacy_queues = server_state.get("work_queues") or {} - if legacy_queues: - logger.warning( - "Ignoring %d work-queue entr%s in legacy server_state.pt — " - "work-queue persistence was removed in #46. Workers will " - "re-register their datasets on connect; any in-flight " - "units from the prior run will be re-issued.", - len(legacy_queues), - "y" if len(legacy_queues) == 1 else "ies", + # Restore work-unit dispatch state (#105): the per-(dataset_id, + # shuffle_seed) issued/completed bitmaps and the per-dataset + # length snapshot. A worker re-registering its dataset hits the + # restored queue (``_handle_register_dataset`` reuses any existing + # queue for the key), so already-issued units stay issued and are + # not re-handed-out — the server is the authority for consumed + # rows. Absent on pre-feature checkpoints → start empty (today's + # rebuild-on-reregister behavior, which re-issues from unit 0). + self._dataset_lengths = server_state.get("dataset_lengths", {}) or {} + restored = server_state.get("work_queues") or {} + self._work_queues = {} + for qkey, d in restored.items(): + # Skip-and-warn on ANY bad entry (malformed key, missing/typed + # field, or a bitmap whose length disagrees with total_units) + # rather than letting one corrupt/partial entry abort the whole + # load and brick a restart. A length mismatch would otherwise + # surface much later as an IndexError during issuance. + try: + ds, seed_str = qkey.rsplit("|", 1) + seed = int(seed_str) + total_units = int(d["total_units"]) + nbytes = (total_units + 7) // 8 + issued = bytearray(d["issued"]) + completed = bytearray(d.get("completed") or bytes(nbytes)) + if len(issued) != nbytes or len(completed) != nbytes: + raise ValueError( + f"bitmap length {len(issued)}/{len(completed)} " + f"!= expected {nbytes} for total_units={total_units}" + ) + queue = WorkQueue( + total_units=total_units, + issued=issued, + completed=completed, + hint_length=int(d["hint_length"]), + issued_count=int(d.get("issued_count", 0)), + completed_count=int(d.get("completed_count", 0)), + by_worker=d.get("by_worker") or {}, + dataset_path=d.get("dataset_path"), + dataset_name=d.get("dataset_name"), + dataset_split=d.get("dataset_split"), + dataset_revision=d.get("dataset_revision"), + dataset_data_files=d.get("dataset_data_files"), + ) + except (ValueError, AttributeError, KeyError, TypeError) as exc: + logger.warning( + "Skipping malformed work-queue entry %r in " + "server_state.pt: %s", + qkey, + exc, + ) + continue + self._work_queues[(ds, seed)] = queue + + if self._work_queues: + issued_total = sum(q.issued_count for q in self._work_queues.values()) + logger.info( + "Restored %d work queue(s) from checkpoint (%d unit(s) " + "already issued); re-registering workers resume mid-epoch.", + len(self._work_queues), + issued_total, ) logger.info( @@ -3334,14 +3399,37 @@ def run(self): self._start_health_monitor() self._start_bulk_listener() + # Flush state on SIGTERM (how the forgather_server scheduler stops a + # server job) as well as SIGINT (Ctrl-C), so a webui-triggered stop + # doesn't lose rounds / issued work-units since the last autosave + # (#105). The handler just raises KeyboardInterrupt to break + # serve_forever; the single save below in `finally` then covers every + # exit path. signal.signal only works on the main thread (run() is + # the blocking CLI entrypoint); fall back gracefully otherwise. + import signal as _signal + + def _raise_interrupt(signum, frame): + raise KeyboardInterrupt + + prev_sigterm = None + try: + prev_sigterm = _signal.signal(_signal.SIGTERM, _raise_interrupt) + except (ValueError, OSError): + pass # not on the main thread — SIGTERM stays default + try: self._server.serve_forever() except KeyboardInterrupt: - logger.info("Server interrupted by Ctrl-C") - if self.save_every_n_rounds > 0: - logger.info("Saving server state before shutdown...") - self.save_state() + logger.info("Server interrupted (signal) — shutting down") finally: + if self.save_every_n_rounds > 0: + try: + logger.info("Saving server state before shutdown...") + self.save_state() + except Exception as exc: + logger.error("Failed to save server state on shutdown: %s", exc) + if prev_sigterm is not None: + _signal.signal(_signal.SIGTERM, prev_sigterm) self._stop_health_monitor() self._stop_bulk_listener() self._running = False @@ -3383,6 +3471,16 @@ def stop(self): """Stop the background server.""" if not self._running: raise RuntimeError("Stop cannot be called, unless we are already running.") + # Flush state before teardown so a graceful stop doesn't lose the + # rounds / issued work-units accumulated since the last autosave + # (#105). save_state no-ops when clean; never let a save failure + # block shutdown. Mirrors the run() Ctrl-C path. + if self.save_every_n_rounds > 0: + try: + logger.info("Saving server state before shutdown...") + self.save_state() + except Exception as exc: + logger.error("Failed to save server state on stop: %s", exc) self._stop_health_monitor() self._stop_bulk_listener() if self._server: diff --git a/tests/unit/ml/diloco/test_server_work_queue.py b/tests/unit/ml/diloco/test_server_work_queue.py index a9600371f..071f27f95 100644 --- a/tests/unit/ml/diloco/test_server_work_queue.py +++ b/tests/unit/ml/diloco/test_server_work_queue.py @@ -330,17 +330,14 @@ def test_queue_by_worker_counters(self, server): class TestPersistence: - def test_work_queues_NOT_persisted_across_save_load(self, tmp_path): - """Regression for #46. - - Work-queue persistence was removed because it created a - cross-experiment hazard: same output_dir, new dataset → - old queues with stale dataset_ids and stale hint.length lingered - in /work/queues, confusing the operator. Workers re-register - their datasets on connect, so the server reconstructs the - queue map fresh on demand. The cost is bounded re-work after - a server crash (≤ N_workers in-flight units), which matches - the existing worker-death budget.""" + def test_work_queues_persisted_and_restored(self, tmp_path): + """The server is the authority for which rows have been consumed, + so its per-(dataset_id, shuffle_seed) issued/completed bitmaps and + the per-dataset length snapshot MUST round-trip through + save_state/load_state (#105). Without this a restart re-issues + already-trained units within the epoch.""" + import torch as _torch + sd = _state_dict() ckpt = make_initial_checkpoint(sd, tmp_path) s = DiLoCoServer( @@ -356,7 +353,7 @@ def test_work_queues_NOT_persisted_across_save_load(self, tmp_path): c = DiLoCoClient(f"localhost:{s.port}", timeout=10) c.register_dataset("w0", "ds-1", 42, {"length": 5000}) u0 = c.request_work("w0", "ds-1", 42)["unit_id"] - c.request_work("w0", "ds-1", 42) + c.request_work("w0", "ds-1", 42) # issue a second unit c.complete_work("w0", "ds-1", 42, u0) save_dir = str(tmp_path / "saved") s._dirty = True @@ -364,8 +361,17 @@ def test_work_queues_NOT_persisted_across_save_load(self, tmp_path): finally: s.stop() - # Fresh server: queues come up empty, ready for fresh - # registrations from workers. + # The persisted file carries the queue + length snapshot. + loaded = _torch.load( + os.path.join(save_dir, "server_state.pt"), + map_location="cpu", + weights_only=False, + ) + assert "work_queues" in loaded and "dataset_lengths" in loaded + assert loaded["dataset_lengths"]["ds-1"] == 5000 + assert "ds-1|42" in loaded["work_queues"] + + # A fresh server restores the queue with issuance intact. s2 = DiLoCoServer( output_dir=str(tmp_path), from_checkpoint=save_dir, @@ -373,34 +379,61 @@ def test_work_queues_NOT_persisted_across_save_load(self, tmp_path): port=0, default_work_units=8, ) - assert s2._work_queues == {} - assert s2._dataset_lengths == {} - - # And the persisted file no longer contains a work_queues key - # at all — confirm via direct inspection so a future regression - # that re-adds the key without restoring the load path also - # fails this test. - import torch as _torch + assert s2._dataset_lengths == {"ds-1": 5000} + q = s2._work_queues[("ds-1", 42)] + assert q.issued_count == 2 # both issued units survived + assert q.completed_count == 1 # the completed one survived + assert q.hint_length == 5000 + # Bits 0 and 1 are set in the restored issued bitmap. + assert q.issued[0] & 0b11 == 0b11 + + def test_restored_queue_resumes_issuance(self, tmp_path): + """After a restart, a worker re-registering its dataset reuses the + restored queue and is handed the NEXT un-issued unit — it does not + restart the epoch from unit 0 (#105).""" + sd = _state_dict() + ckpt = make_initial_checkpoint(sd, tmp_path) + s = DiLoCoServer( + output_dir=str(tmp_path), + from_checkpoint=ckpt, + num_workers=1, + port=0, + default_work_units=8, + ) + s.start() + time.sleep(0.2) + try: + c = DiLoCoClient(f"localhost:{s.port}", timeout=10) + c.register_dataset("w0", "ds-1", 42, {"length": 5000}) + issued = [c.request_work("w0", "ds-1", 42)["unit_id"] for _ in range(3)] + assert issued == [0, 1, 2] + save_dir = str(tmp_path / "saved") + s._dirty = True + s.save_state(save_dir) + finally: + s.stop() - loaded = _torch.load( - os.path.join(save_dir, "server_state.pt"), - map_location="cpu", - weights_only=False, + s2 = DiLoCoServer( + output_dir=str(tmp_path), + from_checkpoint=save_dir, + num_workers=1, + port=0, + default_work_units=8, ) - assert "work_queues" not in loaded - assert "dataset_lengths" not in loaded - - def test_legacy_checkpoint_with_work_queues_loads_with_warning( - self, tmp_path, caplog - ): - """A pre-#46 server_state.pt that contains a work_queues - entry must still load cleanly, but the operator should see a - clear warning that those queues are being dropped. Mid-epoch - resume that worked under the old behavior is intentionally - gone — workers re-register on connect, in-flight units get - re-issued (≤ N_workers extra units trained per crash).""" - import logging + s2.start() + time.sleep(0.2) + try: + c2 = DiLoCoClient(f"localhost:{s2.port}", timeout=10) + # Re-register the same dataset (what a relaunched worker does). + c2.register_dataset("w0", "ds-1", 42, {"length": 5000}) + # Issuance resumes at unit 3, not 0. + assert c2.request_work("w0", "ds-1", 42)["unit_id"] == 3 + finally: + s2.stop() + def test_malformed_work_queue_key_skipped(self, tmp_path): + """A work_queues entry whose key isn't 'dataset_id|seed' is skipped + (logged), not fatal — load_state must not crash the server.""" import torch as _torch sd = _state_dict() @@ -420,41 +453,79 @@ def test_legacy_checkpoint_with_work_queues_loads_with_warning( finally: s.stop() - # Inject a legacy-shaped work_queues entry into the saved - # server_state.pt, then reload and confirm the warning fires - # and the queues themselves are NOT rehydrated. sp = os.path.join(save_dir, "server_state.pt") ss = _torch.load(sp, map_location="cpu", weights_only=False) ss["work_queues"] = { - "stale-ds|0": { - "dataset_id": "stale-ds", - "shuffle_seed": 0, + "no-seed-separator": { "total_units": 8, - "issued": bytes(b"\x00"), - "completed": bytes(b"\x00"), - "hint_length": 12345, - "issued_count": 0, - "completed_count": 0, - "by_worker": {}, + "issued": bytes(1), + "completed": bytes(1), + "hint_length": 1, } } _torch.save(ss, sp) - with caplog.at_level(logging.WARNING, logger="forgather.ml.diloco.server"): - s2 = DiLoCoServer( - output_dir=str(tmp_path), - from_checkpoint=save_dir, - num_workers=1, - port=0, - default_work_units=8, - ) + s2 = DiLoCoServer( + output_dir=str(tmp_path), + from_checkpoint=save_dir, + num_workers=1, + port=0, + default_work_units=8, + ) + assert s2._work_queues == {} # malformed entry skipped, no crash + + def test_corrupt_work_queue_value_does_not_brick_startup(self, tmp_path): + """A work_queues entry with a valid key but a missing field, or a + bitmap whose length disagrees with total_units, is skipped (logged) + — one bad/partial entry must not abort the whole load and prevent a + restart. Good entries in the same map still load.""" + import torch as _torch - assert s2._work_queues == {} - assert any( - "work-queue" in rec.message for rec in caplog.records - ), "Expected a warning about ignored legacy work_queues, got: " + repr( - [rec.message for rec in caplog.records] + sd = _state_dict() + ckpt = make_initial_checkpoint(sd, tmp_path) + s = DiLoCoServer( + output_dir=str(tmp_path), + from_checkpoint=ckpt, + num_workers=1, + port=0, + default_work_units=8, + ) + s.start() + time.sleep(0.2) + try: + c = DiLoCoClient(f"localhost:{s.port}", timeout=10) + c.register_dataset("w0", "good-ds", 1, {"length": 100}) + c.request_work("w0", "good-ds", 1) + save_dir = str(tmp_path / "saved") + s._dirty = True + s.save_state(save_dir) + finally: + s.stop() + + sp = os.path.join(save_dir, "server_state.pt") + ss = _torch.load(sp, map_location="cpu", weights_only=False) + # Valid key, missing 'total_units'/'issued'. + ss["work_queues"]["partial-ds|0"] = {"hint_length": 5000} + # Valid key, bitmap length inconsistent with total_units. + ss["work_queues"]["badbitmap-ds|0"] = { + "total_units": 1024, # expects 128 bytes + "issued": bytes(1), # but only 1 + "completed": bytes(1), + "hint_length": 1, + } + _torch.save(ss, sp) + + s2 = DiLoCoServer( + output_dir=str(tmp_path), + from_checkpoint=save_dir, + num_workers=1, + port=0, + default_work_units=8, ) + # The two bad entries are skipped; the good one survives. + assert ("good-ds", 1) in s2._work_queues + assert ("partial-ds", 0) not in s2._work_queues + assert ("badbitmap-ds", 0) not in s2._work_queues def test_legacy_checkpoint_loads_with_empty_queue_map(self, tmp_path): """Checkpoints written before this feature landed have no