Skip to content

Commit 6ec673e

Browse files
authored
Add literal validation (#14)
1 parent b768741 commit 6ec673e

File tree

3 files changed

+133
-10
lines changed

3 files changed

+133
-10
lines changed

assets/coverage.svg

Lines changed: 2 additions & 2 deletions
Loading

statica/validation.py

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
"""
22
The backus naur grammar for types is as follows:
3-
T ::= Statica | int | float | str | None | (T1 | T2) | list[T] | set[T] | dict[T1, T2]
3+
T ::= Statica
4+
| int
5+
| float
6+
| str
7+
| None
8+
| (T1 | T2)
9+
| list[T]
10+
| set[T]
11+
| dict[T1, T2]
12+
| Literal[V1, ...]
13+
414
515
Where:
616
- Statica: A class that inherits from Statica
@@ -16,11 +26,37 @@
1626
from __future__ import annotations
1727

1828
from types import GenericAlias, UnionType
19-
from typing import Any
29+
from typing import Any, Literal, TypeGuard, Union
2030

2131
from statica.config import StaticaConfig, default_config
2232
from statica.exceptions import ConstraintValidationError, TypeValidationError
2333

34+
########################################################################################
35+
#### MARK: Types
36+
37+
38+
class LiteralGenericAlias:
39+
"""A type used in place of typing._LiteralGenericAlias to avoid private imports."""
40+
41+
__origin__ = Literal
42+
__args__: tuple[Any, ...]
43+
44+
45+
def is_literal_generic_alias(expected_type: Any) -> TypeGuard[LiteralGenericAlias]:
46+
return hasattr(expected_type, "__origin__") and expected_type.__origin__ is Literal
47+
48+
49+
class UnionGenericAlias:
50+
"""A type used in place of typing._UnionGenericAlias to avoid private imports."""
51+
52+
__origin__ = Union
53+
__args__: tuple[Any, ...]
54+
55+
56+
def is_union_generic_alias(expected_type: Any) -> TypeGuard[UnionGenericAlias]:
57+
return hasattr(expected_type, "__origin__") and expected_type.__origin__ is Union
58+
59+
2460
########################################################################################
2561
#### MARK: Type Validation
2662

@@ -35,16 +71,28 @@ def validate_or_raise(
3571
are already initialized Statica objects.
3672
"""
3773

38-
# Handle union types
74+
# Handle generic aliases if native python types, e.g. list[int], dict[str, int]
3975

40-
if isinstance(expected_type, UnionType):
41-
validate_type_union(value, expected_type, config)
76+
if isinstance(expected_type, GenericAlias):
77+
validate_type_generic_alias(value, expected_type, config)
4278
return
4379

44-
# Handle generic aliases
80+
# Handle parameterized generic types
4581

46-
if isinstance(expected_type, GenericAlias):
47-
validate_type_generic_alias(value, expected_type, config)
82+
if is_union_generic_alias(expected_type):
83+
validate_type_union_generic_alias(value, expected_type, config)
84+
return
85+
86+
# Handle Literal (e.g. Literal["a", "b"], with any number and type of values)
87+
88+
if is_literal_generic_alias(expected_type):
89+
validate_literal(value, expected_type)
90+
return
91+
92+
# Handle union types
93+
94+
if isinstance(expected_type, UnionType):
95+
validate_type_union(value, expected_type, config)
4896
return
4997

5098
# Handle all other types
@@ -59,6 +107,19 @@ def validate_or_raise(
59107
raise TypeValidationError(msg)
60108

61109

110+
def validate_literal(
111+
value: Any,
112+
expected_type: LiteralGenericAlias,
113+
) -> None:
114+
"""
115+
Validate that the value matches one of the literals in the expected_type.
116+
Throws TypeValidationError if the value is not one of the literals.
117+
"""
118+
if value not in expected_type.__args__:
119+
msg = f"expected one of {expected_type.__args__}, got '{value}'"
120+
raise TypeValidationError(msg)
121+
122+
62123
def validate_type_union(
63124
value: Any,
64125
expected_type: UnionType,
@@ -83,6 +144,30 @@ def validate_type_union(
83144
raise TypeValidationError(msg)
84145

85146

147+
def validate_type_union_generic_alias(
148+
value: Any,
149+
expected_type: UnionGenericAlias,
150+
config: StaticaConfig = default_config,
151+
) -> None:
152+
"""
153+
Validate that the value matches one of the types in the UnionGenericAlias.
154+
Throws TypeValidationError if the type does not match any of the union types.
155+
"""
156+
for sub_type in expected_type.__args__:
157+
try:
158+
validate_or_raise(value, sub_type, config)
159+
except TypeValidationError:
160+
continue # Try the next sub-type
161+
else:
162+
return # Exit if one of the sub-types matches
163+
164+
msg = config.type_error_message.format(
165+
expected_type=expected_type.__args__,
166+
found_type=type(value).__name__,
167+
)
168+
raise TypeValidationError(msg)
169+
170+
86171
def validate_type_generic_alias(
87172
value: Any,
88173
expected_type: GenericAlias,

tests/test_validation.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Literal
2+
13
import pytest
24

35
from statica import Field, Statica, TypeValidationError
@@ -98,3 +100,39 @@ class UnsupportedGeneric(Statica):
98100

99101
with pytest.raises(TypeValidationError):
100102
UnsupportedGeneric(data=frozenset([1, 2, 3]))
103+
104+
105+
def test_validate_literal() -> None:
106+
class LiteralTest(Statica):
107+
data: Literal["a", "b", "c"]
108+
number: Literal[1, 2, 3]
109+
110+
i1 = LiteralTest.from_map({"data": "a", "number": 1})
111+
assert i1.data == "a"
112+
assert i1.number == 1
113+
114+
with pytest.raises(TypeValidationError):
115+
LiteralTest.from_map({"data": "d", "number": 1})
116+
117+
with pytest.raises(TypeValidationError):
118+
LiteralTest.from_map({"data": "a", "number": 4})
119+
120+
121+
def test_validate_literal_optional() -> None:
122+
class LiteralTest(Statica):
123+
data: Literal["a", "b", "c"] | None
124+
number: Literal[1, 2, 3] | None
125+
126+
i1 = LiteralTest.from_map({"data": "a", "number": 1})
127+
assert i1.data == "a"
128+
assert i1.number == 1
129+
130+
i2 = LiteralTest.from_map({"data": None, "number": None})
131+
assert i2.data is None
132+
assert i2.number is None
133+
134+
with pytest.raises(TypeValidationError):
135+
LiteralTest.from_map({"data": "d", "number": 1})
136+
137+
with pytest.raises(TypeValidationError):
138+
LiteralTest.from_map({"data": "a", "number": 4})

0 commit comments

Comments
 (0)