Skip to content

Commit 2855966

Browse files
dicentra13Orbax Authors
authored andcommitted
Modernize type annotations in atomicity modules.
PiperOrigin-RevId: 787073942
1 parent 6d98ef9 commit 2855966

File tree

3 files changed

+33
-38
lines changed

3 files changed

+33
-38
lines changed

checkpoint/orbax/checkpoint/_src/path/atomicity.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
import pickle
5757
import threading
5858
import time
59-
from typing import Awaitable, Optional, Protocol, Sequence, Type, TypeVar
59+
from typing import Awaitable, Protocol, Sequence, TypeVar
6060

6161
from absl import logging
6262
from etils import epath
@@ -101,7 +101,7 @@ def __call__(
101101
path: epath.Path,
102102
parents: bool = False,
103103
exist_ok: bool = False,
104-
mode: Optional[int] = None,
104+
mode: int | None = None,
105105
**kwargs,
106106
) -> Awaitable[None]:
107107
"""Creates the directory at path."""
@@ -112,10 +112,10 @@ async def _create_tmp_directory(
112112
async_makedir_func: AsyncMakeDirFunc,
113113
tmp_dir: epath.Path,
114114
*,
115-
path_permission_mode: Optional[int] = None,
116-
checkpoint_metadata_store: Optional[
117-
checkpoint_metadata.MetadataStore
118-
] = None,
115+
path_permission_mode: int | None = None,
116+
checkpoint_metadata_store: (
117+
checkpoint_metadata.MetadataStore | None
118+
) = None,
119119
**kwargs,
120120
) -> epath.Path:
121121
"""Creates a non-deterministic tmp directory for saving for given `final_dir`.
@@ -181,7 +181,7 @@ def _get_tmp_directory(final_path: epath.Path) -> epath.Path:
181181
)
182182

183183

184-
def _get_tmp_directory_pattern(final_path_name: Optional[str] = None) -> str:
184+
def _get_tmp_directory_pattern(final_path_name: str | None = None) -> str:
185185
suffix = r'\.orbax-checkpoint-tmp'
186186
if final_path_name is None:
187187
return '(.+)' + suffix
@@ -197,10 +197,10 @@ def __init__(
197197
temporary_path: epath.Path,
198198
final_path: epath.Path,
199199
*,
200-
checkpoint_metadata_store: Optional[
201-
checkpoint_metadata.MetadataStore
202-
] = None,
203-
file_options: Optional[options_lib.FileOptions] = None,
200+
checkpoint_metadata_store: (
201+
checkpoint_metadata.MetadataStore | None
202+
) = None,
203+
file_options: options_lib.FileOptions | None = None,
204204
):
205205
self._tmp_path = temporary_path
206206
self._final_path = final_path
@@ -269,7 +269,7 @@ def to_bytes(self) -> bytes:
269269

270270
@classmethod
271271
def from_bytes(
272-
cls: Type['ReadOnlyTemporaryPath'],
272+
cls: type['ReadOnlyTemporaryPath'],
273273
data: bytes,
274274
) -> ReadOnlyTemporaryPath:
275275
"""Deserializes the object from bytes.
@@ -291,10 +291,10 @@ def from_final(
291291
cls,
292292
final_path: epath.Path,
293293
*,
294-
checkpoint_metadata_store: Optional[
295-
checkpoint_metadata.MetadataStore
296-
] = None,
297-
file_options: Optional[options_lib.FileOptions] = None,
294+
checkpoint_metadata_store: (
295+
checkpoint_metadata.MetadataStore | None
296+
) = None,
297+
file_options: options_lib.FileOptions | None = None,
298298
) -> ReadOnlyTemporaryPath:
299299
"""Not implemented for ReadOnlyTemporaryPath."""
300300
raise NotImplementedError(
@@ -324,10 +324,10 @@ def from_final(
324324
cls,
325325
final_path: epath.Path,
326326
*,
327-
checkpoint_metadata_store: Optional[
328-
checkpoint_metadata.MetadataStore
329-
] = None,
330-
file_options: Optional[options_lib.FileOptions] = None,
327+
checkpoint_metadata_store: (
328+
checkpoint_metadata.MetadataStore | None
329+
) = None,
330+
file_options: options_lib.FileOptions | None = None,
331331
) -> AtomicRenameTemporaryPath:
332332
return cls(
333333
_get_tmp_directory(final_path),
@@ -399,10 +399,10 @@ def from_final(
399399
cls,
400400
final_path: epath.Path,
401401
*,
402-
checkpoint_metadata_store: Optional[
403-
checkpoint_metadata.MetadataStore
404-
] = None,
405-
file_options: Optional[options_lib.FileOptions] = None,
402+
checkpoint_metadata_store: (
403+
checkpoint_metadata.MetadataStore | None
404+
) = None,
405+
file_options: options_lib.FileOptions | None = None,
406406
) -> CommitFileTemporaryPath:
407407
return cls(
408408
final_path,
@@ -465,9 +465,7 @@ def finalize(
465465
async def create_all(
466466
paths: Sequence[atomicity_types.TemporaryPath],
467467
*,
468-
multiprocessing_options: Optional[
469-
options_lib.MultiprocessingOptions
470-
] = None,
468+
multiprocessing_options: options_lib.MultiprocessingOptions | None = None,
471469
):
472470
"""Creates all temporary paths in parallel."""
473471
start = time.time()
@@ -511,9 +509,7 @@ def create_all_async(
511509
paths: Sequence[atomicity_types.TemporaryPath],
512510
completion_signals: Sequence[synchronization.HandlerAwaitableSignal],
513511
*,
514-
multiprocessing_options: Optional[
515-
options_lib.MultiprocessingOptions
516-
] = None,
512+
multiprocessing_options: options_lib.MultiprocessingOptions | None = None,
517513
subdirectories: Sequence[str] | None = None,
518514
) -> future.Future:
519515
"""Creates all temporary paths in parallel asynchronously.

checkpoint/orbax/checkpoint/_src/path/atomicity_defaults.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
than we would want to introduce into the base `atomicity` module.
2020
"""
2121

22-
from typing import Type
2322
from etils import epath
2423
from orbax.checkpoint._src.path import atomicity
2524
from orbax.checkpoint._src.path import atomicity_types
@@ -28,7 +27,7 @@
2827

2928
def get_item_default_temporary_path_class(
3029
final_path: epath.Path,
31-
) -> Type[atomicity_types.TemporaryPath]:
30+
) -> type[atomicity_types.TemporaryPath]:
3231
"""Returns the default temporary path class for a given sub-item path."""
3332
if step_lib.is_gcs_path(final_path):
3433
return atomicity.CommitFileTemporaryPath
@@ -38,7 +37,7 @@ def get_item_default_temporary_path_class(
3837

3938
def get_default_temporary_path_class(
4039
final_path: epath.Path,
41-
) -> Type[atomicity_types.TemporaryPath]:
40+
) -> type[atomicity_types.TemporaryPath]:
4241
"""Returns the default temporary path class for a given checkpoint path."""
4342
if step_lib.is_gcs_path(final_path):
4443
return atomicity.CommitFileTemporaryPath

checkpoint/orbax/checkpoint/_src/path/atomicity_types.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from __future__ import annotations
2222

2323
import typing
24-
from typing import Optional, Protocol
24+
from typing import Protocol
2525

2626
from etils import epath
2727
from orbax.checkpoint import options as options_lib
@@ -49,10 +49,10 @@ def from_final(
4949
cls,
5050
final_path: epath.Path,
5151
*,
52-
checkpoint_metadata_store: Optional[
53-
checkpoint_metadata.MetadataStore
54-
] = None,
55-
file_options: Optional[options_lib.FileOptions] = None,
52+
checkpoint_metadata_store: (
53+
checkpoint_metadata.MetadataStore | None
54+
) = None,
55+
file_options: options_lib.FileOptions | None = None,
5656
) -> TemporaryPath:
5757
"""Creates a TemporaryPath from a final path."""
5858
...

0 commit comments

Comments
 (0)