56
56
import pickle
57
57
import threading
58
58
import time
59
- from typing import Awaitable , Optional , Protocol , Sequence , Type , TypeVar
59
+ from typing import Awaitable , Protocol , Sequence , TypeVar
60
60
61
61
from absl import logging
62
62
from etils import epath
@@ -101,7 +101,7 @@ def __call__(
101
101
path : epath .Path ,
102
102
parents : bool = False ,
103
103
exist_ok : bool = False ,
104
- mode : Optional [ int ] = None ,
104
+ mode : int | None = None ,
105
105
** kwargs ,
106
106
) -> Awaitable [None ]:
107
107
"""Creates the directory at path."""
@@ -112,10 +112,10 @@ async def _create_tmp_directory(
112
112
async_makedir_func : AsyncMakeDirFunc ,
113
113
tmp_dir : epath .Path ,
114
114
* ,
115
- path_permission_mode : Optional [ int ] = None ,
116
- checkpoint_metadata_store : Optional [
117
- checkpoint_metadata .MetadataStore
118
- ] = None ,
115
+ path_permission_mode : int | None = None ,
116
+ checkpoint_metadata_store : (
117
+ checkpoint_metadata .MetadataStore | None
118
+ ) = None ,
119
119
** kwargs ,
120
120
) -> epath .Path :
121
121
"""Creates a non-deterministic tmp directory for saving for given `final_dir`.
@@ -181,7 +181,7 @@ def _get_tmp_directory(final_path: epath.Path) -> epath.Path:
181
181
)
182
182
183
183
184
- def _get_tmp_directory_pattern (final_path_name : Optional [ str ] = None ) -> str :
184
+ def _get_tmp_directory_pattern (final_path_name : str | None = None ) -> str :
185
185
suffix = r'\.orbax-checkpoint-tmp'
186
186
if final_path_name is None :
187
187
return '(.+)' + suffix
@@ -197,10 +197,10 @@ def __init__(
197
197
temporary_path : epath .Path ,
198
198
final_path : epath .Path ,
199
199
* ,
200
- checkpoint_metadata_store : Optional [
201
- checkpoint_metadata .MetadataStore
202
- ] = None ,
203
- file_options : Optional [ options_lib .FileOptions ] = None ,
200
+ checkpoint_metadata_store : (
201
+ checkpoint_metadata .MetadataStore | None
202
+ ) = None ,
203
+ file_options : options_lib .FileOptions | None = None ,
204
204
):
205
205
self ._tmp_path = temporary_path
206
206
self ._final_path = final_path
@@ -269,7 +269,7 @@ def to_bytes(self) -> bytes:
269
269
270
270
@classmethod
271
271
def from_bytes (
272
- cls : Type ['ReadOnlyTemporaryPath' ],
272
+ cls : type ['ReadOnlyTemporaryPath' ],
273
273
data : bytes ,
274
274
) -> ReadOnlyTemporaryPath :
275
275
"""Deserializes the object from bytes.
@@ -291,10 +291,10 @@ def from_final(
291
291
cls ,
292
292
final_path : epath .Path ,
293
293
* ,
294
- checkpoint_metadata_store : Optional [
295
- checkpoint_metadata .MetadataStore
296
- ] = None ,
297
- file_options : Optional [ options_lib .FileOptions ] = None ,
294
+ checkpoint_metadata_store : (
295
+ checkpoint_metadata .MetadataStore | None
296
+ ) = None ,
297
+ file_options : options_lib .FileOptions | None = None ,
298
298
) -> ReadOnlyTemporaryPath :
299
299
"""Not implemented for ReadOnlyTemporaryPath."""
300
300
raise NotImplementedError (
@@ -324,10 +324,10 @@ def from_final(
324
324
cls ,
325
325
final_path : epath .Path ,
326
326
* ,
327
- checkpoint_metadata_store : Optional [
328
- checkpoint_metadata .MetadataStore
329
- ] = None ,
330
- file_options : Optional [ options_lib .FileOptions ] = None ,
327
+ checkpoint_metadata_store : (
328
+ checkpoint_metadata .MetadataStore | None
329
+ ) = None ,
330
+ file_options : options_lib .FileOptions | None = None ,
331
331
) -> AtomicRenameTemporaryPath :
332
332
return cls (
333
333
_get_tmp_directory (final_path ),
@@ -399,10 +399,10 @@ def from_final(
399
399
cls ,
400
400
final_path : epath .Path ,
401
401
* ,
402
- checkpoint_metadata_store : Optional [
403
- checkpoint_metadata .MetadataStore
404
- ] = None ,
405
- file_options : Optional [ options_lib .FileOptions ] = None ,
402
+ checkpoint_metadata_store : (
403
+ checkpoint_metadata .MetadataStore | None
404
+ ) = None ,
405
+ file_options : options_lib .FileOptions | None = None ,
406
406
) -> CommitFileTemporaryPath :
407
407
return cls (
408
408
final_path ,
@@ -465,9 +465,7 @@ def finalize(
465
465
async def create_all (
466
466
paths : Sequence [atomicity_types .TemporaryPath ],
467
467
* ,
468
- multiprocessing_options : Optional [
469
- options_lib .MultiprocessingOptions
470
- ] = None ,
468
+ multiprocessing_options : options_lib .MultiprocessingOptions | None = None ,
471
469
):
472
470
"""Creates all temporary paths in parallel."""
473
471
start = time .time ()
@@ -511,9 +509,7 @@ def create_all_async(
511
509
paths : Sequence [atomicity_types .TemporaryPath ],
512
510
completion_signals : Sequence [synchronization .HandlerAwaitableSignal ],
513
511
* ,
514
- multiprocessing_options : Optional [
515
- options_lib .MultiprocessingOptions
516
- ] = None ,
512
+ multiprocessing_options : options_lib .MultiprocessingOptions | None = None ,
517
513
subdirectories : Sequence [str ] | None = None ,
518
514
) -> future .Future :
519
515
"""Creates all temporary paths in parallel asynchronously.
0 commit comments