Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions docs/trainers/diloco-architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -852,19 +852,24 @@ parameter values.
**Cause:** BFloat16 has ~3 digits of precision. Very small pseudo-gradients
(difference between global and local params) may be rounded to zero.

**Mitigation:** Disable bf16 communication with `--no-bf16` or
`bf16_comm=False`. This doubles bandwidth usage.
**Mitigation:** Disable bf16 communication by starting the server with
`--no-bf16`. `bf16_comm` is server-authoritative (the whole group shares one
wire precision), so every worker adopts it from `/info` — there is no worker
flag. This doubles bandwidth usage.

### Fragment sync deadlock

**Symptom:** Workers hang when using `--num-fragments > 1` in sync mode.
**Symptom:** Workers hang when the server runs with `--num-fragments > 1` in
sync mode.

**Cause:** Per-fragment barriers require all workers to submit the same fragment
in the same round. If workers have different `sync_every` values (e.g., from
DyLU) or different `num_fragments`, their fragment schedules won't align.
in the same round. Misaligned `sync_every` or `num_fragments` across workers
would break this.

**Requirement:** All workers in synchronous fragment mode must use the same
`sync_every` and `num_fragments`.
`sync_every` and `num_fragments`. This is now guaranteed automatically: both
are server-authoritative and adopted by every worker from `/info`, so they
cannot diverge. Set them on the server (`--sync-every`, `--num-fragments`).

### Async staleness drift

Expand Down
76 changes: 42 additions & 34 deletions docs/trainers/diloco.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ Server arguments:
- `--dylu-base-sync-every N`: Base sync interval for the fastest worker (default: 500)
- `--from-checkpoint FROM_CHECKPOINT`: Load model from specified checkpoint path. Overrides loading from newest.

Group-wide worker settings (must match across the group, so they live on the
server; every worker adopts them from `/info` — there are no worker flags):
- `--sync-every N`: Local optimizer steps between syncs (H). Default: 500.
Under `--dylu`, the DyLU base rate is used instead.
- `--num-fragments N`: Streaming-sync fragments every worker splits the model
into (1 = no streaming). Default: 1.
- `--no-bf16`: Send full-precision pseudo-gradients instead of bfloat16
(centralized: one wire precision for the whole group). Default: bf16 on.

```bash
# Load a specific checkpoint and save checkpoints to specified directory.
forgather diloco server -o path/to/output --from-checkpoint output_models/my_model/checkpoint-1000 -n 2
Expand All @@ -133,38 +142,33 @@ explicit `--output-dir` the operator passes, so workers get distinct
per-worker dirs for free):

```bash
# sync mode
forgather diloco worker \
--server 192.168.1.100:8512 \
--sync-every 500 \
--worker-id w0 \
-p my_project -t train.yaml \
train -d 0

# with DyLU - server adjusts sync frequency dynamically
forgather diloco worker \
--server 192.168.1.100:8512 \
--sync-every 500 \
--worker-id w1 \
--dylu \
--heartbeat-interval 30 \
-p my_project -t train.yaml \
train -d 1
```

Worker arguments:
- `--server`: Server address as `host:port`
- `--sync-every`: Local steps between syncs (default: 500)
- `--worker-id`: Unique worker identity. Drives the per-worker output-dir
suffix the project template appends to `ns.output_dir`, and the
uniqueness key the server enforces on `/register`. Auto-generated when
omitted but operators typically set it explicitly so logs / output dirs
are predictable.
- `--no-bf16`: Send full-precision pseudo-gradients instead of bfloat16
- `--dylu`: Enable dynamic sync frequency adjustment from server
- `--heartbeat-interval`: Seconds between heartbeats for speed reporting (default: 30)
- `--heartbeat-interval`: Seconds between heartbeats for speed reporting
(default: 30). Client-local; validated against the server's
`--heartbeat-timeout` at startup.
- `-d`: CUDA visible devices

**Server-authoritative settings.** `sync_every`, `bf16_comm`, `dylu`, and
`num_fragments` must match across every worker in the group for the sync
barrier / outer step / fragment barriers to be coherent. The server is
their sole authority: the worker fetches them from the server's `/info` at
startup and logs the values it adopted. There are no `--sync-every` /
`--no-bf16` / `--dylu` / `--num-fragments` worker flags — set them on the
**server** instead (`--dylu-base-sync-every`, `--dylu`, etc.).

Dataset partitioning across workers is handled by the server's **work-unit
dispatch**: each worker registers its train dataset with the DiLoCo server
on first iteration and pulls per-unit row ranges on demand, so no row is
Expand Down Expand Up @@ -373,18 +377,20 @@ updates at approximately the same wall-clock rate.

DyLU requires:
1. **Server**: `--dylu` flag and `--dylu-base-sync-every` (default: 500)
2. **Workers**: `--dylu` flag and `--heartbeat-interval` (default: 30s)
2. **Workers**: `--heartbeat-interval` (default: 30s). DyLU enablement and the
base `sync_every` are taken from the server's `/info` — there is no worker
`--dylu` flag.

Workers periodically report their training speed via heartbeats. The server
computes the recommended sync interval and returns it in the heartbeat response.
Workers adjust their `sync_every` dynamically.

```bash
# Server with DyLU
# Server with DyLU (this is where dylu / sync_every are configured)
forgather diloco server -o ./model -n 3 --async --dylu --dylu-base-sync-every 500

