@@ -49,16 +49,18 @@ class CloudPathwaysArrayHandler(type_handlers.ArrayHandler):
49
49
50
50
def __init__ (
51
51
self ,
52
- read_timeout : datetime .timedelta | None = None ,
52
+ timeout : datetime .timedelta | None = None ,
53
53
use_ocdbt : bool = False ,
54
54
):
55
- """Constructor .
55
+ """Orbax array handler for Pathways on Cloud with Persistence API .
56
56
57
57
Args:
58
- read_timeout : Duration indicating the timeout for reading arrays
58
+ timeout : Duration indicating the timeout for reading and writing arrays
59
59
use_ocdbt: allows using Tensorstore OCDBT driver.
60
60
"""
61
- self ._read_timeout = read_timeout
61
+ if timeout is None :
62
+ timeout = datetime .timedelta (hours = 1 )
63
+ self .timeout = timeout
62
64
63
65
if use_ocdbt :
64
66
raise ValueError ("OCDBT not supported for Pathways." )
@@ -92,7 +94,7 @@ async def serialize(
92
94
93
95
self ._wait_for_directory_creation_signals ()
94
96
locations , names = extract_parent_dir_and_name (infos )
95
- f = functools .partial (helper .write_one_array , timeout = self ._read_timeout )
97
+ f = functools .partial (helper .write_one_array , timeout = self .timeout )
96
98
futures_results = list (map (f , locations , names , values ))
97
99
98
100
return [
@@ -181,7 +183,7 @@ async def deserialize(
181
183
grouped_global_shapes ,
182
184
grouped_shardings ,
183
185
global_mesh .devices ,
184
- timeout = self ._read_timeout ,
186
+ timeout = self .timeout ,
185
187
)
186
188
# each persistence call is awaited serially.
187
189
read_future .result ()
@@ -191,7 +193,7 @@ async def deserialize(
191
193
192
194
193
195
def register_pathways_handlers (
194
- read_timeout : datetime .timedelta | None = None ,
196
+ timeout : datetime .timedelta | None = None ,
195
197
):
196
198
"""Function that must be called before saving or restoring with Pathways."""
197
199
logger .debug (
@@ -200,7 +202,7 @@ def register_pathways_handlers(
200
202
type_handlers .register_type_handler (
201
203
jax .Array ,
202
204
CloudPathwaysArrayHandler (
203
- read_timeout = read_timeout ,
205
+ timeout = timeout ,
204
206
),
205
207
override = True ,
206
208
)
0 commit comments