Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 25 additions & 24 deletions docs/design/diloco-work-unit-dispatch.md
Original file line number Diff line number Diff line change
Expand Up @@ -543,30 +543,31 @@ either:

### DiLoCo server restart

_Implementation note: the proposal here described persisted queue
state across server restarts. The first live bringup uncovered a
cross-experiment hazard with that design — stale queues from a
previous run lingered in the persisted state and surfaced as
ghost queues on `/work/queues` after the operator switched
datasets. The implementation reverses the proposal: `_work_queues`
is **not** written to `server_state.pt`, and a pre-#46 file
containing them is loaded with a warning + dropped. The trade is
broader: on a server restart **all** in-flight + already-completed
units in the live epoch return to "available" and workers re-issue
from scratch on first contact. Justifies the cost as: server
restarts are rare; the cross-experiment ghost-queue surface was a
weekly footgun. The 409-on-length-mismatch guard described in the
next section still protects against the worst form of dataset
mismatch._

- ~~Queue state persisted in the server's checkpoint~~ — dropped
(see note above). The server doesn't checkpoint queues at all
now.
- On restart: bitmaps are reset. Workers re-register their datasets
on first contact (the wrap calls `/datasets/register` lazily on
the iterator's first request). The queue map is reconstructed
fresh; previously-trained rows within the epoch may be re-issued.
Crash recovery beyond "the queue resets" is out of scope.
The server is the authority for which rows have been consumed, so queue
state is persisted with its checkpoint and restored on restart (#105) —
the worker deliberately keeps **no** dataset-progress state of its own,
relying on the server to track and persist it.

- **Persisted in `server_state.pt`:** the per-`(dataset_id, shuffle_seed)`
`issued`/`completed` bitmaps + counters, and `_dataset_lengths` (the
first-registered row count per `dataset_id`). Queues are keyed on disk by
a `"dataset_id|seed"` string (tuple keys don't round-trip) and bitmaps
are stored as `bytes`.
- **On restart:** `load_state` rehydrates `_work_queues`. A worker
re-registering its dataset (the wrap calls `/datasets/register` lazily on
the iterator's first request) hits the *restored* queue —
`_handle_register_dataset` reuses any existing queue for the key — so
already-issued units stay issued and issuance resumes at the next
un-issued unit, not from 0.
- **Cross-experiment safety (the original ghost-queue worry):** a changed
dataset hashes to a different `dataset_id` → a fresh `(dataset_id, seed)`
key → any stale queue from a prior dataset is simply never matched and
sits inert. The 409-on-length-mismatch guard catches a same-id/different-
length config. An operator wanting a hard reset restarts from the model
weights and purges the rest of `output_dir`.
- **Save cadence:** queues are written on the normal `save_every_n_rounds`
cadence and flushed once more on graceful shutdown (SIGINT/SIGTERM), so a
clean stop doesn't lose units issued since the last autosave.

### Worker dataset mismatch

Expand Down
12 changes: 7 additions & 5 deletions docs/trainers/diloco.md
Original file line number Diff line number Diff line change
Expand Up @@ -967,11 +967,13 @@ under DiLoCo unless you have a measured reason otherwise.

If a worker dies holding an issued unit, that unit is lost (the
server's one-way issuance design — at most `N_workers` units lost per
epoch). The DiLoCo server's `_work_queues` is **not** persisted across
server restarts (pre-#46 it was, but cross-experiment state-bleed from
a stale checkpoint outweighed crash-recovery utility). On server
restart, workers re-register their datasets on first contact and the
queue map is reconstructed fresh.
epoch). The DiLoCo server's `_work_queues` **is** persisted with its
checkpoint and restored on restart, so a server bounce does not re-issue
already-consumed rows within the epoch: a re-registering worker resumes
at the next un-issued unit. A changed dataset hashes to a new
`dataset_id`, so stale queues from a prior dataset are never matched
(no cross-experiment bleed); the queue is flushed on graceful shutdown
as well as the periodic save cadence.

Design details: `docs/design/diloco-work-unit-dispatch.md`.

Expand Down
168 changes: 133 additions & 35 deletions src/forgather/ml/diloco/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3021,22 +3021,40 @@ def save_state(self, path: Optional[str] = None):
checkpoint_path, self.get_global_params(), safetensors=self.safetensors
)

# Work-queue state is intentionally NOT persisted (#46). Earlier
# versions rode the per-queue bitmap into server_state.pt so a
# server restart preserved issuance — the rationale was crash
# recovery within a run. In practice the more common pattern
# was "operator changed the dataset, output_dir stayed the
# same" → workers register a fresh dataset_id, but the old
# queue lingers in /work/queues with a stale hint.length and
# an unrecognized dataset_id key. The 2026-05-26 bringup
# chased exactly this confusion.
#
# Trade: on a server crash, in-flight units (≤ N_workers, one
# per worker) become re-issuable. That matches the design's
# accepted worker-death budget. Workers re-register their
# datasets on startup anyway, so the server reconstructs the
# queue map on demand. Operators who really want mid-epoch
# resume can keep their own queue snapshot.
# Work-unit dispatch state IS persisted: the server is the authority
# for which rows each worker has consumed, so a restart must not
# re-issue already-trained units within an epoch (#105). Snapshot the
# per-(dataset_id, shuffle_seed) issued/completed bitmaps under the
# lock (concurrent issuance must not mutate them mid-serialization),
# plus the per-dataset length snapshot used for the registration
# integrity check. Disk form keys queues by a "dataset_id|seed"
# string (tuple keys don't round-trip cleanly) and stores bytes, not
# bytearray. A dataset change hashes to a different dataset_id → a
# fresh key → any stale queue is simply never matched and sits inert
# (an operator wanting a hard reset restarts from the model weights
# and purges the rest).
with self._work_queues_lock:
work_queues = {
f"{ds}|{seed}": {
"total_units": q.total_units,
"issued": bytes(q.issued),
"completed": bytes(q.completed),
"hint_length": q.hint_length,
"issued_count": q.issued_count,
"completed_count": q.completed_count,
"by_worker": {w: dict(c) for w, c in q.by_worker.items()},
"dataset_path": q.dataset_path,
"dataset_name": q.dataset_name,
"dataset_split": q.dataset_split,
"dataset_revision": q.dataset_revision,
"dataset_data_files": (
list(q.dataset_data_files) if q.dataset_data_files else None
),
}
for (ds, seed), q in self._work_queues.items()
}
dataset_lengths = dict(self._dataset_lengths)

# Snapshot the known-worker roster under the lock so a concurrent
# registration can't mutate the dict mid-serialization (issue #103
# follow-up). Persisting it here lets a restarted server offer the
Expand All @@ -3052,6 +3070,8 @@ def save_state(self, path: Optional[str] = None):
"async_mode": self.async_mode,
"total_submissions": self._total_submissions,
"known_workers": known_workers,
"work_queues": work_queues,
"dataset_lengths": dataset_lengths,
}
torch.save(server_state, os.path.join(checkpoint_path, "server_state.pt"))

Expand Down Expand Up @@ -3202,21 +3222,66 @@ def load_state(self, checkpoint_path: Optional[str] = None):
# repopulates as workers register.
self._known_workers = server_state.get("known_workers", {}) or {}

# Work-queue state is no longer persisted (#46). For
# backward-compat with pre-#46 checkpoints, surface a
# warning if any work-queue entries are present — they're
# being silently ignored, and the operator should know in
# case they were expecting mid-epoch resume from this
# checkpoint.
legacy_queues = server_state.get("work_queues") or {}
if legacy_queues:
logger.warning(
"Ignoring %d work-queue entr%s in legacy server_state.pt — "
"work-queue persistence was removed in #46. Workers will "
"re-register their datasets on connect; any in-flight "
"units from the prior run will be re-issued.",
len(legacy_queues),
"y" if len(legacy_queues) == 1 else "ies",
# Restore work-unit dispatch state (#105): the per-(dataset_id,
# shuffle_seed) issued/completed bitmaps and the per-dataset
# length snapshot. A worker re-registering its dataset hits the
# restored queue (``_handle_register_dataset`` reuses any existing
# queue for the key), so already-issued units stay issued and are
# not re-handed-out — the server is the authority for consumed
# rows. Absent on pre-feature checkpoints → start empty (today's
# rebuild-on-reregister behavior, which re-issues from unit 0).
self._dataset_lengths = server_state.get("dataset_lengths", {}) or {}
restored = server_state.get("work_queues") or {}
self._work_queues = {}
for qkey, d in restored.items():
# Skip-and-warn on ANY bad entry (malformed key, missing/typed
# field, or a bitmap whose length disagrees with total_units)
# rather than letting one corrupt/partial entry abort the whole
# load and brick a restart. A length mismatch would otherwise
# surface much later as an IndexError during issuance.
try:
ds, seed_str = qkey.rsplit("|", 1)
seed = int(seed_str)
total_units = int(d["total_units"])
nbytes = (total_units + 7) // 8
issued = bytearray(d["issued"])
completed = bytearray(d.get("completed") or bytes(nbytes))
if len(issued) != nbytes or len(completed) != nbytes:
raise ValueError(
f"bitmap length {len(issued)}/{len(completed)} "
f"!= expected {nbytes} for total_units={total_units}"
)
queue = WorkQueue(
total_units=total_units,
issued=issued,
completed=completed,
hint_length=int(d["hint_length"]),
issued_count=int(d.get("issued_count", 0)),
completed_count=int(d.get("completed_count", 0)),
by_worker=d.get("by_worker") or {},
dataset_path=d.get("dataset_path"),
dataset_name=d.get("dataset_name"),
dataset_split=d.get("dataset_split"),
dataset_revision=d.get("dataset_revision"),
dataset_data_files=d.get("dataset_data_files"),
)
except (ValueError, AttributeError, KeyError, TypeError) as exc:
logger.warning(
"Skipping malformed work-queue entry %r in "
"server_state.pt: %s",
qkey,
exc,
)
continue
self._work_queues[(ds, seed)] = queue

if self._work_queues:
issued_total = sum(q.issued_count for q in self._work_queues.values())
logger.info(
"Restored %d work queue(s) from checkpoint (%d unit(s) "
"already issued); re-registering workers resume mid-epoch.",
len(self._work_queues),
issued_total,
)

logger.info(
Expand Down Expand Up @@ -3334,14 +3399,37 @@ def run(self):
self._start_health_monitor()
self._start_bulk_listener()

# Flush state on SIGTERM (how the forgather_server scheduler stops a
# server job) as well as SIGINT (Ctrl-C), so a webui-triggered stop
# doesn't lose rounds / issued work-units since the last autosave
# (#105). The handler just raises KeyboardInterrupt to break
# serve_forever; the single save below in `finally` then covers every
# exit path. signal.signal only works on the main thread (run() is
# the blocking CLI entrypoint); fall back gracefully otherwise.
import signal as _signal

def _raise_interrupt(signum, frame):
raise KeyboardInterrupt

prev_sigterm = None
try:
prev_sigterm = _signal.signal(_signal.SIGTERM, _raise_interrupt)
except (ValueError, OSError):
pass # not on the main thread — SIGTERM stays default

try:
self._server.serve_forever()
except KeyboardInterrupt:
logger.info("Server interrupted by Ctrl-C")
if self.save_every_n_rounds > 0:
logger.info("Saving server state before shutdown...")
self.save_state()
logger.info("Server interrupted (signal) — shutting down")
finally:
if self.save_every_n_rounds > 0:
try:
logger.info("Saving server state before shutdown...")
self.save_state()
except Exception as exc:
logger.error("Failed to save server state on shutdown: %s", exc)
if prev_sigterm is not None:
_signal.signal(_signal.SIGTERM, prev_sigterm)
self._stop_health_monitor()
self._stop_bulk_listener()
self._running = False
Expand Down Expand Up @@ -3383,6 +3471,16 @@ def stop(self):
"""Stop the background server."""
if not self._running:
raise RuntimeError("Stop cannot be called, unless we are already running.")
# Flush state before teardown so a graceful stop doesn't lose the
# rounds / issued work-units accumulated since the last autosave
# (#105). save_state no-ops when clean; never let a save failure
# block shutdown. Mirrors the run() Ctrl-C path.
if self.save_every_n_rounds > 0:
try:
logger.info("Saving server state before shutdown...")
self.save_state()
except Exception as exc:
logger.error("Failed to save server state on stop: %s", exc)
self._stop_health_monitor()
self._stop_bulk_listener()
if self._server:
Expand Down
Loading
Loading