From 32e679b6bdef1e1e9351cc1c5d416654e1867846 Mon Sep 17 00:00:00 2001 From: Jason dinAlt Date: Fri, 29 May 2026 08:07:11 +0000 Subject: [PATCH 1/2] diloco: make worker settings server-authoritative via /info MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sync_every, bf16_comm, dylu and num_fragments must match across every worker in a group for the sync barrier / outer step / fragment barriers to be coherent. They were redundantly re-specified on the worker CLI, in the callback templates, and in the webui — a foot-gun that let a worker silently diverge from the group. The server already advertises them via GET /info; now the worker actually consumes them and there is no client override (a divergent value is never useful). Server (/info): - expected_client_settings is now complete and non-null: sync_every is advertised as dylu_base_sync_every whether or not DyLU is on, plus num_fragments_default and heartbeat_timeout. Added settings_authority and a coarse model_hash (param name/shape fingerprint, consumed by the upcoming model-bundle work). Worker: - DiLoCoClient.get_info(). - DiLoCoCallback drops the four must-match constructor args / DILOCO_* reads; on_train_begin fetches /info (leader-only, broadcast to DDP followers), applies the settings verbatim and logs them. A server that doesn't advertise sync_every (too old) is fatal, not silently defaulted. heartbeat_interval stays client-local and is validated against the server's heartbeat_timeout. Surface removal: - `forgather diloco worker` drops --sync-every / --no-bf16 / --dylu / --num-fragments and no longer emits the matching DILOCO_* env vars. - scheduler no longer forwards the four settings. - callbacks/diloco.yaml + mixins/diloco.yaml drop the four knobs. - SubmitModal removes the four editable inputs and shows the server's values read-only instead (stable layout). - docs/trainers/diloco.md updated to match. Tests: extended /info shape + model_hash determinism; callback adoption matrix + heartbeat validation; scheduler no longer forwards the four. All DiLoCo + forgather_server suites green (974 passed). Co-Authored-By: Claude Opus 4.8 (1M context) --- docs/trainers/diloco.md | 67 +++--- src/forgather/cli/diloco.py | 22 +- src/forgather/cli/diloco_args.py | 31 +-- src/forgather/ml/diloco/client.py | 11 + src/forgather/ml/diloco/server.py | 52 ++++- .../ml/trainer/callbacks/diloco_callback.py | 163 +++++++++----- templatelib/examples/callbacks/diloco.yaml | 14 +- templatelib/examples/mixins/diloco.yaml | 55 ++--- .../test_scheduler_diloco_env.py | 36 +-- tests/unit/ml/diloco/test_dashboard.py | 32 +++ tests/unit/ml/diloco/test_diloco_callback.py | 212 +++++++++++++----- tools/forgather_server/scheduler.py | 12 +- tools/forgather_server/webui/src/api.ts | 7 + .../webui/src/components/SubmitModal.tsx | 212 ++++++------------ 14 files changed, 524 insertions(+), 402 deletions(-) diff --git a/docs/trainers/diloco.md b/docs/trainers/diloco.md index df636eca..6bdf75ea 100644 --- a/docs/trainers/diloco.md +++ b/docs/trainers/diloco.md @@ -133,38 +133,33 @@ explicit `--output-dir` the operator passes, so workers get distinct per-worker dirs for free): ```bash -# sync mode forgather diloco worker \ --server 192.168.1.100:8512 \ - --sync-every 500 \ --worker-id w0 \ -p my_project -t train.yaml \ train -d 0 - -# with DyLU - server adjusts sync frequency dynamically -forgather diloco worker \ - --server 192.168.1.100:8512 \ - --sync-every 500 \ - --worker-id w1 \ - --dylu \ - --heartbeat-interval 30 \ - -p my_project -t train.yaml \ - train -d 1 ``` Worker arguments: - `--server`: Server address as `host:port` -- `--sync-every`: Local steps between syncs (default: 500) - `--worker-id`: Unique worker identity. Drives the per-worker output-dir suffix the project template appends to `ns.output_dir`, and the uniqueness key the server enforces on `/register`. Auto-generated when omitted but operators typically set it explicitly so logs / output dirs are predictable. -- `--no-bf16`: Send full-precision pseudo-gradients instead of bfloat16 -- `--dylu`: Enable dynamic sync frequency adjustment from server -- `--heartbeat-interval`: Seconds between heartbeats for speed reporting (default: 30) +- `--heartbeat-interval`: Seconds between heartbeats for speed reporting + (default: 30). Client-local; validated against the server's + `--heartbeat-timeout` at startup. - `-d`: CUDA visible devices +**Server-authoritative settings.** `sync_every`, `bf16_comm`, `dylu`, and +`num_fragments` must match across every worker in the group for the sync +barrier / outer step / fragment barriers to be coherent. The server is +their sole authority: the worker fetches them from the server's `/info` at +startup and logs the values it adopted. There are no `--sync-every` / +`--no-bf16` / `--dylu` / `--num-fragments` worker flags — set them on the +**server** instead (`--dylu-base-sync-every`, `--dylu`, etc.). + Dataset partitioning across workers is handled by the server's **work-unit dispatch**: each worker registers its train dataset with the DiLoCo server on first iteration and pulls per-unit row ranges on demand, so no row is @@ -373,18 +368,20 @@ updates at approximately the same wall-clock rate. DyLU requires: 1. **Server**: `--dylu` flag and `--dylu-base-sync-every` (default: 500) -2. **Workers**: `--dylu` flag and `--heartbeat-interval` (default: 30s) +2. **Workers**: `--heartbeat-interval` (default: 30s). DyLU enablement and the + base `sync_every` are taken from the server's `/info` — there is no worker + `--dylu` flag. Workers periodically report their training speed via heartbeats. The server computes the recommended sync interval and returns it in the heartbeat response. Workers adjust their `sync_every` dynamically. ```bash -# Server with DyLU +# Server with DyLU (this is where dylu / sync_every are configured) forgather diloco server -o ./model -n 3 --async --dylu --dylu-base-sync-every 500 -# Worker with DyLU enabled -forgather diloco worker --server host:8512 --sync-every 500 --worker-id w0 --dylu -- train +# Worker — picks up dylu + sync_every from the server's /info +forgather diloco worker --server host:8512 --worker-id w0 -- train ``` ### Staleness Tracking @@ -432,12 +429,11 @@ communication becomes fully overlapped. ### CLI Usage ```bash -# Worker with 4 streaming fragments +# Streaming fragments are configured on the server, not the worker — +# the worker reads num_fragments (and sync_every) from /info. forgather diloco worker \ --server 192.168.1.100:8512 \ - --sync-every 500 \ --worker-id w0 \ - --num-fragments 4 \ -p my_project -t train.yaml \ train ``` @@ -649,12 +645,11 @@ configuration works for both DiLoCo and standalone training. ```python from forgather.ml.trainer.callbacks import DiLoCoCallback -# Explicit configuration +# Explicit configuration (client-local knobs only) callback = DiLoCoCallback( server_addr="192.168.1.100:8512", - sync_every=500, - bf16_comm=True, - num_fragments=1, + worker_id="w0", + heartbeat_interval=30.0, ) # Or rely on environment variables (set by `forgather diloco worker`) @@ -669,17 +664,21 @@ trainer = Trainer( trainer.train() ``` -All constructor parameters fall back to `DILOCO_*` environment variables: +The callback's client-local constructor parameters fall back to `DILOCO_*` +environment variables: | Parameter | Env Var | Default | |-----------|---------|---------| | `server_addr` | `DILOCO_SERVER` | `""` (no-op) | -| `sync_every` | `DILOCO_SYNC_EVERY` | `500` | | `worker_id` | `DILOCO_WORKER_ID` | auto-generated | -| `bf16_comm` | `DILOCO_BF16_COMM` | `True` | -| `dylu` | `DILOCO_DYLU` | `False` | | `heartbeat_interval` | `DILOCO_HEARTBEAT_INTERVAL` | `30.0` | -| `num_fragments` | `DILOCO_NUM_FRAGMENTS` | `1` | + +`sync_every`, `bf16_comm`, `dylu`, and `num_fragments` are **not** callback +parameters or env vars — they must match across the group, so the worker +reads them from the server's `/info` at startup (set them on the server). +The `DiLoCoWorker` class still accepts them directly for the low-level +programmatic API; it's only the callback / CLI surface that defers to the +server. ### Configuration Template @@ -696,8 +695,8 @@ Or add the callback directly in your project template: == super() diloco_callback: !singleton:forgather.ml.trainer.callbacks:DiLoCoCallback server_addr: {{ diloco_server | default(None) }} - sync_every: {{ diloco_sync_every | default(None) }} - num_fragments: {{ diloco_num_fragments | default(None) }} + worker_id: {{ diloco_worker_id | default(None) }} + heartbeat_interval: {{ diloco_heartbeat_interval | default(None) }} ``` See `examples/tiny_experiments/diloco/` for a complete working example. diff --git a/src/forgather/cli/diloco.py b/src/forgather/cli/diloco.py index fa13d39c..cf103cd0 100644 --- a/src/forgather/cli/diloco.py +++ b/src/forgather/cli/diloco.py @@ -306,14 +306,13 @@ def _worker_cmd(args): This wraps the standard training command, injecting DiLoCo configuration via environment variables that the training script picks up. """ - # Set DiLoCo environment variables for the training script + # Set DiLoCo environment variables for the training script. Only + # client-local knobs are forwarded; sync_every / bf16_comm / dylu / + # num_fragments are server-authoritative and resolved from /info by + # the worker at startup (no client override). env = os.environ.copy() env["DILOCO_SERVER"] = args.server - env["DILOCO_SYNC_EVERY"] = str(args.sync_every) - env["DILOCO_BF16_COMM"] = "0" if args.no_bf16 else "1" - env["DILOCO_DYLU"] = "1" if getattr(args, "dylu", False) else "0" env["DILOCO_HEARTBEAT_INTERVAL"] = str(getattr(args, "heartbeat_interval", 30.0)) - env["DILOCO_NUM_FRAGMENTS"] = str(getattr(args, "num_fragments", 1)) if args.worker_id: env["DILOCO_WORKER_ID"] = args.worker_id @@ -351,15 +350,10 @@ def _worker_cmd(args): cmd_args.extend(remainder) cmd_str = " ".join(cmd_args) - diloco_info = ( - f"DiLoCo: server={args.server}, sync_every={args.sync_every}, " - f"bf16={'yes' if not args.no_bf16 else 'no'}" - ) - num_frags = getattr(args, "num_fragments", 1) - if num_frags > 1: - diloco_info += f", fragments={num_frags}" - if getattr(args, "dylu", False): - diloco_info += ", dylu=yes" + # sync_every / bf16 / dylu / num_fragments come from the server's /info + # at startup, so they aren't known here — the worker logs them once it + # negotiates with the server. + diloco_info = f"DiLoCo: server={args.server}" if args.worker_id: diloco_info += f", worker_id={args.worker_id}" diff --git a/src/forgather/cli/diloco_args.py b/src/forgather/cli/diloco_args.py index df98cda3..84566bb3 100644 --- a/src/forgather/cli/diloco_args.py +++ b/src/forgather/cli/diloco_args.py @@ -256,28 +256,16 @@ def create_diloco_parser(global_args): required=True, help="DiLoCo server address as host:port", ) - worker_parser.add_argument( - "--sync-every", - type=int, - default=500, - help="Number of local optimizer steps between syncs (default: 500)", - ) + # NOTE: sync_every / bf16_comm / dylu / num_fragments are NOT worker + # flags. They must match across the group, so the server is their sole + # authority — the worker reads them from /info at startup. See + # DiLoCoCallback (server-authoritative settings). worker_parser.add_argument( "--worker-id", type=str, default=None, help="Worker ID (auto-generated if not provided)", ) - worker_parser.add_argument( - "--no-bf16", - action="store_true", - help="Disable bfloat16 communication (send full precision pseudo-gradients)", - ) - worker_parser.add_argument( - "--dylu", - action="store_true", - help="Enable Dynamic Local Updates - adapt sync_every based on server recommendations", - ) worker_parser.add_argument( "--heartbeat-interval", type=float, @@ -285,6 +273,7 @@ def create_diloco_parser(global_args): help=( "Seconds between heartbeats to server. Enables server-side\n" "health monitoring and DyLU speed reporting. 0 = disabled.\n" + "Client-local; validated against the server's heartbeat-timeout.\n" "(default: 30)" ), ) @@ -295,16 +284,6 @@ def create_diloco_parser(global_args): default=None, help='CUDA Visible Devices e.g. "0,1"', ) - worker_parser.add_argument( - "--num-fragments", - type=int, - default=1, - help=( - "Number of fragments for streaming sync. When > 1, splits the model\n" - "into N fragments that sync at staggered intervals in background\n" - "threads, overlapping communication with computation. (default: 1)" - ), - ) worker_parser.add_argument( "--dry-run", action="store_true", diff --git a/src/forgather/ml/diloco/client.py b/src/forgather/ml/diloco/client.py index b2e28d54..4b790e90 100644 --- a/src/forgather/ml/diloco/client.py +++ b/src/forgather/ml/diloco/client.py @@ -672,6 +672,17 @@ def get_status(self) -> dict: """Get server status.""" return self._request_json("GET", "/status") + def get_info(self) -> dict: + """Get the server's static-ish configuration. + + Returns the negotiation facts a worker needs before registering: + the authoritative ``expected_client_settings`` (sync_every, + bf16_comm, dylu, num_fragments_*, heartbeat_timeout), the model + fingerprint (``model_hash``), and topology info. Distinct from + ``get_status`` which is the live, rapidly-changing snapshot. + """ + return self._request_json("GET", "/info") + # ------------------------------------------------------------------ # Work-unit dispatch (see docs/design/diloco-work-unit-dispatch.md) # ------------------------------------------------------------------ diff --git a/src/forgather/ml/diloco/server.py b/src/forgather/ml/diloco/server.py index e26ecbed..8de7d7ed 100644 --- a/src/forgather/ml/diloco/server.py +++ b/src/forgather/ml/diloco/server.py @@ -25,6 +25,7 @@ """ import base64 +import hashlib import io import json import logging @@ -537,6 +538,14 @@ def _initialize(self, model_state_dict: Dict[str, torch.Tensor]): ] ) + # Coarse model fingerprint advertised via /info. Workers use it to + # decide whether a cached model-definition bundle is still valid + # (issue #53) and as an early, pre-construction compatibility gate + # (the per-parameter /register fingerprint stays the fine-grained + # check). Derived from the parameter (name, shape) set the server + # holds; deterministic across the server's lifetime. + self._model_hash = self._compute_model_hash(model_state_dict) + # Outer optimizer self.outer_optimizer = self.outer_optimizer_factory( self._param_list.parameters() @@ -654,6 +663,23 @@ def _find_available_port(start_port: int = 8512, max_attempts: int = 100) -> int f"No available port in range {start_port}-{start_port + max_attempts}" ) + @staticmethod + def _compute_model_hash(model_state_dict: Dict[str, torch.Tensor]) -> str: + """Coarse, deterministic fingerprint of the model's parameter set. + + Hashes the sorted ``(name, shape)`` pairs. Stable across restarts + and machines for the same architecture; changes when the parameter + topology changes. Advertised via /info so workers can validate a + cached model-definition bundle and gate compatibility before + constructing the model. + """ + h = hashlib.sha256() + for name in sorted(model_state_dict.keys()): + shape = tuple(model_state_dict[name].shape) + h.update(name.encode("utf-8")) + h.update(repr(shape).encode("utf-8")) + return h.hexdigest() + def get_global_params(self) -> Dict[str, torch.Tensor]: """Get current global parameters as a state dict.""" return { @@ -2109,16 +2135,30 @@ def _handle_info(self, handler: BaseHTTPRequestHandler): "model_size_mb": round(self._model_size_mb, 2), "dylu_enabled": self.dylu_enabled, "dylu_base_sync_every": self.dylu_base_sync_every, + # Coarse model fingerprint (issue #53). Workers validate a + # cached model-definition bundle against this and use it as an + # early compatibility gate before constructing the model. + "model_hash": self._model_hash, + # The server is the sole authority for these settings: they must + # match across the group for the sync barrier / outer step / + # fragment barriers to be coherent, so the worker takes them + # verbatim (no client override). ``settings_authority`` signals + # that intent to clients and tooling. + "settings_authority": "server", "expected_client_settings": { - # DyLU servers want all workers ramped to the base rate so - # the per-worker scaling has a known anchor. Non-DyLU - # servers leave sync_every up to the worker. - "sync_every": ( - self.dylu_base_sync_every if self.dylu_enabled else None - ), + # Every worker ramps to the same inner-step cadence. We + # advertise ``dylu_base_sync_every`` as the canonical + # ``sync_every`` whether or not DyLU is enabled — it's the + # operator-set anchor either way, and a non-null value lets + # the worker drop its own --sync-every entirely. + "sync_every": self.dylu_base_sync_every, "dylu": self.dylu_enabled, "bf16_comm": True, "num_fragments_min": 1, + "num_fragments_default": 1, + # Exposed so the worker can validate its (client-local) + # heartbeat send cadence against the server's death timeout. + "heartbeat_timeout": self.heartbeat_timeout, }, } _send_json_response(handler, response) diff --git a/src/forgather/ml/trainer/callbacks/diloco_callback.py b/src/forgather/ml/trainer/callbacks/diloco_callback.py index f279cea3..88034c00 100644 --- a/src/forgather/ml/trainer/callbacks/diloco_callback.py +++ b/src/forgather/ml/trainer/callbacks/diloco_callback.py @@ -53,14 +53,6 @@ def _env_float(name: str, default: float) -> float: return float(val) -def _env_int(name: str, default: int) -> int: - """Read an int from an environment variable.""" - val = os.environ.get(name, "") - if not val: - return default - return int(val) - - class DiLoCoCallback(TrainerCallback): """ Trainer callback that manages a DiLoCoWorker for distributed local-SGD training. @@ -81,33 +73,31 @@ class DiLoCoCallback(TrainerCallback): we end up here without a server, we fail loudly. Likewise, a *configured* but *unreachable* server is fatal at - startup — the callback does a ``/status`` round-trip before + startup — the callback does a ``/info`` round-trip before proceeding so the operator sees the failure in the TTY pane, not five hundred steps later when the first sync fails. + **Server-authoritative settings.** ``sync_every``, ``bf16_comm``, + ``dylu`` and ``num_fragments`` must match across the whole group for + the sync barrier / outer step / fragment barriers to be coherent, so + they are owned by the server: the worker reads them verbatim from the + server's ``/info`` at startup (the leader fetches and broadcasts to + DDP followers). There is no client override — a divergent value is + never useful. Only genuinely client-local knobs remain as constructor + args / env vars. + Parameters ---------- server_addr : str, optional DiLoCo server address (``"host:port"``). Falls back to ``DILOCO_SERVER`` env var. - sync_every : int, optional - Local optimizer steps between syncs. Falls back to - ``DILOCO_SYNC_EVERY`` env var. Default ``500``. worker_id : str, optional Unique worker ID. Falls back to ``DILOCO_WORKER_ID`` env var. Auto-generated if unset. - bf16_comm : bool, optional - Cast pseudo-gradients to bfloat16. Falls back to - ``DILOCO_BF16_COMM`` env var. Default ``True``. - dylu : bool, optional - Enable Dynamic Local Updates. Falls back to ``DILOCO_DYLU`` env var. - Default ``False``. heartbeat_interval : float, optional - Seconds between heartbeats. Falls back to + Seconds between heartbeats (a client-local send cadence, validated + against the server's ``heartbeat_timeout``). Falls back to ``DILOCO_HEARTBEAT_INTERVAL`` env var. Default ``30.0``. - num_fragments : int, optional - Number of streaming fragments. Falls back to - ``DILOCO_NUM_FRAGMENTS`` env var. Default ``1`` (no streaming). timeout : float, optional Client timeout in seconds. Default ``600``. max_sync_retries : int, optional @@ -117,12 +107,8 @@ class DiLoCoCallback(TrainerCallback): def __init__( self, server_addr: Optional[str] = None, - sync_every: Optional[int] = None, worker_id: Optional[str] = None, - bf16_comm: Optional[bool] = None, - dylu: Optional[bool] = None, heartbeat_interval: Optional[float] = None, - num_fragments: Optional[int] = None, timeout: float = 600, max_sync_retries: int = 3, auth_token: Optional[str] = None, @@ -134,24 +120,21 @@ def __init__( from forgather.ml.diloco import diloco_server_addr self.server_addr = server_addr or diloco_server_addr() - self.sync_every = ( - sync_every if sync_every is not None else _env_int("DILOCO_SYNC_EVERY", 500) - ) self.worker_id = worker_id or os.environ.get("DILOCO_WORKER_ID", "") or None - self.bf16_comm = ( - bf16_comm if bf16_comm is not None else _env_bool("DILOCO_BF16_COMM", True) - ) - self.dylu = dylu if dylu is not None else _env_bool("DILOCO_DYLU", False) self.heartbeat_interval = ( heartbeat_interval if heartbeat_interval is not None else _env_float("DILOCO_HEARTBEAT_INTERVAL", 30.0) ) - self.num_fragments = ( - num_fragments - if num_fragments is not None - else _env_int("DILOCO_NUM_FRAGMENTS", 1) - ) + # Server-authoritative settings: these MUST match across the group + # for the sync barrier / outer step / fragment barriers to be + # coherent, so the worker takes them verbatim from the server's + # /info (resolved in on_train_begin) with no client override. They + # stay None until then. + self.sync_every: Optional[int] = None + self.bf16_comm: Optional[bool] = None + self.dylu: Optional[bool] = None + self.num_fragments: Optional[int] = None self.timeout = timeout self.max_sync_retries = max_sync_retries # Security (issue #90): bearer token + TLS verification. ``None`` @@ -176,6 +159,54 @@ def active(self) -> bool: """Whether DiLoCo integration is configured (server_addr is set).""" return bool(self.server_addr) + def _resolve_server_settings(self, info: Dict[str, Any]) -> Dict[str, Any]: + """Extract the server-authoritative settings from an /info payload. + + These four (sync_every, bf16_comm, dylu, num_fragments) must match + across the group, so the server is the sole authority — there is no + client override. A server that doesn't advertise them (too old, or + ``expected_client_settings.sync_every`` is null) is a fatal + misconfiguration, not something to paper over with a client default + (no-silent-fallback). + """ + from forgather.ml.diloco.client import DiLoCoServerUnreachable + + ecs = info.get("expected_client_settings") or {} + sync_every = ecs.get("sync_every") + if sync_every is None: + raise DiLoCoServerUnreachable( + f"DiLoCoCallback: server at {self.server_addr!r} did not " + f"advertise a sync_every in /info " + f"(expected_client_settings={ecs!r}). It is likely an " + f"older server predating server-authoritative settings; " + f"upgrade the diloco server." + ) + return { + "sync_every": int(sync_every), + "bf16_comm": bool(ecs.get("bf16_comm", True)), + "dylu": bool(ecs.get("dylu", False)), + "num_fragments": int(ecs.get("num_fragments_default", 1)), + "heartbeat_timeout": ecs.get("heartbeat_timeout"), + } + + def _validate_heartbeat(self, heartbeat_timeout) -> None: + """Fail loud if the client's heartbeat cadence can't beat the + server's death timeout. ``heartbeat_interval`` stays a client knob + (it's a genuinely local send cadence, not a must-match value), but + a cadence at or above the server's timeout guarantees spurious + eviction, so reject it up front. ``heartbeat_timeout <= 0`` means + death detection is disabled — nothing to validate.""" + if heartbeat_timeout and heartbeat_timeout > 0: + if self.heartbeat_interval >= heartbeat_timeout: + raise ValueError( + f"DiLoCoCallback: heartbeat_interval=" + f"{self.heartbeat_interval}s is >= the server's " + f"heartbeat_timeout={heartbeat_timeout}s; the worker " + f"would be evicted between heartbeats. Set " + f"--heartbeat-interval (or DILOCO_HEARTBEAT_INTERVAL) " + f"well below {heartbeat_timeout}s." + ) + def on_train_begin( self, args: MinimalTrainingArguments, @@ -272,25 +303,26 @@ def on_train_begin( f"group '{base_id}' member '{worker_id}'." ) - # Reachability pre-check: a /status round-trip before we - # bother building the worker. Surfaces "server URL wrong", - # "server down", "wrong port", "firewall" while the operator - # is still watching the TTY, instead of 500 local steps later - # when the first sync fails. + # Reachability pre-check + settings negotiation in one /info + # round-trip, before we bother building the worker. Surfaces + # "server URL wrong", "server down", "wrong port", "firewall" + # while the operator is still watching the TTY, instead of 500 + # local steps later when the first sync fails. /info also carries + # the server-authoritative settings (sync_every, bf16_comm, dylu, + # num_fragments) the worker must adopt verbatim. # - # DDP rank 0 only — followers don't talk to the server, so - # they shouldn't probe it either. They'll still hit the - # broadcast collective inside DiLoCoWorker.start(), so if the - # leader fails the probe and aborts, followers will deadlock - # waiting for the broadcast — but the leader's exception is - # the actionable signal the operator needs, and the worker's - # train loop will get torn down by the trainer once the - # leader's process exits. + # DDP rank 0 only — followers don't talk to the server. The leader + # fetches /info and broadcasts the resolved settings to followers + # so every rank syncs in lockstep. If the leader fails the probe + # and aborts, followers deadlock on the broadcast below (or the + # one inside DiLoCoWorker.start()) — but the leader's exception is + # the actionable signal, and the trainer tears the followers down + # once the leader's process exits. import torch.distributed as dist - is_leader = ( - not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0 - ) + ddp = dist.is_available() and dist.is_initialized() + is_leader = not ddp or dist.get_rank() == 0 + settings: Optional[Dict[str, Any]] = None if is_leader: probe = DiLoCoClient( self.server_addr, @@ -299,14 +331,33 @@ def on_train_begin( verify_tls=self.verify_tls, ) try: - probe.get_status() + info = probe.get_info() except Exception as exc: raise DiLoCoServerUnreachable( - f"DiLoCoCallback: /status round-trip to " + f"DiLoCoCallback: /info round-trip to " f"{self.server_addr!r} failed at startup: {exc}. " f"The server must be running and reachable before " f"workers can register." ) from exc + settings = self._resolve_server_settings(info) + if ddp: + holder = [settings] + dist.broadcast_object_list(holder, src=0) + settings = holder[0] + + self.sync_every = settings["sync_every"] + self.bf16_comm = settings["bf16_comm"] + self.dylu = settings["dylu"] + self.num_fragments = settings["num_fragments"] + self._validate_heartbeat(settings["heartbeat_timeout"]) + logger.info( + "DiLoCoCallback: using server settings sync_every=%s bf16_comm=%s " + "dylu=%s num_fragments=%s", + self.sync_every, + self.bf16_comm, + self.dylu, + self.num_fragments, + ) self._worker = DiLoCoWorker( model=model, diff --git a/templatelib/examples/callbacks/diloco.yaml b/templatelib/examples/callbacks/diloco.yaml index 7b9dfb0f..8f0c93a5 100644 --- a/templatelib/examples/callbacks/diloco.yaml +++ b/templatelib/examples/callbacks/diloco.yaml @@ -2,9 +2,13 @@ ## DiLoCo Callback ## Adds DiLoCoCallback to the trainer callback list for distributed local-SGD -## training. All parameters default to environment variables set by the -## `forgather diloco worker` CLI, so this template works with both explicit -## configuration and CLI-based env vars. +## training. Client-local parameters default to environment variables set by +## the `forgather diloco worker` CLI, so this template works with both +## explicit configuration and CLI-based env vars. +## +## sync_every / bf16_comm / dylu / num_fragments are NOT set here: they must +## match across the group, so the server is their sole authority and the +## worker reads them from the server's /info at startup (no client override). ## ## When server_addr is empty (no DILOCO_SERVER), the callback is a no-op. @@ -12,11 +16,7 @@ => super() diloco_callback: !singleton:forgather.ml.trainer.callbacks:DiLoCoCallback server_addr: {{ diloco_server | default(None) }} - sync_every: {{ diloco_sync_every | default(None) }} worker_id: {{ diloco_worker_id | default(None) }} - bf16_comm: {{ diloco_bf16_comm | default(None) }} - dylu: {{ diloco_dylu | default(None) }} heartbeat_interval: {{ diloco_heartbeat_interval | default(None) }} - num_fragments: {{ diloco_num_fragments | default(None) }} timeout: {{ diloco_timeout | default(600) }} max_sync_retries: {{ diloco_max_sync_retries | default(3) }} diff --git a/templatelib/examples/mixins/diloco.yaml b/templatelib/examples/mixins/diloco.yaml index 153aa6cc..0697bcd7 100644 --- a/templatelib/examples/mixins/diloco.yaml +++ b/templatelib/examples/mixins/diloco.yaml @@ -12,8 +12,8 @@ ## **Import with context**: callers must use ## ``{% from "mixins/diloco.yaml" import … with context %}``. Without ## ``with context`` the macro can't see the caller's Jinja variables, -## so ``-- set diloco_num_fragments = 4`` in a leaf would render as -## ``num_fragments: null`` instead of ``4``. The current Forgather +## so ``-- set diloco_worker_id = 'w0'`` in a leaf would render as +## ``worker_id: null`` instead of ``w0``. The current Forgather ## dynamic-args also need the context to flow through. ## ## Activation: ``DILOCO_SERVER`` env var (set by the scheduler when @@ -22,60 +22,43 @@ ## by the leaf and the same config behaves as a vanilla training run. -- macro callback_entry() - ## DiLoCo callback. All DILOCO_* env vars are read by the - ## callback constructor when these args are None; the webui - ## sets them based on the operator's selections. + ## DiLoCo callback. Client-local DILOCO_* env vars are read by the + ## callback constructor when these args are None; the webui sets them + ## based on the operator's selections. + ## + ## sync_every / bf16_comm / dylu / num_fragments are intentionally + ## absent: they must match across the group, so the server is their + ## sole authority and the worker reads them from the server's /info at + ## startup (no client override). diloco_callback: !singleton:forgather.ml.trainer.callbacks:DiLoCoCallback server_addr: {{ diloco_server | toyaml(None) }} - sync_every: {{ diloco_sync_every | toyaml(None) }} worker_id: {{ diloco_worker_id | toyaml(None) }} - bf16_comm: {{ diloco_bf16_comm | toyaml(None) }} - dylu: {{ diloco_dylu | toyaml(None) }} heartbeat_interval: {{ diloco_heartbeat_interval | toyaml(None) }} - num_fragments: {{ diloco_num_fragments | toyaml(None) }} timeout: {{ diloco_timeout | toyaml(600) }} max_sync_retries: {{ diloco_max_sync_retries | toyaml(3) }} -- endmacro -- macro dynamic_args() - ## DiLoCo override knobs surfaced on the train submit modal. The - ## webui's DiLoCo radio group writes to these when the operator - ## picks a server. Optional from the CLI side; unset means - ## env-var fallback (or callback default). + ## DiLoCo knobs surfaced on the train submit modal. The webui's DiLoCo + ## radio group writes to these when the operator picks a server. + ## Optional from the CLI side; unset means env-var fallback (or + ## callback default). + ## + ## sync_every / num_fragments / dylu / bf16_comm are NOT exposed: they + ## are server-authoritative (read from /info), so there is no operator + ## knob — the webui shows the server's values read-only instead. diloco_server: names: "--diloco-server" type: str group: "DiLoCo" help: "DiLoCo server address host:port. When unset, the DiLoCo callback isn't constructed and the config runs as a vanilla finetune." - diloco_sync_every: - names: "--diloco-sync-every" - type: int - group: "DiLoCo" - min: 1 - help: "Local optimizer steps between syncs." - diloco_num_fragments: - names: "--diloco-fragments" - type: int - group: "DiLoCo" - min: 1 - help: "Number of streaming-sync fragments (1 = no streaming)." - diloco_dylu: - names: "--diloco-dylu" - action: "store_true" - group: "DiLoCo" - help: "Enable Dynamic Local Updates (server must support it)." - diloco_bf16_comm: - names: "--diloco-bf16-comm" - type: bool - group: "DiLoCo" - help: "Cast pseudo-gradients to bf16 before sending." diloco_heartbeat_interval: names: "--diloco-heartbeat-interval" type: float group: "DiLoCo" min: 0 - help: "Seconds between heartbeats to the server." + help: "Seconds between heartbeats to the server (client-local; validated against the server's heartbeat-timeout)." diloco_worker_id: names: "--diloco-worker-id" type: str diff --git a/tests/unit/forgather_server/test_scheduler_diloco_env.py b/tests/unit/forgather_server/test_scheduler_diloco_env.py index 26cc513b..c3361671 100644 --- a/tests/unit/forgather_server/test_scheduler_diloco_env.py +++ b/tests/unit/forgather_server/test_scheduler_diloco_env.py @@ -15,7 +15,7 @@ def test_empty_dict_emits_nothing(): def test_missing_server_addr_short_circuits(): # Server addr is the gate — even if other keys are set, no env # vars get emitted when the worker is opting out of DiLoCo entirely. - assert _diloco_env_from_job_params({"sync_every": 500}, QID) == {} + assert _diloco_env_from_job_params({"heartbeat_interval": 15}, QID) == {} def test_server_addr_alone_yields_minimal_env(): @@ -30,7 +30,10 @@ def test_server_addr_alone_yields_minimal_env(): } -def test_full_payload_translates_all_keys(): +def test_full_payload_forwards_only_client_local_keys(): + # sync_every / num_fragments / dylu / bf16_comm are server-authoritative + # now — the worker reads them from /info, so the scheduler must NOT + # forward them even when the (legacy) submission carries them. env = _diloco_env_from_job_params( { "server_addr": "host:8512", @@ -45,23 +48,21 @@ def test_full_payload_translates_all_keys(): ) assert env == { "DILOCO_SERVER": "host:8512", - "DILOCO_SYNC_EVERY": "500", - "DILOCO_NUM_FRAGMENTS": "4", - "DILOCO_DYLU": "1", - "DILOCO_BF16_COMM": "0", "DILOCO_HEARTBEAT_INTERVAL": "15.0", "DILOCO_WORKER_ID": "w1", } -def test_bf16_true_translates_to_1(): - env = _diloco_env_from_job_params({"server_addr": "h:1", "bf16_comm": True}, QID) - assert env["DILOCO_BF16_COMM"] == "1" - - -def test_dylu_false_translates_to_0(): - env = _diloco_env_from_job_params({"server_addr": "h:1", "dylu": False}, QID) - assert env["DILOCO_DYLU"] == "0" +def test_server_authoritative_keys_never_forwarded(): + # Even alone, the must-match settings are not translated to env vars. + env = _diloco_env_from_job_params( + {"server_addr": "h:1", "sync_every": 500, "dylu": True, "bf16_comm": True}, + QID, + ) + assert "DILOCO_SYNC_EVERY" not in env + assert "DILOCO_DYLU" not in env + assert "DILOCO_BF16_COMM" not in env + assert "DILOCO_NUM_FRAGMENTS" not in env def test_empty_worker_id_falls_back_to_queue_id(): @@ -83,10 +84,9 @@ def test_explicit_worker_id_wins_over_queue_id(): def test_none_typed_fields_are_skipped(): - # None values shouldn't emit env vars — the callback's constructor - # then falls back to its own DILOCO_* env reads (which would be - # unset), then to its default. ``worker_id`` is the one exception: - # it always gets set (queue_id fallback). + # None values shouldn't emit env vars. ``worker_id`` is the one + # exception: it always gets set (queue_id fallback). The + # server-authoritative fields are never forwarded regardless. env = _diloco_env_from_job_params( { "server_addr": "h:1", diff --git a/tests/unit/ml/diloco/test_dashboard.py b/tests/unit/ml/diloco/test_dashboard.py index e3b5535b..121ddfc7 100644 --- a/tests/unit/ml/diloco/test_dashboard.py +++ b/tests/unit/ml/diloco/test_dashboard.py @@ -18,6 +18,19 @@ def _make_state_dict(dim=8, num_layers=2, seed=42): return {f"layer{i}.weight": torch.randn(dim, dim) for i in range(num_layers)} +def test_compute_model_hash_is_deterministic_and_shape_sensitive(): + """The /info model_hash depends on (name, shape) only — stable across + runs/values, changes when the parameter topology changes.""" + a = {"w": torch.randn(4, 4), "b": torch.randn(4)} + a2 = {"w": torch.zeros(4, 4), "b": torch.ones(4)} # same shapes, diff values + c = {"w": torch.randn(4, 8), "b": torch.randn(4)} # different shape + h = DiLoCoServer._compute_model_hash + assert h(a) == h(a2) # value-independent + assert h(a) != h(c) # shape-sensitive + # Insertion order doesn't matter (names are sorted). + assert h({"b": torch.randn(4), "w": torch.randn(4, 4)}) == h(a) + + def _simple_sgd(params): return torch.optim.SGD(params, lr=1.0, momentum=0.5) @@ -290,3 +303,22 @@ def test_info_carries_output_dir(self, server, tmp_path): _, _, body = _get(f"http://localhost:{server.port}/info") data = json.loads(body) assert data.get("output_dir") == str(tmp_path) + + def test_info_advertises_authoritative_settings(self, server): + """/info is the authority for the must-match worker settings, so + every field is present and non-null (the worker takes them + verbatim). A non-DyLU server still advertises a sync_every.""" + _, _, body = _get(f"http://localhost:{server.port}/info") + data = json.loads(body) + assert data["settings_authority"] == "server" + assert isinstance(data["model_hash"], str) and data["model_hash"] + exp = data["expected_client_settings"] + # Non-DyLU server still advertises a concrete sync_every (the + # operator-set dylu_base_sync_every), so the worker can drop its + # own --sync-every entirely. + assert exp["sync_every"] == server.dylu_base_sync_every + assert exp["sync_every"] is not None + assert exp["bf16_comm"] is True + assert exp["dylu"] is False + assert exp["num_fragments_default"] == 1 + assert exp["heartbeat_timeout"] == server.heartbeat_timeout diff --git a/tests/unit/ml/diloco/test_diloco_callback.py b/tests/unit/ml/diloco/test_diloco_callback.py index 26a06318..8963c880 100644 --- a/tests/unit/ml/diloco/test_diloco_callback.py +++ b/tests/unit/ml/diloco/test_diloco_callback.py @@ -52,6 +52,36 @@ def _make_control(): _CLIENT_PATCH = "forgather.ml.diloco.client.DiLoCoClient" +def _info( + sync_every=500, + dylu=False, + bf16_comm=True, + num_fragments_default=1, + heartbeat_timeout=0, +): + """A minimal /info payload as the server would return it. The four + server-authoritative settings live under expected_client_settings.""" + return { + "mode": "sync", + "num_parameters": 64, + "model_hash": "deadbeef", + "settings_authority": "server", + "expected_client_settings": { + "sync_every": sync_every, + "dylu": dylu, + "bf16_comm": bf16_comm, + "num_fragments_min": 1, + "num_fragments_default": num_fragments_default, + "heartbeat_timeout": heartbeat_timeout, + }, + } + + +def _stub_info(MockClient, **kwargs): + """Point the mocked DiLoCoClient's get_info at an _info() payload.""" + MockClient.return_value.get_info.return_value = _info(**kwargs) + + class TestFailFastWhenUnconfigured: """The callback was reworked from "silent no-op when DILOCO_SERVER is unset" to "fail fast on misconfiguration." The previous silent @@ -117,25 +147,28 @@ def test_on_train_begin_raises_when_server_unreachable( self, MockClient, MockWorker ): """on_train_begin raises DiLoCoServerUnreachable when the - /status probe at startup fails. Surfaces the failure while + /info probe at startup fails. Surfaces the failure while the operator's still watching the TTY, not 500 steps in.""" from forgather.ml.diloco.client import DiLoCoServerUnreachable - MockClient.return_value.get_status.side_effect = ConnectionError("refused") + MockClient.return_value.get_info.side_effect = ConnectionError("refused") cb = DiLoCoCallback(server_addr="unreachable:9999") args, state, control = _make_args(), _make_state(), _make_control() model = TinyModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - with pytest.raises(DiLoCoServerUnreachable, match="/status round-trip"): + with pytest.raises(DiLoCoServerUnreachable, match="/info round-trip"): cb.on_train_begin(args, state, control, model=model, optimizer=optimizer) # Worker was never started since the probe came first. MockWorker.return_value.start.assert_not_called() class TestEnvVarConfiguration: - """Environment variable reading and constructor override.""" + """Environment variable reading for the client-local knobs. The + server-authoritative settings (sync_every / bf16_comm / dylu / + num_fragments) are no longer constructor args or env vars — the + worker reads them from /info (see TestServerAuthoritativeSettings).""" def test_server_addr_from_env(self): """DILOCO_SERVER env var provides server_addr.""" @@ -150,68 +183,121 @@ def test_explicit_overrides_env(self): cb = DiLoCoCallback(server_addr="explicit:8512") assert cb.server_addr == "explicit:8512" - def test_sync_every_from_env(self): - """DILOCO_SYNC_EVERY env var provides sync_every.""" - with patch.dict(os.environ, {"DILOCO_SYNC_EVERY": "200"}): - cb = DiLoCoCallback() - assert cb.sync_every == 200 - - def test_sync_every_explicit_overrides_env(self): - """Explicit sync_every overrides env var.""" - with patch.dict(os.environ, {"DILOCO_SYNC_EVERY": "200"}): - cb = DiLoCoCallback(sync_every=300) - assert cb.sync_every == 300 - def test_worker_id_from_env(self): """DILOCO_WORKER_ID env var provides worker_id.""" with patch.dict(os.environ, {"DILOCO_WORKER_ID": "w42"}): cb = DiLoCoCallback() assert cb.worker_id == "w42" - def test_bf16_comm_from_env(self): - """DILOCO_BF16_COMM env var provides bf16_comm.""" - with patch.dict(os.environ, {"DILOCO_BF16_COMM": "false"}): - cb = DiLoCoCallback() - assert cb.bf16_comm is False - - def test_bf16_comm_default_true(self): - """bf16_comm defaults to True when env var is unset.""" - cb = DiLoCoCallback() - assert cb.bf16_comm is True - - def test_dylu_from_env(self): - """DILOCO_DYLU env var provides dylu.""" - with patch.dict(os.environ, {"DILOCO_DYLU": "1"}): - cb = DiLoCoCallback() - assert cb.dylu is True - def test_heartbeat_interval_from_env(self): """DILOCO_HEARTBEAT_INTERVAL env var provides heartbeat_interval.""" with patch.dict(os.environ, {"DILOCO_HEARTBEAT_INTERVAL": "15.0"}): cb = DiLoCoCallback() assert cb.heartbeat_interval == 15.0 - def test_num_fragments_from_env(self): - """DILOCO_NUM_FRAGMENTS env var provides num_fragments.""" - with patch.dict(os.environ, {"DILOCO_NUM_FRAGMENTS": "4"}): - cb = DiLoCoCallback() - assert cb.num_fragments == 4 + def test_server_authoritative_settings_not_constructor_args(self): + """The removed must-match settings are not accepted as kwargs.""" + for kw in ("sync_every", "bf16_comm", "dylu", "num_fragments"): + with pytest.raises(TypeError): + DiLoCoCallback(server_addr="host:8512", **{kw: 1}) def test_defaults_without_env(self): - """Default values when no env vars are set.""" + """Default values when no env vars are set. The server-authoritative + settings stay None until /info is read in on_train_begin.""" # Clear any DILOCO_* env vars env = {k: v for k, v in os.environ.items() if not k.startswith("DILOCO_")} with patch.dict(os.environ, env, clear=True): cb = DiLoCoCallback() assert cb.server_addr == "" - assert cb.sync_every == 500 assert cb.worker_id is None - assert cb.bf16_comm is True - assert cb.dylu is False assert cb.heartbeat_interval == 30.0 - assert cb.num_fragments == 1 assert cb.timeout == 600 assert cb.max_sync_retries == 3 + # Not resolved until on_train_begin negotiates with the server. + assert cb.sync_every is None + assert cb.bf16_comm is None + assert cb.dylu is None + assert cb.num_fragments is None + + +class TestServerAuthoritativeSettings: + """sync_every / bf16_comm / dylu / num_fragments are taken verbatim + from the server's /info, with no client override.""" + + @patch(_CLIENT_PATCH) + @patch(_WORKER_PATCH) + def test_worker_built_with_server_settings(self, MockWorker, MockClient): + mock_instance = MockWorker.return_value + mock_instance.sync_metrics = {} + _stub_info( + MockClient, + sync_every=250, + dylu=True, + bf16_comm=False, + num_fragments_default=3, + ) + + cb = DiLoCoCallback(server_addr="host:8512") + args, state, control = _make_args(), _make_state(), _make_control() + model = TinyModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + cb.on_train_begin(args, state, control, model=model, optimizer=optimizer) + + _, kwargs = MockWorker.call_args + assert kwargs["sync_every"] == 250 + assert kwargs["dylu"] is True + assert kwargs["bf16_comm"] is False + assert kwargs["num_fragments"] == 3 + # And the callback's own copies were updated. + assert cb.sync_every == 250 + assert cb.dylu is True + + @patch(_CLIENT_PATCH) + @patch(_WORKER_PATCH) + def test_missing_sync_every_is_fatal(self, MockWorker, MockClient): + """A server that doesn't advertise sync_every (too old) is fatal, + not silently defaulted.""" + from forgather.ml.diloco.client import DiLoCoServerUnreachable + + MockClient.return_value.get_info.return_value = _info(sync_every=None) + cb = DiLoCoCallback(server_addr="host:8512") + args, state, control = _make_args(), _make_state(), _make_control() + model = TinyModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + with pytest.raises(DiLoCoServerUnreachable, match="sync_every"): + cb.on_train_begin(args, state, control, model=model, optimizer=optimizer) + MockWorker.return_value.start.assert_not_called() + + @patch(_CLIENT_PATCH) + @patch(_WORKER_PATCH) + def test_heartbeat_interval_at_or_above_timeout_raises( + self, MockWorker, MockClient + ): + """A heartbeat cadence >= the server's death timeout is rejected + up front (it guarantees spurious eviction).""" + _stub_info(MockClient, heartbeat_timeout=20) + cb = DiLoCoCallback(server_addr="host:8512", heartbeat_interval=30.0) + args, state, control = _make_args(), _make_state(), _make_control() + model = TinyModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + with pytest.raises(ValueError, match="heartbeat_interval"): + cb.on_train_begin(args, state, control, model=model, optimizer=optimizer) + MockWorker.return_value.start.assert_not_called() + + @patch(_CLIENT_PATCH) + @patch(_WORKER_PATCH) + def test_heartbeat_timeout_zero_disables_validation(self, MockWorker, MockClient): + """heartbeat_timeout=0 means death detection is off — any cadence + is allowed.""" + mock_instance = MockWorker.return_value + mock_instance.sync_metrics = {} + _stub_info(MockClient, heartbeat_timeout=0) + cb = DiLoCoCallback(server_addr="host:8512", heartbeat_interval=999.0) + args, state, control = _make_args(), _make_state(), _make_control() + model = TinyModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + cb.on_train_begin(args, state, control, model=model, optimizer=optimizer) + mock_instance.start.assert_called_once() class TestWorkerLifecycle: @@ -227,8 +313,9 @@ def test_worker_created_on_train_begin(self, MockWorker, MockClient): """on_train_begin creates and starts a DiLoCoWorker.""" mock_instance = MockWorker.return_value mock_instance.sync_metrics = {} + _stub_info(MockClient) # server advertises sync_every=500, defaults - cb = DiLoCoCallback(server_addr="host:8512", sync_every=100) + cb = DiLoCoCallback(server_addr="host:8512") args, state, control = _make_args(), _make_state(), _make_control() model = TinyModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) @@ -239,7 +326,7 @@ def test_worker_created_on_train_begin(self, MockWorker, MockClient): model=model, optimizer=optimizer, server_addr="host:8512", - sync_every=100, + sync_every=500, # from /info worker_id=None, bf16_comm=True, timeout=600, @@ -252,8 +339,8 @@ def test_worker_created_on_train_begin(self, MockWorker, MockClient): verify_tls=True, ) mock_instance.start.assert_called_once() - # Pre-probe should have happened first. - MockClient.return_value.get_status.assert_called_once() + # Pre-probe (now /info) should have happened first. + MockClient.return_value.get_info.assert_called_once() @patch(_CLIENT_PATCH) @patch(_WORKER_PATCH) @@ -261,6 +348,7 @@ def test_worker_stopped_on_train_end(self, MockWorker, MockClient): """on_train_end stops the worker.""" mock_instance = MockWorker.return_value mock_instance.sync_metrics = {} + _stub_info(MockClient) cb = DiLoCoCallback(server_addr="host:8512") args, state, control = _make_args(), _make_state(), _make_control() @@ -289,18 +377,22 @@ def test_missing_model_raises(self, MockWorker, MockClient): @patch(_CLIENT_PATCH) @patch(_WORKER_PATCH) def test_custom_parameters_passed_to_worker(self, MockWorker, MockClient): - """All callback parameters are forwarded to DiLoCoWorker.""" + """Client-local params come from the constructor; the four + server-authoritative ones come from /info.""" mock_instance = MockWorker.return_value mock_instance.sync_metrics = {} + _stub_info( + MockClient, + sync_every=200, + dylu=True, + bf16_comm=False, + num_fragments_default=4, + ) cb = DiLoCoCallback( server_addr="remote:9999", - sync_every=200, worker_id="test_worker", - bf16_comm=False, - dylu=True, heartbeat_interval=10.0, - num_fragments=4, timeout=300, max_sync_retries=5, ) @@ -314,13 +406,13 @@ def test_custom_parameters_passed_to_worker(self, MockWorker, MockClient): model=model, optimizer=optimizer, server_addr="remote:9999", - sync_every=200, + sync_every=200, # from /info worker_id="test_worker", - bf16_comm=False, + bf16_comm=False, # from /info timeout=300, - dylu=True, + dylu=True, # from /info heartbeat_interval=10.0, - num_fragments=4, + num_fragments=4, # from /info max_sync_retries=5, param_view=None, auth_token=None, @@ -340,6 +432,7 @@ def test_pipeline_trainer_builds_param_view_and_group_kwargs( ): mock_instance = MockWorker.return_value mock_instance.sync_metrics = {} + _stub_info(MockClient) # Fake pipeline trainer: only the attributes the callback reads. fake_trainer = MagicMock() @@ -382,6 +475,7 @@ def test_non_pipeline_trainer_keeps_solo_path(self, MockWorker, MockClient): group kwargs.""" mock_instance = MockWorker.return_value mock_instance.sync_metrics = {} + _stub_info(MockClient) fake_trainer = MagicMock() fake_trainer.pipeline_modules = None # not a pipeline trainer @@ -422,6 +516,7 @@ def test_metrics_injected_on_log(self, MockWorker, MockClient): "diloco/local_step": 42, "diloco/total_sync_time": 10.5, } + _stub_info(MockClient) cb = DiLoCoCallback(server_addr="host:8512") args, state, control = _make_args(), _make_state(), _make_control() @@ -455,6 +550,7 @@ def test_no_crash_when_logs_is_none(self, MockWorker, MockClient): """on_log handles None logs gracefully.""" mock_instance = MockWorker.return_value mock_instance.sync_metrics = {"diloco/sync_count": 1} + _stub_info(MockClient) cb = DiLoCoCallback(server_addr="host:8512") args, state, control = _make_args(), _make_state(), _make_control() @@ -484,6 +580,7 @@ def test_state_dict_captures_worker_state(self, MockWorker, MockClient): mock_instance._reconnections = 1 mock_instance._dylu_adjustments = 3 mock_instance._fragment_syncs = 8 + _stub_info(MockClient) cb = DiLoCoCallback(server_addr="host:8512") args, state, control = _make_args(), _make_state(), _make_control() @@ -527,6 +624,7 @@ def test_deferred_state_applied_on_train_begin(self, MockWorker, MockClient): """Pending state from load_state_dict is applied when worker starts.""" mock_instance = MockWorker.return_value mock_instance.sync_metrics = {} + _stub_info(MockClient) cb = DiLoCoCallback(server_addr="host:8512") @@ -568,6 +666,7 @@ def test_no_pending_state_when_not_loaded(self, MockWorker, MockClient): """on_train_begin works fine without any pending state.""" mock_instance = MockWorker.return_value mock_instance.sync_metrics = {} + _stub_info(MockClient) cb = DiLoCoCallback(server_addr="host:8512") args, state, control = _make_args(), _make_state(), _make_control() @@ -601,6 +700,7 @@ def test_roundtrip_state_dict(self, MockWorker, MockClient): mock_instance._reconnections = 0 mock_instance._dylu_adjustments = 0 mock_instance._fragment_syncs = 0 + _stub_info(MockClient) cb = DiLoCoCallback(server_addr="host:8512") args, state, control = _make_args(), _make_state(), _make_control() diff --git a/tools/forgather_server/scheduler.py b/tools/forgather_server/scheduler.py index a669f62d..98d12061 100644 --- a/tools/forgather_server/scheduler.py +++ b/tools/forgather_server/scheduler.py @@ -836,14 +836,10 @@ def _diloco_env_from_job_params( token = _diloco_token_for_server_addr(str(server)) if token: env[_DILOCO_TOKEN_ENV_VAR] = token - if diloco.get("sync_every") is not None: - env["DILOCO_SYNC_EVERY"] = str(int(diloco["sync_every"])) - if diloco.get("num_fragments") is not None: - env["DILOCO_NUM_FRAGMENTS"] = str(int(diloco["num_fragments"])) - if diloco.get("dylu") is not None: - env["DILOCO_DYLU"] = "1" if bool(diloco["dylu"]) else "0" - if diloco.get("bf16_comm") is not None: - env["DILOCO_BF16_COMM"] = "1" if bool(diloco["bf16_comm"]) else "0" + # sync_every / num_fragments / dylu / bf16_comm are server-authoritative + # (they must match across the group); the worker reads them from the + # server's /info at startup, so we no longer forward them from the + # submission. Only client-local knobs are forwarded. if diloco.get("heartbeat_interval") is not None: env["DILOCO_HEARTBEAT_INTERVAL"] = str(float(diloco["heartbeat_interval"])) # Always-set: operator-supplied value if present, otherwise the diff --git a/tools/forgather_server/webui/src/api.ts b/tools/forgather_server/webui/src/api.ts index 05fbdecc..e9789afc 100644 --- a/tools/forgather_server/webui/src/api.ts +++ b/tools/forgather_server/webui/src/api.ts @@ -1149,11 +1149,18 @@ export interface DiLoCoInfo { model_size_mb?: number; dylu_enabled?: boolean; dylu_base_sync_every?: number; + // Coarse model fingerprint; the worker validates a cached model + // bundle and uses it as an early compatibility gate (issue #53). + model_hash?: string; + // Marks the four settings below as server-owned (no client override). + settings_authority?: string; expected_client_settings?: { sync_every?: number | null; dylu?: boolean; bf16_comm?: boolean; num_fragments_min?: number; + num_fragments_default?: number; + heartbeat_timeout?: number; }; } diff --git a/tools/forgather_server/webui/src/components/SubmitModal.tsx b/tools/forgather_server/webui/src/components/SubmitModal.tsx index 1cdefa24..ea30e287 100644 --- a/tools/forgather_server/webui/src/components/SubmitModal.tsx +++ b/tools/forgather_server/webui/src/components/SubmitModal.tsx @@ -169,14 +169,10 @@ export function SubmitModal({ project, config, onClose, onSubmitted }: Props) { const [selectedDiLoCoBase, setSelectedDiLoCoBase] = useState( persistedDiLoCo.base, ); - // Per-knob form state for the dependent fields. Strings (free form) - // so empty == "use config / env default"; coerced on submit. - const [diSyncEvery, setDiSyncEvery] = useState(persistedDiLoCo.syncEvery); - const [diNumFragments, setDiNumFragments] = useState( - persistedDiLoCo.numFragments, - ); - const [diDylu, setDiDylu] = useState(persistedDiLoCo.dylu); - const [diBf16, setDiBf16] = useState(persistedDiLoCo.bf16Comm); + // sync_every / num_fragments / dylu / bf16_comm are server-authoritative + // (the worker reads them from /info); they have no form state here. Only + // client-local knobs remain. Strings (free form) so empty == "use config + // / env default"; coerced on submit. const [diHeartbeat, setDiHeartbeat] = useState( persistedDiLoCo.heartbeatInterval, ); @@ -187,22 +183,10 @@ export function SubmitModal({ project, config, onClose, onSubmitted }: Props) { useEffect(() => { const cur: DiLoCoPersisted = { base: selectedDiLoCoBase, - syncEvery: diSyncEvery, - numFragments: diNumFragments, - dylu: diDylu, - bf16Comm: diBf16, heartbeatInterval: diHeartbeat, }; persistSet(dilocoStorageKey, JSON.stringify(cur)); - }, [ - dilocoStorageKey, - selectedDiLoCoBase, - diSyncEvery, - diNumFragments, - diDylu, - diBf16, - diHeartbeat, - ]); + }, [dilocoStorageKey, selectedDiLoCoBase, diHeartbeat]); // If the persisted base isn't currently in the server list (server // went offline, was renamed, etc.) fall back to "None" AND surface a // warning. The silent fallback was the failure mode that produced @@ -244,18 +228,9 @@ export function SubmitModal({ project, config, onClose, onSubmitted }: Props) { useEffect(() => { const info: DiLoCoInfo | undefined = dilocoInfoQ.data; if (!selectedDiLoCoBase || !info) return; - const exp = info.expected_client_settings ?? {}; - if (exp.sync_every != null && diSyncEvery === "") { - setDiSyncEvery(String(exp.sync_every)); - } - if (typeof exp.dylu === "boolean") { - // A DyLU server requires the worker to opt in; otherwise the - // server's per-worker recommendations are ignored. - setDiDylu(exp.dylu); - } - if (typeof exp.bf16_comm === "boolean") { - setDiBf16(exp.bf16_comm); - } + // sync_every / dylu / bf16_comm / num_fragments are server-authoritative + // and resolved by the worker from /info — nothing to seed into the form. + // The picker shows the server's values read-only instead. // Seed --model-id-or-path from the server's output_dir so the // worker constructs its model against the same checkpoint the // server loaded from. Catches the operator-misconfiguration @@ -642,10 +617,6 @@ export function SubmitModal({ project, config, onClose, onSubmitted }: Props) { // picked a server in the radio group ("" = None). const diloco = buildDiLoCoPayload({ base: selectedDiLoCoBase, - syncEvery: diSyncEvery, - numFragments: diNumFragments, - dylu: diDylu, - bf16Comm: diBf16, heartbeatInterval: diHeartbeat, workerId: diWorkerId, }); @@ -848,14 +819,6 @@ export function SubmitModal({ project, config, onClose, onSubmitted }: Props) { servers={dilocoServersQ.data} selectedBase={selectedDiLoCoBase} onSelectBase={setSelectedDiLoCoBase} - syncEvery={diSyncEvery} - setSyncEvery={setDiSyncEvery} - numFragments={diNumFragments} - setNumFragments={setDiNumFragments} - dylu={diDylu} - setDylu={setDiDylu} - bf16Comm={diBf16} - setBf16Comm={setDiBf16} heartbeatInterval={diHeartbeat} setHeartbeatInterval={setDiHeartbeat} workerId={diWorkerId} @@ -958,19 +921,11 @@ function formatNproc(v: number | string | null): string { interface DiLoCoPersisted { base: string; - syncEvery: string; - numFragments: string; - dylu: boolean; - bf16Comm: boolean; heartbeatInterval: string; } const DEFAULT_DILOCO_PERSISTED: DiLoCoPersisted = { base: "", - syncEvery: "", - numFragments: "", - dylu: false, - bf16Comm: true, heartbeatInterval: "", }; @@ -978,14 +933,6 @@ interface DiLoCoPickerProps { servers: DiLoCoServer[]; selectedBase: string; onSelectBase: (base: string) => void; - syncEvery: string; - setSyncEvery: (v: string) => void; - numFragments: string; - setNumFragments: (v: string) => void; - dylu: boolean; - setDylu: (v: boolean) => void; - bf16Comm: boolean; - setBf16Comm: (v: boolean) => void; heartbeatInterval: string; setHeartbeatInterval: (v: string) => void; workerId: string; @@ -1006,14 +953,6 @@ function DiLoCoPicker(props: DiLoCoPickerProps) { servers, selectedBase, onSelectBase, - syncEvery, - setSyncEvery, - numFragments, - setNumFragments, - dylu, - setDylu, - bf16Comm, - setBf16Comm, heartbeatInterval, setHeartbeatInterval, workerId, @@ -1138,16 +1077,67 @@ function DiLoCoPicker(props: DiLoCoPickerProps) { {info.num_parameters !== undefined && ( <> · {info.num_parameters.toLocaleString()} params )} - {info.dylu_enabled && ( - <> - {" "} - · DyLU base sync_every={" "} - {info.dylu_base_sync_every} - - )} )} + {/* Server-authoritative settings (issue #95 follow-up). These + must match across the group, so the server owns them and + the worker reads them from /info — there is no operator + knob. Shown read-only so the operator can see what the + worker will use. */} + {info && ( +
+ + Managed by server (read-only) + +
+ + sync_every:{" "} + + {info.expected_client_settings?.sync_every ?? "—"} + + + + num_fragments:{" "} + + {info.expected_client_settings?.num_fragments_default ?? + 1} + + + + dylu:{" "} + + {info.expected_client_settings?.dylu ? "on" : "off"} + + + + bf16_comm:{" "} + + {info.expected_client_settings?.bf16_comm === false + ? "off" + : "on"} + + +
+
+ )} +
- - - -
)} @@ -1235,17 +1177,17 @@ function DiLoCoPicker(props: DiLoCoPickerProps) { interface DiLoCoFormSnapshot { base: string; - syncEvery: string; - numFragments: string; - dylu: boolean; - bf16Comm: boolean; heartbeatInterval: string; workerId: string; } /** Construct the ``job_params.diloco`` payload from the form snapshot. * Returns null when the operator picked "None" — callers should skip - * ``job_params.diloco`` entirely in that case. */ + * ``job_params.diloco`` entirely in that case. + * + * sync_every / num_fragments / dylu / bf16_comm are intentionally absent: + * they are server-authoritative and the worker reads them from /info, so + * the submission never carries them. */ function buildDiLoCoPayload( s: DiLoCoFormSnapshot, ): Record | null { @@ -1260,19 +1202,7 @@ function buildDiLoCoPayload( const serverAddr = s.base.replace(/\/$/, ""); const payload: Record = { server_addr: serverAddr, - dylu: s.dylu, - bf16_comm: s.bf16Comm, }; - const sync = s.syncEvery.trim(); - if (sync) { - const n = Number(sync); - if (Number.isFinite(n)) payload.sync_every = Math.max(1, Math.floor(n)); - } - const frags = s.numFragments.trim(); - if (frags) { - const n = Number(frags); - if (Number.isFinite(n)) payload.num_fragments = Math.max(1, Math.floor(n)); - } const hb = s.heartbeatInterval.trim(); if (hb) { const n = Number(hb); From 1be9c9d7a056582c563657c9a4f56aed49553463 Mon Sep 17 00:00:00 2001 From: Jason dinAlt Date: Fri, 29 May 2026 16:11:12 +0000 Subject: [PATCH 2/2] diloco: give the server real knobs for the authoritative settings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Review of the previous commit (and the server modal) surfaced that three of the four settings made server-authoritative had no server-side knob — /info just hardcoded them. That made streaming-fragment mode and full-precision comm unreachable, and forced sync_every through the misnamed dylu_base_sync_every. - DiLoCoServer gains sync_every / bf16_comm / num_fragments constructor args; /info advertises the real values (sync_every = dylu_base when DyLU is on, else the dedicated sync_every). - `forgather diloco server` gains --sync-every / --num-fragments / --no-bf16. Threaded through the spawn chain (build_diloco_server_command -> launcher -> scheduler) and surfaced in the DiLoCoServerModal as a "Worker settings (group-wide)" section. bf16_comm stays centralized on the server (one wire precision for the whole group), per the design call. - Fix DDP follower deadlock: the callback now broadcasts an error sentinel from the leader on /info failure so every rank raises together, instead of followers blocking forever in broadcast_object_list when the leader aborts (e.g. a too-old server with null sync_every). - Docs/messages: diloco-architecture.md and diloco.md document the new server flags and stop pointing at the removed worker flags; the worker DyLU-under-pipeline error points at the server --dylu; fixed a misleading test_routes_diloco fixture. Tests: server /info reflects configured group settings (+ DyLU-base case); build_diloco_server_command emits the new flags. Full tests/unit/ml/diloco + tests/unit/forgather_server green (978 passed). Co-Authored-By: Claude Opus 4.8 (1M context) --- docs/trainers/diloco-architecture.md | 17 +++-- docs/trainers/diloco.md | 9 +++ src/forgather/cli/diloco.py | 3 + src/forgather/cli/diloco_args.py | 33 +++++++++ src/forgather/ml/diloco/server.py | 30 +++++--- src/forgather/ml/diloco/worker.py | 5 +- .../ml/trainer/callbacks/diloco_callback.py | 31 +++++--- .../forgather_server/test_routes_diloco.py | 2 +- .../test_scheduler_diloco_server_token.py | 36 ++++++++++ tests/unit/ml/diloco/test_dashboard.py | 65 +++++++++++++++-- tools/forgather_server/diloco_server_ops.py | 12 ++++ tools/forgather_server/launcher.py | 6 ++ tools/forgather_server/scheduler.py | 3 + .../src/components/DiLoCoServerModal.tsx | 72 +++++++++++++++++++ 14 files changed, 289 insertions(+), 35 deletions(-) diff --git a/docs/trainers/diloco-architecture.md b/docs/trainers/diloco-architecture.md index 423d8a48..1ff7a1d4 100644 --- a/docs/trainers/diloco-architecture.md +++ b/docs/trainers/diloco-architecture.md @@ -852,19 +852,24 @@ parameter values. **Cause:** BFloat16 has ~3 digits of precision. Very small pseudo-gradients (difference between global and local params) may be rounded to zero. -**Mitigation:** Disable bf16 communication with `--no-bf16` or -`bf16_comm=False`. This doubles bandwidth usage. +**Mitigation:** Disable bf16 communication by starting the server with +`--no-bf16`. `bf16_comm` is server-authoritative (the whole group shares one +wire precision), so every worker adopts it from `/info` — there is no worker +flag. This doubles bandwidth usage. ### Fragment sync deadlock -**Symptom:** Workers hang when using `--num-fragments > 1` in sync mode. +**Symptom:** Workers hang when the server runs with `--num-fragments > 1` in +sync mode. **Cause:** Per-fragment barriers require all workers to submit the same fragment -in the same round. If workers have different `sync_every` values (e.g., from -DyLU) or different `num_fragments`, their fragment schedules won't align. +in the same round. Misaligned `sync_every` or `num_fragments` across workers +would break this. **Requirement:** All workers in synchronous fragment mode must use the same -`sync_every` and `num_fragments`. +`sync_every` and `num_fragments`. This is now guaranteed automatically: both +are server-authoritative and adopted by every worker from `/info`, so they +cannot diverge. Set them on the server (`--sync-every`, `--num-fragments`). ### Async staleness drift diff --git a/docs/trainers/diloco.md b/docs/trainers/diloco.md index 6bdf75ea..8f19a518 100644 --- a/docs/trainers/diloco.md +++ b/docs/trainers/diloco.md @@ -118,6 +118,15 @@ Server arguments: - `--dylu-base-sync-every N`: Base sync interval for the fastest worker (default: 500) - `--from-checkpoint FROM_CHECKPOINT`: Load model from specified checkpoint path. Overrides loading from newest. +Group-wide worker settings (must match across the group, so they live on the +server; every worker adopts them from `/info` — there are no worker flags): +- `--sync-every N`: Local optimizer steps between syncs (H). Default: 500. + Under `--dylu`, the DyLU base rate is used instead. +- `--num-fragments N`: Streaming-sync fragments every worker splits the model + into (1 = no streaming). Default: 1. +- `--no-bf16`: Send full-precision pseudo-gradients instead of bfloat16 + (centralized: one wire precision for the whole group). Default: bf16 on. + ```bash # Load a specific checkpoint and save checkpoints to specified directory. forgather diloco server -o path/to/output --from-checkpoint output_models/my_model/checkpoint-1000 -n 2 diff --git a/src/forgather/cli/diloco.py b/src/forgather/cli/diloco.py index cf103cd0..29ac10b2 100644 --- a/src/forgather/cli/diloco.py +++ b/src/forgather/cli/diloco.py @@ -195,6 +195,9 @@ def outer_optimizer_factory(params): dn_buffer_size=dn_buffer_size, dylu_enabled=dylu, dylu_base_sync_every=dylu_base, + sync_every=args.sync_every, + bf16_comm=args.bf16_comm, + num_fragments=args.num_fragments, heartbeat_timeout=heartbeat_timeout, min_workers=min_workers, default_work_units=default_work_units, diff --git a/src/forgather/cli/diloco_args.py b/src/forgather/cli/diloco_args.py index 84566bb3..1b34cb60 100644 --- a/src/forgather/cli/diloco_args.py +++ b/src/forgather/cli/diloco_args.py @@ -117,6 +117,39 @@ def create_diloco_parser(global_args): default=500, help="DyLU base sync_every for the fastest worker (default: 500)", ) + # Group-wide worker settings (issue #53 follow-up). These must match + # across every worker, so they're owned by the server and the workers + # adopt them from /info — there are no corresponding worker flags. + server_parser.add_argument( + "--sync-every", + type=int, + default=500, + help=( + "Local optimizer steps between syncs (H), applied by every\n" + "worker. Under --dylu the DyLU base rate is used instead.\n" + "(default: 500)" + ), + ) + server_parser.add_argument( + "--num-fragments", + type=int, + default=1, + help=( + "Streaming-sync fragments every worker splits the model into.\n" + "1 = no streaming. Must be uniform across the group, so it's\n" + "set here, not per worker. (default: 1)" + ), + ) + server_parser.add_argument( + "--no-bf16", + dest="bf16_comm", + action="store_false", + help=( + "Send full-precision pseudo-gradients instead of bfloat16.\n" + "Centralized: the group's wire precision is set on the server\n" + "and adopted by every worker. (default: bf16 enabled)" + ), + ) server_parser.add_argument( "--heartbeat-timeout", type=float, diff --git a/src/forgather/ml/diloco/server.py b/src/forgather/ml/diloco/server.py index 8de7d7ed..38b45fdc 100644 --- a/src/forgather/ml/diloco/server.py +++ b/src/forgather/ml/diloco/server.py @@ -426,6 +426,9 @@ def __init__( dn_buffer_size: int = 0, dylu_enabled: bool = False, dylu_base_sync_every: int = 500, + sync_every: int = 500, + bf16_comm: bool = True, + num_fragments: int = 1, heartbeat_timeout: float = 120.0, min_workers: int = 1, auth_token: Optional[str] = None, @@ -456,6 +459,16 @@ def __init__( self.dn_buffer_size = dn_buffer_size self.dylu_enabled = dylu_enabled self.dylu_base_sync_every = dylu_base_sync_every + # Group-wide worker settings the server is authoritative for (issue + # #53 follow-up). These MUST match across the group for the sync / + # fragment barriers to be coherent, so the operator sets them on + # the server and workers adopt them verbatim from /info. ``bf16_comm`` + # is centralized here too (the server doesn't need it to decode an + # upload, but a single operator-facing knob keeps the group's wire + # format consistent rather than per-worker). + self.sync_every = sync_every + self.bf16_comm = bf16_comm + self.num_fragments = num_fragments self.heartbeat_timeout = heartbeat_timeout self.default_work_units = default_work_units self.outer_optimizer_factory = ( @@ -2146,16 +2159,17 @@ def _handle_info(self, handler: BaseHTTPRequestHandler): # that intent to clients and tooling. "settings_authority": "server", "expected_client_settings": { - # Every worker ramps to the same inner-step cadence. We - # advertise ``dylu_base_sync_every`` as the canonical - # ``sync_every`` whether or not DyLU is enabled — it's the - # operator-set anchor either way, and a non-null value lets - # the worker drop its own --sync-every entirely. - "sync_every": self.dylu_base_sync_every, + # The group's inner-step cadence. Under DyLU every worker + # ramps to the base rate (the per-worker scaling anchor); + # otherwise it's the operator-set server ``sync_every``. + # Always non-null so the worker can drop its own knob. + "sync_every": ( + self.dylu_base_sync_every if self.dylu_enabled else self.sync_every + ), "dylu": self.dylu_enabled, - "bf16_comm": True, + "bf16_comm": self.bf16_comm, "num_fragments_min": 1, - "num_fragments_default": 1, + "num_fragments_default": self.num_fragments, # Exposed so the worker can validate its (client-local) # heartbeat send cadence against the server's death timeout. "heartbeat_timeout": self.heartbeat_timeout, diff --git a/src/forgather/ml/diloco/worker.py b/src/forgather/ml/diloco/worker.py index 719f063e..6c9b1123 100644 --- a/src/forgather/ml/diloco/worker.py +++ b/src/forgather/ml/diloco/worker.py @@ -221,8 +221,9 @@ def __init__( f"DiLoCo DyLU is not compatible with pipeline " f"groups (pp_world_size={self.pp_world_size}). " "Per-rank sync-every adjustments would desync " - "the group's barrier. Disable --diloco-dylu or " - "use a non-pipeline trainer." + "the group's barrier. DyLU is server-controlled — " + "restart the diloco server without --dylu, or use a " + "non-pipeline trainer." ) else: self._is_leader = self._ddp_rank == 0 diff --git a/src/forgather/ml/trainer/callbacks/diloco_callback.py b/src/forgather/ml/trainer/callbacks/diloco_callback.py index 88034c00..e1490cc6 100644 --- a/src/forgather/ml/trainer/callbacks/diloco_callback.py +++ b/src/forgather/ml/trainer/callbacks/diloco_callback.py @@ -312,17 +312,19 @@ def on_train_begin( # num_fragments) the worker must adopt verbatim. # # DDP rank 0 only — followers don't talk to the server. The leader - # fetches /info and broadcasts the resolved settings to followers - # so every rank syncs in lockstep. If the leader fails the probe - # and aborts, followers deadlock on the broadcast below (or the - # one inside DiLoCoWorker.start()) — but the leader's exception is - # the actionable signal, and the trainer tears the followers down - # once the leader's process exits. + # fetches /info and broadcasts the result to followers so every + # rank syncs in lockstep. Crucially, the leader broadcasts an + # *error sentinel* on failure (server down, too-old server, etc.) + # rather than raising before the collective — otherwise followers + # would block forever inside broadcast_object_list waiting on a + # rank 0 that already exited. With the sentinel, every rank raises + # the same actionable error together and the job fails fast. import torch.distributed as dist ddp = dist.is_available() and dist.is_initialized() is_leader = not ddp or dist.get_rank() == 0 settings: Optional[Dict[str, Any]] = None + nego_error: Optional[str] = None if is_leader: probe = DiLoCoClient( self.server_addr, @@ -332,18 +334,25 @@ def on_train_begin( ) try: info = probe.get_info() + settings = self._resolve_server_settings(info) + except DiLoCoServerUnreachable as exc: + # _resolve_server_settings already produced an actionable + # message (e.g. server too old to advertise sync_every). + nego_error = str(exc) except Exception as exc: - raise DiLoCoServerUnreachable( + nego_error = ( f"DiLoCoCallback: /info round-trip to " f"{self.server_addr!r} failed at startup: {exc}. " f"The server must be running and reachable before " f"workers can register." - ) from exc - settings = self._resolve_server_settings(info) + ) if ddp: - holder = [settings] + holder = [(settings, nego_error)] dist.broadcast_object_list(holder, src=0) - settings = holder[0] + settings, nego_error = holder[0] + if nego_error is not None: + # Raised on every rank (leader and followers) — no deadlock. + raise DiLoCoServerUnreachable(nego_error) self.sync_every = settings["sync_every"] self.bf16_comm = settings["bf16_comm"] diff --git a/tests/unit/forgather_server/test_routes_diloco.py b/tests/unit/forgather_server/test_routes_diloco.py index 3afd0167..aaa9695d 100644 --- a/tests/unit/forgather_server/test_routes_diloco.py +++ b/tests/unit/forgather_server/test_routes_diloco.py @@ -302,7 +302,7 @@ def test_info_proxies(self, client, no_local_servers, monkeypatch): upstream = { "output_dir": "/tmp/m", "num_parameters": 1234, - "expected_client_settings": {"sync_every": None}, + "expected_client_settings": {"sync_every": 500}, } def fake_handler(request: httpx.Request) -> httpx.Response: diff --git a/tests/unit/forgather_server/test_scheduler_diloco_server_token.py b/tests/unit/forgather_server/test_scheduler_diloco_server_token.py index 0d58722a..05c59f30 100644 --- a/tests/unit/forgather_server/test_scheduler_diloco_server_token.py +++ b/tests/unit/forgather_server/test_scheduler_diloco_server_token.py @@ -126,6 +126,42 @@ def test_build_diloco_command_bulk_port_flags(): assert "--no-bulk-auth" in cmd +def test_build_diloco_command_group_settings(): + """sync_every/num_fragments/bf16_comm (server-authoritative, adopted by + workers from /info) surface on the spawn argv.""" + from forgather_server.diloco_server_ops import build_diloco_server_command + + cmd = build_diloco_server_command( + output_dir="/tmp/out", + num_workers=1, + port=8512, + sync_every=250, + num_fragments=4, + bf16_comm=False, + ) + assert "--sync-every" in cmd + assert cmd[cmd.index("--sync-every") + 1] == "250" + assert "--num-fragments" in cmd + assert cmd[cmd.index("--num-fragments") + 1] == "4" + assert "--no-bf16" in cmd + + +def test_build_diloco_command_group_settings_defaults(): + """Defaults: --sync-every emitted (always meaningful), but + --num-fragments (1) and --no-bf16 (bf16 on) are omitted for a readable + argv.""" + from forgather_server.diloco_server_ops import build_diloco_server_command + + cmd = build_diloco_server_command( + output_dir="/tmp/out", + num_workers=1, + port=8512, + ) + assert "--sync-every" in cmd + assert "--num-fragments" not in cmd + assert "--no-bf16" not in cmd + + def test_build_diloco_command_no_bulk_port_omits_flags(): """No --bulk-port → bulk_tls/bulk_auth are silently dropped.""" from forgather_server.diloco_server_ops import build_diloco_server_command diff --git a/tests/unit/ml/diloco/test_dashboard.py b/tests/unit/ml/diloco/test_dashboard.py index 121ddfc7..4333a048 100644 --- a/tests/unit/ml/diloco/test_dashboard.py +++ b/tests/unit/ml/diloco/test_dashboard.py @@ -307,18 +307,69 @@ def test_info_carries_output_dir(self, server, tmp_path): def test_info_advertises_authoritative_settings(self, server): """/info is the authority for the must-match worker settings, so every field is present and non-null (the worker takes them - verbatim). A non-DyLU server still advertises a sync_every.""" + verbatim). A non-DyLU server advertises its own ``sync_every``.""" _, _, body = _get(f"http://localhost:{server.port}/info") data = json.loads(body) assert data["settings_authority"] == "server" assert isinstance(data["model_hash"], str) and data["model_hash"] exp = data["expected_client_settings"] - # Non-DyLU server still advertises a concrete sync_every (the - # operator-set dylu_base_sync_every), so the worker can drop its - # own --sync-every entirely. - assert exp["sync_every"] == server.dylu_base_sync_every + # Non-DyLU server advertises its dedicated sync_every (not None). + assert exp["sync_every"] == server.sync_every assert exp["sync_every"] is not None - assert exp["bf16_comm"] is True + assert exp["bf16_comm"] == server.bf16_comm assert exp["dylu"] is False - assert exp["num_fragments_default"] == 1 + assert exp["num_fragments_default"] == server.num_fragments assert exp["heartbeat_timeout"] == server.heartbeat_timeout + + def test_info_reflects_configured_group_settings(self, tmp_path): + """A server configured with non-default group settings advertises + exactly those — the operator's single source of truth for the + whole group.""" + sd = _make_state_dict() + ckpt = make_initial_checkpoint(sd, tmp_path / "initial") + srv = DiLoCoServer( + output_dir=str(tmp_path), + from_checkpoint=str(ckpt), + num_workers=1, + port=0, + sync_every=250, + num_fragments=4, + bf16_comm=False, + outer_optimizer_factory=_simple_sgd, + ) + srv.start() + time.sleep(0.2) + try: + _, _, body = _get(f"http://localhost:{srv.port}/info") + exp = json.loads(body)["expected_client_settings"] + assert exp["sync_every"] == 250 + assert exp["num_fragments_default"] == 4 + assert exp["bf16_comm"] is False + finally: + srv.stop() + + def test_info_sync_every_uses_dylu_base_when_dylu_enabled(self, tmp_path): + """Under DyLU the advertised sync_every is the DyLU base rate (the + per-worker scaling anchor), not the plain sync_every.""" + sd = _make_state_dict() + ckpt = make_initial_checkpoint(sd, tmp_path / "initial") + srv = DiLoCoServer( + output_dir=str(tmp_path), + from_checkpoint=str(ckpt), + num_workers=1, + port=0, + async_mode=True, + dylu_enabled=True, + dylu_base_sync_every=128, + sync_every=999, # ignored while DyLU is on + outer_optimizer_factory=_simple_sgd, + ) + srv.start() + time.sleep(0.2) + try: + _, _, body = _get(f"http://localhost:{srv.port}/info") + exp = json.loads(body)["expected_client_settings"] + assert exp["dylu"] is True + assert exp["sync_every"] == 128 + finally: + srv.stop() diff --git a/tools/forgather_server/diloco_server_ops.py b/tools/forgather_server/diloco_server_ops.py index 1db6f10e..8cb9e4d5 100644 --- a/tools/forgather_server/diloco_server_ops.py +++ b/tools/forgather_server/diloco_server_ops.py @@ -33,6 +33,9 @@ def build_diloco_server_command( dn_buffer_size: int = 0, dylu: bool = False, dylu_base_sync_every: int = 500, + sync_every: int = 500, + bf16_comm: bool = True, + num_fragments: int = 1, from_checkpoint: Optional[str] = None, save_every: int = 10, save_total_limit: int = 3, @@ -80,6 +83,15 @@ def build_diloco_server_command( # (500) is fine, but surface the explicit value so the spawn # argv reflects the operator's intent. cmd.extend(["--dylu-base-sync-every", str(int(dylu_base_sync_every))]) + # Group-wide worker settings the server is authoritative for (issue + # #53 follow-up). sync_every is always meaningful (the non-DyLU + # cadence); num_fragments/bf16 only when they diverge from the CLI + # default, keeping argv readable. + cmd.extend(["--sync-every", str(int(sync_every))]) + if num_fragments and int(num_fragments) > 1: + cmd.extend(["--num-fragments", str(int(num_fragments))]) + if bf16_comm is False: + cmd.append("--no-bf16") if from_checkpoint: cmd.extend(["--from-checkpoint", from_checkpoint]) # save_every: 0 disables periodic save — the CLI accepts 0, so pass diff --git a/tools/forgather_server/launcher.py b/tools/forgather_server/launcher.py index b008c9d5..4f71bb2d 100644 --- a/tools/forgather_server/launcher.py +++ b/tools/forgather_server/launcher.py @@ -476,6 +476,9 @@ def spawn_diloco_server_process( dn_buffer_size: int = 0, dylu: bool = False, dylu_base_sync_every: int = 500, + sync_every: int = 500, + bf16_comm: bool = True, + num_fragments: int = 1, from_checkpoint: Optional[str] = None, save_every: int = 10, save_total_limit: int = 3, @@ -513,6 +516,9 @@ def spawn_diloco_server_process( dn_buffer_size=dn_buffer_size, dylu=dylu, dylu_base_sync_every=dylu_base_sync_every, + sync_every=sync_every, + bf16_comm=bf16_comm, + num_fragments=num_fragments, from_checkpoint=from_checkpoint, save_every=save_every, save_total_limit=save_total_limit, diff --git a/tools/forgather_server/scheduler.py b/tools/forgather_server/scheduler.py index 98d12061..ce902a89 100644 --- a/tools/forgather_server/scheduler.py +++ b/tools/forgather_server/scheduler.py @@ -582,6 +582,9 @@ def _build_diloco_server(item, gpu_indices, tty_path): dn_buffer_size=int(p.get("dn_buffer_size", 0) or 0), dylu=bool(p.get("dylu", False)), dylu_base_sync_every=int(p.get("dylu_base_sync_every", 500) or 500), + sync_every=int(p.get("sync_every", 500) or 500), + bf16_comm=bool(p.get("bf16_comm", True)), + num_fragments=int(p.get("num_fragments", 1) or 1), from_checkpoint=p.get("from_checkpoint") or None, save_every=int(p.get("save_every", 10) or 0), save_total_limit=int(p.get("save_total_limit", 3) or 0), diff --git a/tools/forgather_server/webui/src/components/DiLoCoServerModal.tsx b/tools/forgather_server/webui/src/components/DiLoCoServerModal.tsx index 1573df47..f6390724 100644 --- a/tools/forgather_server/webui/src/components/DiLoCoServerModal.tsx +++ b/tools/forgather_server/webui/src/components/DiLoCoServerModal.tsx @@ -21,6 +21,12 @@ interface PersistedAdHoc { dnBufferSize: number; dylu: boolean; dyluBase: number; + // Group-wide worker settings the server is authoritative for. They + // must match across the group, so the operator sets them here and the + // workers adopt them from /info (no per-worker knob). + syncEvery: number; + numFragments: number; + bf16Comm: boolean; fromCheckpoint: string; saveEvery: number; saveTotalLimit: number; @@ -50,6 +56,9 @@ const DEFAULT_AD_HOC: PersistedAdHoc = { dnBufferSize: 0, dylu: false, dyluBase: 500, + syncEvery: 500, + numFragments: 1, + bf16Comm: true, fromCheckpoint: "", saveEvery: 10, saveTotalLimit: 3, @@ -144,6 +153,9 @@ export function DiLoCoServerModal({ dnBufferSize: pickNum(editingService.args, "dn_buffer_size", 0), dylu: pickBool(editingService.args, "dylu", false), dyluBase: pickNum(editingService.args, "dylu_base_sync_every", 500), + syncEvery: pickNum(editingService.args, "sync_every", 500), + numFragments: pickNum(editingService.args, "num_fragments", 1), + bf16Comm: pickBool(editingService.args, "bf16_comm", true), fromCheckpoint: pickStr(editingService.args, "from_checkpoint", ""), saveEvery: pickNum(editingService.args, "save_every", 10), saveTotalLimit: pickNum(editingService.args, "save_total_limit", 3), @@ -172,6 +184,9 @@ export function DiLoCoServerModal({ const [dnBufferSize, setDnBufferSize] = useState(seed.dnBufferSize); const [dylu, setDylu] = useState(seed.dylu); const [dyluBase, setDyluBase] = useState(seed.dyluBase); + const [syncEvery, setSyncEvery] = useState(seed.syncEvery); + const [numFragments, setNumFragments] = useState(seed.numFragments); + const [bf16Comm, setBf16Comm] = useState(seed.bf16Comm); const [fromCheckpoint, setFromCheckpoint] = useState(seed.fromCheckpoint); const [saveEvery, setSaveEvery] = useState(seed.saveEvery); const [saveTotalLimit, setSaveTotalLimit] = useState(seed.saveTotalLimit); @@ -215,6 +230,10 @@ export function DiLoCoServerModal({ no_nesterov: noNesterov, heartbeat_timeout: heartbeatTimeout, min_workers: minWorkers, + // Group-wide worker settings (adopted from /info; no worker knob). + sync_every: syncEvery, + num_fragments: numFragments, + bf16_comm: bf16Comm, }; if (dnBufferSize > 0) args.dn_buffer_size = dnBufferSize; if (dylu) args.dylu_base_sync_every = dyluBase; @@ -251,6 +270,9 @@ export function DiLoCoServerModal({ dnBufferSize, dylu, dyluBase, + syncEvery, + numFragments, + bf16Comm, fromCheckpoint: trimmedFromCheckpoint, saveEvery, saveTotalLimit, @@ -432,6 +454,56 @@ export function DiLoCoServerModal({ +
+ Worker settings (group-wide) +
+ These must match across every worker, so they're set here and + adopted by each worker from /info — there are no + per-worker flags. +
+
+ + + +
+
+
Sync mode