# Worker with DyLU enabled
forgather diloco worker --server host:8512 --sync-every 500 --worker-id w0 --dylu -- train
# Worker — picks up dylu + sync_every from the server's /info
forgather diloco worker --server host:8512 --worker-id w0 -- train
```

### Staleness Tracking
Expand Down Expand Up @@ -432,12 +438,11 @@ communication becomes fully overlapped.
### CLI Usage

```bash
# Worker with 4 streaming fragments
# Streaming fragments are configured on the server, not the worker —
# the worker reads num_fragments (and sync_every) from /info.
forgather diloco worker \
--server 192.168.1.100:8512 \
--sync-every 500 \
--worker-id w0 \
--num-fragments 4 \
-p my_project -t train.yaml \
train
```
Expand Down Expand Up @@ -649,12 +654,11 @@ configuration works for both DiLoCo and standalone training.
```python
from forgather.ml.trainer.callbacks import DiLoCoCallback

# Explicit configuration
# Explicit configuration (client-local knobs only)
callback = DiLoCoCallback(
server_addr="192.168.1.100:8512",
sync_every=500,
bf16_comm=True,
num_fragments=1,
worker_id="w0",
heartbeat_interval=30.0,
)

# Or rely on environment variables (set by `forgather diloco worker`)
Expand All @@ -669,17 +673,21 @@ trainer = Trainer(
trainer.train()
```

All constructor parameters fall back to `DILOCO_*` environment variables:
The callback's client-local constructor parameters fall back to `DILOCO_*`
environment variables:

| Parameter | Env Var | Default |
|-----------|---------|---------|
| `server_addr` | `DILOCO_SERVER` | `""` (no-op) |
| `sync_every` | `DILOCO_SYNC_EVERY` | `500` |
| `worker_id` | `DILOCO_WORKER_ID` | auto-generated |
| `bf16_comm` | `DILOCO_BF16_COMM` | `True` |
| `dylu` | `DILOCO_DYLU` | `False` |
| `heartbeat_interval` | `DILOCO_HEARTBEAT_INTERVAL` | `30.0` |
| `num_fragments` | `DILOCO_NUM_FRAGMENTS` | `1` |

`sync_every`, `bf16_comm`, `dylu`, and `num_fragments` are **not** callback
parameters or env vars — they must match across the group, so the worker
reads them from the server's `/info` at startup (set them on the server).
The `DiLoCoWorker` class still accepts them directly for the low-level
programmatic API; it's only the callback / CLI surface that defers to the
server.

### Configuration Template

Expand All @@ -696,8 +704,8 @@ Or add the callback directly in your project template:
== super()
diloco_callback: !singleton:forgather.ml.trainer.callbacks:DiLoCoCallback
server_addr: {{ diloco_server | default(None) }}
sync_every: {{ diloco_sync_every | default(None) }}
num_fragments: {{ diloco_num_fragments | default(None) }}
worker_id: {{ diloco_worker_id | default(None) }}
heartbeat_interval: {{ diloco_heartbeat_interval | default(None) }}
```

See `examples/tiny_experiments/diloco/` for a complete working example.
Expand Down
25 changes: 11 additions & 14 deletions src/forgather/cli/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ def outer_optimizer_factory(params):
dn_buffer_size=dn_buffer_size,
dylu_enabled=dylu,
dylu_base_sync_every=dylu_base,
sync_every=args.sync_every,
bf16_comm=args.bf16_comm,
num_fragments=args.num_fragments,
heartbeat_timeout=heartbeat_timeout,
min_workers=min_workers,
default_work_units=default_work_units,
Expand Down Expand Up @@ -306,14 +309,13 @@ def _worker_cmd(args):
This wraps the standard training command, injecting DiLoCo configuration
via environment variables that the training script picks up.
"""
# Set DiLoCo environment variables for the training script
# Set DiLoCo environment variables for the training script. Only
# client-local knobs are forwarded; sync_every / bf16_comm / dylu /
# num_fragments are server-authoritative and resolved from /info by
# the worker at startup (no client override).
env = os.environ.copy()
env["DILOCO_SERVER"] = args.server
env["DILOCO_SYNC_EVERY"] = str(args.sync_every)
env["DILOCO_BF16_COMM"] = "0" if args.no_bf16 else "1"
env["DILOCO_DYLU"] = "1" if getattr(args, "dylu", False) else "0"
env["DILOCO_HEARTBEAT_INTERVAL"] = str(getattr(args, "heartbeat_interval", 30.0))
env["DILOCO_NUM_FRAGMENTS"] = str(getattr(args, "num_fragments", 1))

if args.worker_id:
env["DILOCO_WORKER_ID"] = args.worker_id
Expand Down Expand Up @@ -351,15 +353,10 @@ def _worker_cmd(args):
cmd_args.extend(remainder)

cmd_str = " ".join(cmd_args)
diloco_info = (
f"DiLoCo: server={args.server}, sync_every={args.sync_every}, "
f"bf16={'yes' if not args.no_bf16 else 'no'}"
)
num_frags = getattr(args, "num_fragments", 1)
if num_frags > 1:
diloco_info += f", fragments={num_frags}"
if getattr(args, "dylu", False):
diloco_info += ", dylu=yes"
# sync_every / bf16 / dylu / num_fragments come from the server's /info
# at startup, so they aren't known here — the worker logs them once it
# negotiates with the server.
diloco_info = f"DiLoCo: server={args.server}"
if args.worker_id:
diloco_info += f", worker_id={args.worker_id}"

Expand Down
64 changes: 38 additions & 26 deletions src/forgather/cli/diloco_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,39 @@ def create_diloco_parser(global_args):
default=500,
help="DyLU base sync_every for the fastest worker (default: 500)",
)
# Group-wide worker settings (issue #53 follow-up). These must match
# across every worker, so they're owned by the server and the workers
# adopt them from /info — there are no corresponding worker flags.
server_parser.add_argument(
"--sync-every",
type=int,
default=500,
help=(
"Local optimizer steps between syncs (H), applied by every\n"
"worker. Under --dylu the DyLU base rate is used instead.\n"
"(default: 500)"
),
)
server_parser.add_argument(
"--num-fragments",
type=int,
default=1,
help=(
"Streaming-sync fragments every worker splits the model into.\n"
"1 = no streaming. Must be uniform across the group, so it's\n"
"set here, not per worker. (default: 1)"
),
)
server_parser.add_argument(
"--no-bf16",
dest="bf16_comm",
action="store_false",
help=(
"Send full-precision pseudo-gradients instead of bfloat16.\n"
"Centralized: the group's wire precision is set on the server\n"
"and adopted by every worker. (default: bf16 enabled)"
),
)
server_parser.add_argument(
"--heartbeat-timeout",
type=float,
Expand Down Expand Up @@ -256,35 +289,24 @@ def create_diloco_parser(global_args):
required=True,
help="DiLoCo server address as host:port",
)
worker_parser.add_argument(
"--sync-every",
type=int,
default=500,
help="Number of local optimizer steps between syncs (default: 500)",
)
# NOTE: sync_every / bf16_comm / dylu / num_fragments are NOT worker
# flags. They must match across the group, so the server is their sole
# authority — the worker reads them from /info at startup. See
# DiLoCoCallback (server-authoritative settings).
worker_parser.add_argument(
"--worker-id",
type=str,
default=None,
help="Worker ID (auto-generated if not provided)",
)
worker_parser.add_argument(
"--no-bf16",
action="store_true",
help="Disable bfloat16 communication (send full precision pseudo-gradients)",
)
worker_parser.add_argument(
"--dylu",
action="store_true",
help="Enable Dynamic Local Updates - adapt sync_every based on server recommendations",
)
worker_parser.add_argument(
"--heartbeat-interval",
type=float,
default=30.0,
help=(
"Seconds between heartbeats to server. Enables server-side\n"
"health monitoring and DyLU speed reporting. 0 = disabled.\n"
"Client-local; validated against the server's heartbeat-timeout.\n"
"(default: 30)"
),
)
Expand All @@ -295,16 +317,6 @@ def create_diloco_parser(global_args):
default=None,
help='CUDA Visible Devices e.g. "0,1"',
)
worker_parser.add_argument(
"--num-fragments",
type=int,
default=1,
help=(
"Number of fragments for streaming sync. When > 1, splits the model\n"
"into N fragments that sync at staggered intervals in background\n"
"threads, overlapping communication with computation. (default: 1)"
),
)
worker_parser.add_argument(
"--dry-run",
action="store_true",
Expand Down
11 changes: 11 additions & 0 deletions src/forgather/ml/diloco/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,17 @@ def get_status(self) -> dict:
"""Get server status."""
return self._request_json("GET", "/status")

def get_info(self) -> dict:
"""Get the server's static-ish configuration.

Returns the negotiation facts a worker needs before registering:
the authoritative ``expected_client_settings`` (sync_every,
bf16_comm, dylu, num_fragments_*, heartbeat_timeout), the model
fingerprint (``model_hash``), and topology info. Distinct from
``get_status`` which is the live, rapidly-changing snapshot.
"""
return self._request_json("GET", "/info")

# ------------------------------------------------------------------
# Work-unit dispatch (see docs/design/diloco-work-unit-dispatch.md)
# ------------------------------------------------------------------
Expand Down
Loading
Loading