Skip to content

Commit 9774eda

Browse files
committed
avoid private import
1 parent 579f856 commit 9774eda

File tree

1 file changed

+40
-18
lines changed

1 file changed

+40
-18
lines changed

statica/validation.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,37 @@
1616
from __future__ import annotations
1717

1818
from types import GenericAlias, UnionType
19-
from typing import (
20-
Any,
21-
_LiteralGenericAlias, # type: ignore[attr-defined]
22-
_UnionGenericAlias, # type: ignore[attr-defined]
23-
)
19+
from typing import Any, Literal, TypeGuard, Union
2420

2521
from statica.config import StaticaConfig, default_config
2622
from statica.exceptions import ConstraintValidationError, TypeValidationError
2723

24+
########################################################################################
25+
#### MARK: Types
26+
27+
28+
class LiteralGenericAlias:
29+
"""A type used in place of typing._LiteralGenericAlias to avoid private imports."""
30+
31+
__origin__ = Literal
32+
__args__: tuple[Any, ...]
33+
34+
35+
def is_literal_generic_alias(expected_type: Any) -> TypeGuard[LiteralGenericAlias]:
36+
return hasattr(expected_type, "__origin__") and expected_type.__origin__ is Literal
37+
38+
39+
class UnionGenericAlias:
40+
"""A type used in place of typing._UnionGenericAlias to avoid private imports."""
41+
42+
__origin__ = Union
43+
__args__: tuple[Any, ...]
44+
45+
46+
def is_union_generic_alias(expected_type: Any) -> TypeGuard[UnionGenericAlias]:
47+
return hasattr(expected_type, "__origin__") and expected_type.__origin__ is Union
48+
49+
2850
########################################################################################
2951
#### MARK: Type Validation
3052

@@ -39,28 +61,28 @@ def validate_or_raise(
3961
are already initialized Statica objects.
4062
"""
4163

42-
# Handle union types
64+
# Handle generic aliases if native python types, e.g. list[int], dict[str, int]
4365

44-
if isinstance(expected_type, UnionType):
45-
validate_type_union(value, expected_type, config)
66+
if isinstance(expected_type, GenericAlias):
67+
validate_type_generic_alias(value, expected_type, config)
4668
return
4769

48-
# Handle union generic aliases
70+
# Handle parameterized generic types
4971

50-
if isinstance(expected_type, _UnionGenericAlias):
72+
if is_union_generic_alias(expected_type):
5173
validate_type_union_generic_alias(value, expected_type, config)
5274
return
5375

54-
# Handle generic aliases
76+
# Handle Literal (e.g. Literal["a", "b"], with any number and type of values)
5577

56-
if isinstance(expected_type, GenericAlias):
57-
validate_type_generic_alias(value, expected_type, config)
78+
if is_literal_generic_alias(expected_type):
79+
validate_literal(value, expected_type)
5880
return
5981

60-
# Handle Literal (e.g. Literal["a", "b"], with any number and type of values)
82+
# Handle union types
6183

62-
if isinstance(expected_type, _LiteralGenericAlias):
63-
validate_literal(value, expected_type)
84+
if isinstance(expected_type, UnionType):
85+
validate_type_union(value, expected_type, config)
6486
return
6587

6688
# Handle all other types
@@ -77,7 +99,7 @@ def validate_or_raise(
7799

78100
def validate_literal(
79101
value: Any,
80-
expected_type: _LiteralGenericAlias,
102+
expected_type: LiteralGenericAlias,
81103
) -> None:
82104
"""
83105
Validate that the value matches one of the literals in the expected_type.
@@ -114,7 +136,7 @@ def validate_type_union(
114136

115137
def validate_type_union_generic_alias(
116138
value: Any,
117-
expected_type: _UnionGenericAlias,
139+
expected_type: UnionGenericAlias,
118140
config: StaticaConfig = default_config,
119141
) -> None:
120142
"""

0 commit comments

Comments
 (0)