Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 107 additions & 27 deletions keras/src/callbacks/orbax_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import warnings

import numpy as np
Expand All @@ -8,7 +9,7 @@
from keras.src.callbacks.monitor_callback import (
MonitorCallback, # For metric monitoring logic
)
from keras.src.utils.io_utils import print_msg
from keras.src.saving.saving_lib import DiskIOStore
from keras.src.utils.module_utils import ocp

# Context and AsyncOptions are accessed through the lazy-loaded ocp module
Expand Down Expand Up @@ -62,6 +63,11 @@ class OrbaxCheckpoint(MonitorCallback):
This callback saves the model's weights and optimizer state asynchronously
using Orbax, allowing training to continue without blocking for I/O.

**Multi-host Support**: When running in a multi-host distributed training
environment with JAX backend, this callback automatically coordinates
checkpointing across all hosts to ensure consistency and proper
synchronization. Multi-host checkpointing is only supported on JAX.

Example:

```python
Expand Down Expand Up @@ -138,6 +144,9 @@ def __init__(
self._current_epoch = 0 # Keep track of epoch
self._total_batches_seen = 0 # Global batch counter for step tracking

# Multi-host support
self._multihost_initialized = self._is_multihost_initialized()

if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
raise ValueError(
f"Unrecognized save_freq: {self.save_freq}. "
Expand All @@ -151,14 +160,18 @@ def __init__(
ocp.training.preservation_policies.LatestN(max_to_keep)
)

# Use AnyPreservationPolicy to combine them.
# Use AnyPreservationPolicy to combine them, or use directly
# if single policy
preservation_policy = None
if policies:
preservation_policy = (
ocp.training.preservation_policies.AnyPreservationPolicy(
policies
if len(policies) == 1:
preservation_policy = policies[0]
else:
preservation_policy = (
ocp.training.preservation_policies.AnyPreservationPolicy(
policies
)
)
)

# Create the V1 Checkpointer with direct parameter passing
# Orbax will handle directory creation on all processes as needed
Expand All @@ -167,6 +180,54 @@ def __init__(
preservation_policy=preservation_policy,
)

def _is_multihost_initialized(self):
"""Check if multi-host environment is initialized."""
# Multi-host checkpointing is only supported on JAX backend
if backend.backend() != "jax":
return False

multihost = ocp.multihost
# Check if JAX distributed client is initialized
# (indicates multihost setup)
return multihost.is_jax_distributed_client_initialized()

def _sync_processes(self, key=None):
"""Synchronize all processes across hosts."""
if not self._multihost_initialized:
return # No-op for single host

multihost = ocp.multihost
sync_key = key or "orbax_checkpoint_sync"
multihost.sync_global_processes(sync_key)

def is_multihost_enabled(self):
"""Return True if multi-host checkpointing is enabled and initialized.

This method can be used to check if the callback is operating in
a multi-host distributed training environment. Multi-host checkpointing
is only supported on JAX backend.

Returns:
bool: True if multi-host support is active, False otherwise.
"""
return self._multihost_initialized

def is_primary_host(self):
"""Return True if this process is the primary host in multi-host setup.

In multi-host environments, only the primary host typically handles
logging and coordination tasks. Multi-host checkpointing is only
supported on JAX backend.

Returns:
bool: True if this is the primary host, False otherwise.
Always returns True in single-host environments.
"""
if not self._multihost_initialized:
return True # Single host is always primary
multihost = ocp.multihost
return multihost.is_primary_host()

def _should_save_on_batch(self, batch):
"""Check if we should save on this batch."""
if self.save_freq == "epoch":
Expand All @@ -186,7 +247,7 @@ def _should_save_on_batch(self, batch):
return False

def _save_checkpoint(self, step, logs=None):
"""Save a checkpoint at the given step."""
"""Save a checkpoint at the given step with multi-host coordination."""

# --- Prepare Composite State (Backend-Agnostic) ---
state_tree = _get_state_tree(self.model)
Expand All @@ -202,16 +263,10 @@ def _save_checkpoint(self, step, logs=None):
"non_trainable_variables"
]
else:
composite_state = state_tree

# --- Save Logic (V1 API) ---
# All processes participate in distributed checkpointing
# Checkpointer is configured to save unconditionally when
# save_pytree is called
if self.verbose > 0:
print_msg(
f"OrbaxCheckpoint: Triggering async save for step {step}..."
)
composite_state = {
"model_config": self.model.get_config(),
**state_tree,
}

# Use a single with statement. If context_options is empty,
# Context() uses defaults.
Expand All @@ -221,6 +276,33 @@ def _save_checkpoint(self, step, logs=None):
else:
self.checkpointer.save_pytree(step, composite_state)

# Save assets separately since PyTree can't handle binary data
if not self.save_weights_only:
self._save_assets(step)

def _save_assets(self, step):
"""Save model assets to a separate directory."""
from keras.src.saving.saving_lib import _save_state

assets_dir = os.path.join(self.directory, "assets", str(step))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you find the assets when reloading? Don't you need to put this file in the specific checkpoint subfolder instead? Because step is not necessarily the checkpoint number.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hertschuh thanks for the comments i strongly feel we should not support assets feature as orbax does not support this,
When Assets Are Saved: Assets (e.g., custom layer data like binary files or numpy arrays) are not handled by Orbax
Instead, Keras saves them separately using _save_assets(step), which writes to checkpoint_dir/step/assets/ via DiskIOStore
Problem If asset saving started immediately after save_pytree_async without waiting, the thread might try to write to checkpoint_dir/step/assets/ before Orbax has created the step/ directory. This would fail with a "directory not found" error.
saving them in background again causes synchronisation issues
Let me know what you think about this ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this background saving can conflict with preservation policy example
The asset thread runs independently—Orbax doesn't know it's writing and may clean up the directory anyway. Background saving overlaps I/O but doesn't prevent Orbax's policy from running concurrently.

try:
assets_store = DiskIOStore(assets_dir, mode="w")
except FileExistsError:
# Directory already exists, skip asset saving
return
try:
# Use the same recursive saving logic as _save_state
visited = set()
_save_state(
self.model,
None, # No weights store
assets_store, # Assets store
"", # Root path
visited,
)
finally:
assets_store.close()

def on_train_batch_end(self, batch, logs=None):
if self._should_save_on_batch(batch):
# Handle save_best_only logic for batch-level saving
Expand Down Expand Up @@ -282,18 +364,16 @@ def on_train_end(self, logs=None):
except Exception:
pass # Ignore errors during cleanup

# Multi-host synchronization: ensure all hosts complete cleanup
self._sync_processes("checkpoint_cleanup")

def wait_until_finished(self):
"""Wait for any in-progress checkpoint operations to complete.
This method blocks until all asynchronous checkpoint save operations
have completed. It should be called before attempting to load
checkpoints if there might be pending save operations.
have completed across all hosts in a multi-host setup.
"""
# Wait for any async operations to complete
if hasattr(self.checkpointer, "wait"):
self.checkpointer.wait()
else:
# Fallback for older Orbax versions that don't have wait() method
while self.checkpointer.is_saving_in_progress():
import time
# Wait for any async operations to complete on this host
self.checkpointer.wait()

time.sleep(0.1)
# Multi-host synchronization: ensure all hosts complete
self._sync_processes("checkpoint_wait_complete")
Loading