Skip to content

Commit 6f2daad

Browse files
lukebaumanncopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 813440675
1 parent 2fa0623 commit 6f2daad

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

pathwaysutils/_initialize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def initialize() -> None:
9292
profiling.monkey_patch_jax()
9393
# TODO: b/365549911 - Remove when OCDBT-compatible
9494
if _is_persistence_enabled():
95-
orbax_handler.register_pathways_handlers(datetime.timedelta(hours=1))
95+
orbax_handler.register_pathways_handlers(
96+
timeout=datetime.timedelta(hours=1),
97+
)
9698

9799
# Turn off JAX compilation cache because Pathways handles its own
98100
# compilation cache.

pathwaysutils/persistence/orbax_handler.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,18 @@ class CloudPathwaysArrayHandler(type_handlers.ArrayHandler):
4949

5050
def __init__(
5151
self,
52-
read_timeout: datetime.timedelta | None = None,
52+
timeout: datetime.timedelta | None = None,
5353
use_ocdbt: bool = False,
5454
):
55-
"""Constructor.
55+
"""Orbax array handler for Pathways on Cloud with Persistence API.
5656
5757
Args:
58-
read_timeout: Duration indicating the timeout for reading arrays
58+
timeout: Duration indicating the timeout for reading and writing arrays
5959
use_ocdbt: allows using Tensorstore OCDBT driver.
6060
"""
61-
self._read_timeout = read_timeout
61+
if timeout is None:
62+
timeout = datetime.timedelta(hours=1)
63+
self.timeout = timeout
6264

6365
if use_ocdbt:
6466
raise ValueError("OCDBT not supported for Pathways.")
@@ -92,7 +94,7 @@ async def serialize(
9294

9395
self._wait_for_directory_creation_signals()
9496
locations, names = extract_parent_dir_and_name(infos)
95-
f = functools.partial(helper.write_one_array, timeout=self._read_timeout)
97+
f = functools.partial(helper.write_one_array, timeout=self.timeout)
9698
futures_results = list(map(f, locations, names, values))
9799

98100
return [
@@ -181,7 +183,7 @@ async def deserialize(
181183
grouped_global_shapes,
182184
grouped_shardings,
183185
global_mesh.devices,
184-
timeout=self._read_timeout,
186+
timeout=self.timeout,
185187
)
186188
# each persistence call is awaited serially.
187189
read_future.result()
@@ -191,7 +193,7 @@ async def deserialize(
191193

192194

193195
def register_pathways_handlers(
194-
read_timeout: datetime.timedelta | None = None,
196+
timeout: datetime.timedelta | None = None,
195197
):
196198
"""Function that must be called before saving or restoring with Pathways."""
197199
logger.debug(
@@ -200,7 +202,7 @@ def register_pathways_handlers(
200202
type_handlers.register_type_handler(
201203
jax.Array,
202204
CloudPathwaysArrayHandler(
203-
read_timeout=read_timeout,
205+
timeout=timeout,
204206
),
205207
override=True,
206208
)

0 commit comments

Comments
 (0)