Skip to content

Commit 8bb0688

Browse files
lukebaumanncopybara-github
authored andcommitted
Removes staticmethods and classmethods from the Manager.
PiperOrigin-RevId: 797005948
1 parent 5756f63 commit 8bb0688

File tree

1 file changed

+22
-56
lines changed

1 file changed

+22
-56
lines changed

pathwaysutils/elastic/manager.py

Lines changed: 22 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ def slice_device_count(self, slice_index: int) -> int:
146146
f"Slice {slice_index=} not found in {self.slice_to_devices=}"
147147
) from error
148148

149-
@classmethod
150-
def is_error_due_to_slice_down(cls, error: Exception) -> bool:
149+
def is_error_due_to_slice_down(self, error: Exception) -> bool:
151150
"""Returns True if the error is due to slice down.
152151
153152
The error types that are considered due to slice down are
@@ -160,7 +159,8 @@ def is_error_due_to_slice_down(cls, error: Exception) -> bool:
160159
error: The error to check.
161160
"""
162161
return_value = isinstance(error, jax.errors.JaxRuntimeError) and any(
163-
error_type in str(error) for error_type in cls._ELASTIC_DOWN_ERROR_TYPES
162+
error_type in str(error)
163+
for error_type in self._ELASTIC_DOWN_ERROR_TYPES
164164
)
165165
if return_value:
166166
_logger.info("Caught an error due to slice down")
@@ -171,8 +171,7 @@ def is_error_due_to_slice_down(cls, error: Exception) -> bool:
171171

172172
return return_value
173173

174-
@classmethod
175-
def _simple_execution(cls, devices: Sequence[jax.Device]) -> jax.Array:
174+
def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array:
176175
"""Simple execution to test if a slice is available.
177176
178177
This function is used to test if a slice is available. It executes a simple
@@ -192,7 +191,7 @@ def _simple_execution(cls, devices: Sequence[jax.Device]) -> jax.Array:
192191
raise ValueError("No devices")
193192

194193
test_input = np.zeros(len(devices), dtype=float) + (
195-
cls._SIMPLE_EXECUTION_TEST_VALUE - 1
194+
self._SIMPLE_EXECUTION_TEST_VALUE - 1
196195
)
197196

198197
return jax.pmap(lambda x: x + 1, devices=devices)(test_input)
@@ -374,7 +373,8 @@ def pop_snapshot(self) -> tuple[int, PyTree | None, PyTree | None]:
374373
the manager. Calls will raise an error if there are no snapshot to pop.
375374
376375
Returns:
377-
A tuple of the step and the snapshot.
376+
A tuple of the step, the snapshot of jax arrays, and the snapshot of
377+
controller variables.
378378
379379
Raises:
380380
ElasticRuntimeError: If there is no snapshot to pop.
@@ -391,46 +391,6 @@ def pop_snapshot(self) -> tuple[int, PyTree | None, PyTree | None]:
391391

392392
return step, snapshot_jax_arrays, snapshot_controller
393393

394-
@staticmethod
395-
def _get_snapshot_jax_arrays_size(snapshot_jax_arrays: PyTree | None) -> int:
396-
"""Returns the size of a snapshot.
397-
398-
Args:
399-
snapshot_jax_arrays: The snapshot to get the size of.
400-
"""
401-
return sum(leaf.nbytes for leaf in jax.tree.leaves(snapshot_jax_arrays))
402-
403-
@staticmethod
404-
def _put_snapshot_jax_arrays_on_host(
405-
snapshot_jax_arrays: PyTree | None,
406-
) -> PyTree | None:
407-
"""Puts a copy of the snapshot on the host.
408-
409-
Args:
410-
snapshot_jax_arrays: The snapshot to move to the host. Must be a PyTree of
411-
JAX arrays or None.
412-
413-
Returns:
414-
A copy of the snapshot on the host.
415-
"""
416-
417-
sharding_pinned_host = jax.tree.map(
418-
lambda x: x.sharding.with_memory_kind("pinned_host"),
419-
snapshot_jax_arrays,
420-
)
421-
return jax.device_put(
422-
snapshot_jax_arrays,
423-
sharding_pinned_host,
424-
donate=False,
425-
may_alias=False,
426-
)
427-
428-
@staticmethod
429-
def _put_snapshot_on_controller(
430-
snapshot: PyTree | None,
431-
) -> PyTree | None:
432-
return copy.deepcopy(snapshot)
433-
434394
# TODO: b/407772100 - Support multiple snapshots.
435395
@timing.timeit
436396
def maybe_snapshot(
@@ -459,22 +419,30 @@ def maybe_snapshot(
459419
_logger.info("Not saving a snapshot")
460420
return
461421

462-
total_nbytes = self._get_snapshot_jax_arrays_size(snapshot_jax_arrays)
422+
total_nbytes = sum(
423+
leaf.nbytes for leaf in jax.tree.leaves(snapshot_jax_arrays)
424+
)
463425

464426
_logger.info("Saving a snapshot of %s bytes on host", total_nbytes)
465427

466-
snapshot_jax_arrays_host = self._put_snapshot_jax_arrays_on_host(
467-
snapshot_jax_arrays
428+
sharding_pinned_host = jax.tree.map(
429+
lambda x: x.sharding.with_memory_kind("pinned_host"),
430+
snapshot_jax_arrays,
431+
)
432+
snapshot_jax_arrays_host = jax.device_put(
433+
snapshot_jax_arrays,
434+
sharding_pinned_host,
435+
donate=False,
436+
may_alias=False,
468437
)
469438
_logger.info("Snapshot dispatched")
470439

471440
if block:
472441
jax.block_until_ready(snapshot_jax_arrays_host)
473442
_logger.info("Snapshot completed")
474443

475-
snapshot_on_controller = self._put_snapshot_on_controller(
476-
snapshot_controller
477-
)
444+
snapshot_on_controller = copy.deepcopy(snapshot_controller)
445+
478446
self._snapshot = {
479447
"step": step,
480448
"snapshot_jax_arrays": snapshot_jax_arrays_host,
@@ -523,9 +491,7 @@ def get_resharded_snapshot(
523491
may_alias=False,
524492
)
525493

526-
snapshot_on_controller = self._put_snapshot_on_controller(
527-
snapshot_controller
528-
)
494+
snapshot_on_controller = copy.deepcopy(snapshot_controller)
529495

530496
self._snapshot = {
531497
"step": step,

0 commit comments

Comments
 (0)