-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Orbax Loading and Sharding Support feature #21903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
bc9060c
43d45d0
4125ae0
77689d9
ece275d
0464c3b
d8a86e8
43fbecd
82b345f
9e35729
5a5c810
b503b79
e45c186
f027ba9
802d785
f98e27c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import os | ||
| import warnings | ||
|
|
||
| import numpy as np | ||
|
|
@@ -8,7 +9,7 @@ | |
| from keras.src.callbacks.monitor_callback import ( | ||
| MonitorCallback, # For metric monitoring logic | ||
| ) | ||
| from keras.src.utils.io_utils import print_msg | ||
| from keras.src.saving.saving_lib import DiskIOStore | ||
| from keras.src.utils.module_utils import ocp | ||
|
|
||
| # Context and AsyncOptions are accessed through the lazy-loaded ocp module | ||
|
|
@@ -62,6 +63,11 @@ class OrbaxCheckpoint(MonitorCallback): | |
| This callback saves the model's weights and optimizer state asynchronously | ||
| using Orbax, allowing training to continue without blocking for I/O. | ||
|
|
||
| **Multi-host Support**: When running in a multi-host distributed training | ||
| environment with JAX backend, this callback automatically coordinates | ||
| checkpointing across all hosts to ensure consistency and proper | ||
| synchronization. Multi-host checkpointing is only supported on JAX. | ||
|
|
||
| Example: | ||
|
|
||
| ```python | ||
|
|
@@ -138,6 +144,9 @@ def __init__( | |
| self._current_epoch = 0 # Keep track of epoch | ||
| self._total_batches_seen = 0 # Global batch counter for step tracking | ||
|
|
||
| # Multi-host support | ||
| self._multihost_initialized = self._is_multihost_initialized() | ||
|
|
||
| if self.save_freq != "epoch" and not isinstance(self.save_freq, int): | ||
| raise ValueError( | ||
| f"Unrecognized save_freq: {self.save_freq}. " | ||
|
|
@@ -151,14 +160,18 @@ def __init__( | |
| ocp.training.preservation_policies.LatestN(max_to_keep) | ||
| ) | ||
|
|
||
| # Use AnyPreservationPolicy to combine them. | ||
| # Use AnyPreservationPolicy to combine them, or use directly | ||
| # if single policy | ||
| preservation_policy = None | ||
| if policies: | ||
| preservation_policy = ( | ||
| ocp.training.preservation_policies.AnyPreservationPolicy( | ||
| policies | ||
| if len(policies) == 1: | ||
| preservation_policy = policies[0] | ||
| else: | ||
| preservation_policy = ( | ||
| ocp.training.preservation_policies.AnyPreservationPolicy( | ||
| policies | ||
| ) | ||
| ) | ||
| ) | ||
|
|
||
| # Create the V1 Checkpointer with direct parameter passing | ||
| # Orbax will handle directory creation on all processes as needed | ||
|
|
@@ -167,6 +180,54 @@ def __init__( | |
| preservation_policy=preservation_policy, | ||
| ) | ||
|
|
||
| def _is_multihost_initialized(self): | ||
| """Check if multi-host environment is initialized.""" | ||
| # Multi-host checkpointing is only supported on JAX backend | ||
| if backend.backend() != "jax": | ||
| return False | ||
|
|
||
| multihost = ocp.multihost | ||
| # Check if JAX distributed client is initialized | ||
| # (indicates multihost setup) | ||
| return multihost.is_jax_distributed_client_initialized() | ||
|
|
||
| def _sync_processes(self, key=None): | ||
| """Synchronize all processes across hosts.""" | ||
| if not self._multihost_initialized: | ||
| return # No-op for single host | ||
|
|
||
| multihost = ocp.multihost | ||
| sync_key = key or "orbax_checkpoint_sync" | ||
| multihost.sync_global_processes(sync_key) | ||
|
|
||
| def is_multihost_enabled(self): | ||
| """Return True if multi-host checkpointing is enabled and initialized. | ||
|
|
||
| This method can be used to check if the callback is operating in | ||
| a multi-host distributed training environment. Multi-host checkpointing | ||
| is only supported on JAX backend. | ||
|
|
||
| Returns: | ||
| bool: True if multi-host support is active, False otherwise. | ||
| """ | ||
| return self._multihost_initialized | ||
|
|
||
| def is_primary_host(self): | ||
| """Return True if this process is the primary host in multi-host setup. | ||
|
|
||
| In multi-host environments, only the primary host typically handles | ||
| logging and coordination tasks. Multi-host checkpointing is only | ||
| supported on JAX backend. | ||
|
|
||
| Returns: | ||
| bool: True if this is the primary host, False otherwise. | ||
| Always returns True in single-host environments. | ||
| """ | ||
| if not self._multihost_initialized: | ||
| return True # Single host is always primary | ||
| multihost = ocp.multihost | ||
| return multihost.is_primary_host() | ||
|
|
||
| def _should_save_on_batch(self, batch): | ||
| """Check if we should save on this batch.""" | ||
| if self.save_freq == "epoch": | ||
|
|
@@ -186,7 +247,7 @@ def _should_save_on_batch(self, batch): | |
| return False | ||
|
|
||
| def _save_checkpoint(self, step, logs=None): | ||
| """Save a checkpoint at the given step.""" | ||
| """Save a checkpoint at the given step with multi-host coordination.""" | ||
|
|
||
| # --- Prepare Composite State (Backend-Agnostic) --- | ||
| state_tree = _get_state_tree(self.model) | ||
|
|
@@ -202,16 +263,10 @@ def _save_checkpoint(self, step, logs=None): | |
| "non_trainable_variables" | ||
| ] | ||
| else: | ||
| composite_state = state_tree | ||
|
|
||
| # --- Save Logic (V1 API) --- | ||
| # All processes participate in distributed checkpointing | ||
| # Checkpointer is configured to save unconditionally when | ||
| # save_pytree is called | ||
| if self.verbose > 0: | ||
| print_msg( | ||
| f"OrbaxCheckpoint: Triggering async save for step {step}..." | ||
| ) | ||
| composite_state = { | ||
| "model_config": self.model.get_config(), | ||
| **state_tree, | ||
| } | ||
|
|
||
| # Use a single with statement. If context_options is empty, | ||
| # Context() uses defaults. | ||
|
|
@@ -221,6 +276,33 @@ def _save_checkpoint(self, step, logs=None): | |
| else: | ||
| self.checkpointer.save_pytree(step, composite_state) | ||
|
|
||
| # Save assets separately since PyTree can't handle binary data | ||
| if not self.save_weights_only: | ||
| self._save_assets(step) | ||
|
|
||
| def _save_assets(self, step): | ||
| """Save model assets to a separate directory.""" | ||
| from keras.src.saving.saving_lib import _save_state | ||
|
|
||
| assets_dir = os.path.join(self.directory, "assets", str(step)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do you find the assets when reloading? Don't you need to put this file in the specific checkpoint subfolder instead? Because
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hertschuh thanks for the comments i strongly feel we should not support assets feature as orbax does not support this,
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also this background saving can conflict with preservation policy example |
||
| try: | ||
| assets_store = DiskIOStore(assets_dir, mode="w") | ||
| except FileExistsError: | ||
| # Directory already exists, skip asset saving | ||
| return | ||
| try: | ||
| # Use the same recursive saving logic as _save_state | ||
| visited = set() | ||
| _save_state( | ||
| self.model, | ||
| None, # No weights store | ||
| assets_store, # Assets store | ||
| "", # Root path | ||
| visited, | ||
| ) | ||
| finally: | ||
| assets_store.close() | ||
|
|
||
| def on_train_batch_end(self, batch, logs=None): | ||
| if self._should_save_on_batch(batch): | ||
| # Handle save_best_only logic for batch-level saving | ||
|
|
@@ -282,18 +364,16 @@ def on_train_end(self, logs=None): | |
| except Exception: | ||
| pass # Ignore errors during cleanup | ||
|
|
||
| # Multi-host synchronization: ensure all hosts complete cleanup | ||
| self._sync_processes("checkpoint_cleanup") | ||
|
|
||
| def wait_until_finished(self): | ||
| """Wait for any in-progress checkpoint operations to complete. | ||
| This method blocks until all asynchronous checkpoint save operations | ||
| have completed. It should be called before attempting to load | ||
| checkpoints if there might be pending save operations. | ||
| have completed across all hosts in a multi-host setup. | ||
| """ | ||
| # Wait for any async operations to complete | ||
| if hasattr(self.checkpointer, "wait"): | ||
| self.checkpointer.wait() | ||
| else: | ||
| # Fallback for older Orbax versions that don't have wait() method | ||
| while self.checkpointer.is_saving_in_progress(): | ||
| import time | ||
| # Wait for any async operations to complete on this host | ||
| self.checkpointer.wait() | ||
|
|
||
| time.sleep(0.1) | ||
| # Multi-host synchronization: ensure all hosts complete | ||
| self._sync_processes("checkpoint_wait_complete") | ||
amitsrivastava78 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Uh oh!
There was an error while loading. Please reload this page.