@@ -146,8 +146,7 @@ def slice_device_count(self, slice_index: int) -> int:
146
146
f"Slice { slice_index = } not found in { self .slice_to_devices = } "
147
147
) from error
148
148
149
- @classmethod
150
- def is_error_due_to_slice_down (cls , error : Exception ) -> bool :
149
+ def is_error_due_to_slice_down (self , error : Exception ) -> bool :
151
150
"""Returns True if the error is due to slice down.
152
151
153
152
The error types that are considered due to slice down are
@@ -160,7 +159,8 @@ def is_error_due_to_slice_down(cls, error: Exception) -> bool:
160
159
error: The error to check.
161
160
"""
162
161
return_value = isinstance (error , jax .errors .JaxRuntimeError ) and any (
163
- error_type in str (error ) for error_type in cls ._ELASTIC_DOWN_ERROR_TYPES
162
+ error_type in str (error )
163
+ for error_type in self ._ELASTIC_DOWN_ERROR_TYPES
164
164
)
165
165
if return_value :
166
166
_logger .info ("Caught an error due to slice down" )
@@ -171,8 +171,7 @@ def is_error_due_to_slice_down(cls, error: Exception) -> bool:
171
171
172
172
return return_value
173
173
174
- @classmethod
175
- def _simple_execution (cls , devices : Sequence [jax .Device ]) -> jax .Array :
174
+ def _simple_execution (self , devices : Sequence [jax .Device ]) -> jax .Array :
176
175
"""Simple execution to test if a slice is available.
177
176
178
177
This function is used to test if a slice is available. It executes a simple
@@ -192,7 +191,7 @@ def _simple_execution(cls, devices: Sequence[jax.Device]) -> jax.Array:
192
191
raise ValueError ("No devices" )
193
192
194
193
test_input = np .zeros (len (devices ), dtype = float ) + (
195
- cls ._SIMPLE_EXECUTION_TEST_VALUE - 1
194
+ self ._SIMPLE_EXECUTION_TEST_VALUE - 1
196
195
)
197
196
198
197
return jax .pmap (lambda x : x + 1 , devices = devices )(test_input )
@@ -374,7 +373,8 @@ def pop_snapshot(self) -> tuple[int, PyTree | None, PyTree | None]:
374
373
the manager. Calls will raise an error if there are no snapshot to pop.
375
374
376
375
Returns:
377
- A tuple of the step and the snapshot.
376
+ A tuple of the step, the snapshot of jax arrays, and the snapshot of
377
+ controller variables.
378
378
379
379
Raises:
380
380
ElasticRuntimeError: If there is no snapshot to pop.
@@ -391,46 +391,6 @@ def pop_snapshot(self) -> tuple[int, PyTree | None, PyTree | None]:
391
391
392
392
return step , snapshot_jax_arrays , snapshot_controller
393
393
394
- @staticmethod
395
- def _get_snapshot_jax_arrays_size (snapshot_jax_arrays : PyTree | None ) -> int :
396
- """Returns the size of a snapshot.
397
-
398
- Args:
399
- snapshot_jax_arrays: The snapshot to get the size of.
400
- """
401
- return sum (leaf .nbytes for leaf in jax .tree .leaves (snapshot_jax_arrays ))
402
-
403
- @staticmethod
404
- def _put_snapshot_jax_arrays_on_host (
405
- snapshot_jax_arrays : PyTree | None ,
406
- ) -> PyTree | None :
407
- """Puts a copy of the snapshot on the host.
408
-
409
- Args:
410
- snapshot_jax_arrays: The snapshot to move to the host. Must be a PyTree of
411
- JAX arrays or None.
412
-
413
- Returns:
414
- A copy of the snapshot on the host.
415
- """
416
-
417
- sharding_pinned_host = jax .tree .map (
418
- lambda x : x .sharding .with_memory_kind ("pinned_host" ),
419
- snapshot_jax_arrays ,
420
- )
421
- return jax .device_put (
422
- snapshot_jax_arrays ,
423
- sharding_pinned_host ,
424
- donate = False ,
425
- may_alias = False ,
426
- )
427
-
428
- @staticmethod
429
- def _put_snapshot_on_controller (
430
- snapshot : PyTree | None ,
431
- ) -> PyTree | None :
432
- return copy .deepcopy (snapshot )
433
-
434
394
# TODO: b/407772100 - Support multiple snapshots.
435
395
@timing .timeit
436
396
def maybe_snapshot (
@@ -459,22 +419,30 @@ def maybe_snapshot(
459
419
_logger .info ("Not saving a snapshot" )
460
420
return
461
421
462
- total_nbytes = self ._get_snapshot_jax_arrays_size (snapshot_jax_arrays )
422
+ total_nbytes = sum (
423
+ leaf .nbytes for leaf in jax .tree .leaves (snapshot_jax_arrays )
424
+ )
463
425
464
426
_logger .info ("Saving a snapshot of %s bytes on host" , total_nbytes )
465
427
466
- snapshot_jax_arrays_host = self ._put_snapshot_jax_arrays_on_host (
467
- snapshot_jax_arrays
428
+ sharding_pinned_host = jax .tree .map (
429
+ lambda x : x .sharding .with_memory_kind ("pinned_host" ),
430
+ snapshot_jax_arrays ,
431
+ )
432
+ snapshot_jax_arrays_host = jax .device_put (
433
+ snapshot_jax_arrays ,
434
+ sharding_pinned_host ,
435
+ donate = False ,
436
+ may_alias = False ,
468
437
)
469
438
_logger .info ("Snapshot dispatched" )
470
439
471
440
if block :
472
441
jax .block_until_ready (snapshot_jax_arrays_host )
473
442
_logger .info ("Snapshot completed" )
474
443
475
- snapshot_on_controller = self ._put_snapshot_on_controller (
476
- snapshot_controller
477
- )
444
+ snapshot_on_controller = copy .deepcopy (snapshot_controller )
445
+
478
446
self ._snapshot = {
479
447
"step" : step ,
480
448
"snapshot_jax_arrays" : snapshot_jax_arrays_host ,
@@ -523,9 +491,7 @@ def get_resharded_snapshot(
523
491
may_alias = False ,
524
492
)
525
493
526
- snapshot_on_controller = self ._put_snapshot_on_controller (
527
- snapshot_controller
528
- )
494
+ snapshot_on_controller = copy .deepcopy (snapshot_controller )
529
495
530
496
self ._snapshot = {
531
497
"step" : step ,
0 commit comments