diff --git a/.pylintrc b/.pylintrc index 6fb2cf7..56ee834 100644 --- a/.pylintrc +++ b/.pylintrc @@ -394,7 +394,6 @@ logging-modules=logging # all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, # UNDEFINED. confidence=HIGH, - CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, UNDEFINED diff --git a/README.md b/README.md index a0d997b..2585c20 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ What I need from a ML configuration library... Python is good but not universal. Sometimes you train a ML model and use it on a different platform. So, you need your model configuration file importable by other programming languages. -2. Simple dynamic value and type checking with default values. +2. Simple dynamic value and type checking. If you are a beginner in a ML project, it is hard to guess the right values for your ML experiment. Therefore it is important to have some default values and know what range and type of input are expected for each field. @@ -48,7 +48,6 @@ What I need from a ML configuration library... ## 🚫 Limitations - `Union` type dataclass fields cannot be parsed from console arguments due to the type ambiguity. - `JSON` is the only supported serialization format, although the others can be easily integrated. -- `List`type with multiple item type annotations are not supported. (e.g. `List[int, str]`). - `dict` fields are parsed from console arguments as JSON str without type checking. (e.g `--val_dict '{"a":10, "b":100}'`). - `MISSING` fields cannot be avoided when parsing console arguments. @@ -62,7 +61,7 @@ from typing import List, Union from coqpit import MISSING, Coqpit, check_argument -@dataclass +@dataclass # Optional. Coqpit subclasses are auto decorated with dataclass class SimpleConfig(Coqpit): val_a: int = 10 val_b: int = None @@ -122,7 +121,6 @@ from coqpit import Coqpit, check_argument from typing import List, Union -@dataclass class SimpleConfig(Coqpit): val_a: int = 10 val_b: int = None @@ -136,7 +134,6 @@ class SimpleConfig(Coqpit): check_argument('val_c', c, restricted=True) -@dataclass class NestedConfig(Coqpit): val_d: int = 10 val_e: int = None diff --git a/coqpit/__init__.py b/coqpit/__init__.py index 36fc452..38d5537 100644 --- a/coqpit/__init__.py +++ b/coqpit/__init__.py @@ -1 +1 @@ -from coqpit.coqpit import MISSING, Coqpit, check_argument, dataclass +from coqpit.coqpit import MISSING, Coqpit, CoqpitTypeError, check, check_argument, dataclass diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index e214c8b..692a1e5 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -1,14 +1,38 @@ +# pylint: disable=too-many-lines import argparse +import dataclasses import functools +import inspect import json import operator import os +import sys +import types as types_native +import typing from collections.abc import MutableMapping from dataclasses import MISSING as _MISSING from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace from pathlib import Path from pprint import pprint -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, get_type_hints +from typing import ( + Any, + Dict, + FrozenSet, + Generic, + List, + Literal, + Optional, + Set, + Tuple, + Type, + TypedDict, + TypeVar, + Union, + get_type_hints, +) + +import typing_extensions +from typing_extensions import TypeGuard T = TypeVar("T") MISSING: Any = "???" @@ -61,12 +85,17 @@ def is_dict(arg_type: Any) -> bool: Returns: bool: True if input type is `dict` """ + # pylint: disable=bare-except try: return arg_type is dict or arg_type is Dict or arg_type.__origin__ is dict - except AttributeError: + except: return False +def is_pep604_union(arg_type: Type[Any]) -> bool: + return sys.version_info >= (3, 10) and arg_type is types_native.UnionType # type: ignore + + def is_union(arg_type: Any) -> bool: """Check if the input type is `Union`. @@ -135,6 +164,215 @@ def _is_optional_field(field) -> bool: return type(None) in getattr(field.type, "__args__") +# ---------------------------------------------------------------------------- # +# Type Checking # +# ---------------------------------------------------------------------------- # + + +class Error(TypeError): + def __init__(self, arg_type: Type[Any], value: Any, path: Optional[List[str]] = None): + # pylint: disable=super-init-not-called + if type(self) == Error: # pylint: disable=unidiomatic-typecheck + raise ValueError(" [!] `Error` must not be instantiated directly") + self.arg_type = arg_type + self.value = value + self.path = path or [] + + def __str__(self) -> str: + raise NotImplementedError() + + +Result = Optional[Error] # returns error context + + +def _path_to_str(path: List[str]) -> str: + return " -> ".join(reversed(path)) + + +class CoqpitTypeError(Error): + def __init__( + self, + arg_type: Type[Any], + value: Any, + path: Optional[List[str]] = None, + exception: Optional[Any] = None, + ): + super().__init__(arg_type, value, path) + self.arg_type = arg_type + self.value = value + self.path = path or [] + self.exception = exception + + def __str__(self): + path = _path_to_str(self.path) + msg = ( + f" [!] Error in field '{path}'. Expected type {self.arg_type}, got {type(self.value)} (value: {self.value})" + ) + if self.exception is not None: + msg += f"\n{type(self.exception)}: {self.exception}" + return msg + + +def is_error(ret: Result) -> TypeGuard[Error]: + return ret is not None + + +def check_dict(value: Dict[Any, Any], ty: Type[Dict[Any, Any]]) -> Result: + args = typing_extensions.get_args(ty) + try: + # Allow Dict without type hints + ty_key = args[0] + ty_item = args[1] + except IndexError: + return None + for k, v in value.items(): + err = check(k, ty_key) + if is_error(err): + return err + err = check(v, ty_item) + if err is not None: + err.path.append(k) + return err + return None + + +def check_typeddict(value: Any, arg_type: Type[Type[Any]]) -> Result: + if not isinstance(value, dict): + return CoqpitTypeError(arg_type, value) + is_total: bool = arg_type.__total__ # type: ignore + for k, aty in typing.get_type_hints(arg_type).items(): + if k not in value: + if is_total: + return CoqpitTypeError(aty, value, [k]) + continue + v = value[k] + err = check(v, aty) + if err is not None: + err.path.append(k) + return err + return None + + +def check_dataclass(value: Any, arg_type: Type[Any]) -> Result: + if not dataclasses.is_dataclass(value): + return CoqpitTypeError(arg_type, value) + for k, aty in typing.get_type_hints(arg_type).items(): + v = getattr(value, k) + err = check(v, aty) + if err is not None: + err.path.append(k) + return err + return None + + +def check_container(value: Any, arg_type: Union[Type[List[Any]], Type[Set[Any]], Type[FrozenSet[Any]]]) -> Result: + # pytlint: disable=raise-missing-from + try: + ty_item = typing_extensions.get_args(arg_type)[0] + except IndexError as exc: + raise TypeError(f" [!] Unsupported container type: value {value} - type {arg_type}") from exc + for v in value: + err = check(v, ty_item) + if is_error(err): + return err + return None + + +def check_tuple(value: Any, arg_type: Type[Tuple[Any, ...]]) -> Result: + types = typing_extensions.get_args(arg_type) + if len(types) == 2 and types[1] == ...: + # arbitrary length tuple (e.g. Tuple[int, ...]) + for v in value: + err = check(v, types[0]) + if is_error(err): + return err + return None + + if len(value) != len(types): + return CoqpitTypeError(arg_type=arg_type, value=value) + for v, t in zip(value, types): + err = check(v, t) + if is_error(err): + return err + return None + + +def check_literal(value: Any, arg_type: Type[Any]) -> Result: + if all(value != t for t in typing_extensions.get_args(arg_type)): + return CoqpitTypeError(arg_type=arg_type, value=value) + return None + + +def check_union(value: Any, arg_type: Type[Any]) -> Result: + if any(not is_error(check(value, t)) for t in typing_extensions.get_args(arg_type)): + return None + return CoqpitTypeError(arg_type=arg_type, value=value) + + +def check_int(value: Any, arg_type: Type[Any]) -> Result: + if isinstance(value, bool) or not isinstance(value, arg_type): + # allow int value to be defined for float type for convenience + if isinstance(value, int) and arg_type == float: + return None + return CoqpitTypeError(arg_type=arg_type, value=value) + return None + + +def is_typeddict(arg_type: Type[Any]) -> TypeGuard[Type[TypedDict]]: # type: ignore + T = "_TypedDictMeta" # pylint: disable=redefined-outer-name + for mod in [typing, typing_extensions]: + if hasattr(mod, T) and isinstance(arg_type, getattr(mod, T)): + return True + return False + + +def check(value: Any, arg_type: Type[Any]) -> Result: + # pylint: disable=too-many-return-statements + # allow None for every type + if value is None: + return None + if value is arg_type: + return None + if not isinstance(value, type) and dataclasses.is_dataclass(arg_type): + # dataclass + return check_dataclass(value, arg_type) + if is_typeddict(arg_type): + # maybe use`typing.is_typeddict` + return check_typeddict(value, arg_type) + origin_type = typing_extensions.get_origin(arg_type) + if origin_type is not None: + # generics + err = check(value, origin_type) + if is_error(err): + return err + + if origin_type is list or origin_type is set or origin_type is frozenset: + err = check_container(value, arg_type) + elif origin_type is dict: + err = check_dict(value, arg_type) # type: ignore + elif origin_type is tuple: + err = check_tuple(value, arg_type) + elif origin_type is Literal: + err = check_literal(value, arg_type) + elif origin_type is Union or is_pep604_union(origin_type): + err = check_union(value, arg_type) + return err + + if isinstance(arg_type, type): + # concrete type + if is_pep604_union(arg_type): + pass + elif issubclass(arg_type, bool): + if not isinstance(value, arg_type): + return CoqpitTypeError(arg_type=arg_type, value=value) + elif issubclass(arg_type, int): # For boolean + return check_int(value, arg_type) + elif not isinstance(value, arg_type): + return CoqpitTypeError(arg_type=arg_type, value=value) + + return None + + def my_get_type_hints( cls, ): @@ -152,6 +390,11 @@ def my_get_type_hints( return r_dict +# ---------------------------------------------------------------------------- # +# Serialize / Deserialize # +# ---------------------------------------------------------------------------- # + + def _serialize(x): """Pick the right serialization for the datatype of the given input. @@ -284,51 +527,13 @@ def _deserialize(x: Any, field_type: Any) -> Any: return _deserialize_list(x, field_type) if is_union(field_type): return _deserialize_union(x, field_type) - if issubclass(field_type, Serializable): + if issubclass(field_type, (Coqpit, Serializable)): return field_type.deserialize_immutable(x) if is_primitive_type(field_type): return _deserialize_primitive_types(x, field_type) raise ValueError(f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type.") -# Recursive setattr (supports dotted attr names) -def rsetattr(obj, attr, val): - def _setitem(obj, attr, val): - return operator.setitem(obj, int(attr), val) - - pre, _, post = attr.rpartition(".") - setfunc = _setitem if post.isnumeric() else setattr - - return setfunc(rgetattr(obj, pre) if pre else obj, post, val) - - -# Recursive getattr (supports dotted attr names) -def rgetattr(obj, attr, *args): - def _getitem(obj, attr): - return operator.getitem(obj, int(attr), *args) - - def _getattr(obj, attr): - getfunc = _getitem if attr.isnumeric() else getattr - return getfunc(obj, attr, *args) - - return functools.reduce(_getattr, [obj] + attr.split(".")) - - -# Recursive setitem (supports dotted attr names) -def rsetitem(obj, attr, val): - pre, _, post = attr.rpartition(".") - return operator.setitem(rgetitem(obj, pre) if pre else obj, post, val) - - -# Recursive getitem (supports dotted attr names) -def rgetitem(obj, attr, *args): - def _getitem(obj, attr): - return operator.getitem(obj, int(attr) if attr.isnumeric() else attr, *args) - - return functools.reduce(_getitem, [obj] + attr.split(".")) - - -@dataclass class Serializable: """Gives serialization ability to any inheriting dataclass.""" @@ -417,7 +622,7 @@ def deserialize(self, data: dict) -> "Serializable": @classmethod def deserialize_immutable(cls, data: dict) -> "Serializable": - """Parse input dictionary and desrialize its fields to a dataclass. + """Parse input dictionary and deserialize its fields to a dataclass. Returns: Newly created deserialized object. @@ -454,6 +659,46 @@ def deserialize_immutable(cls, data: dict) -> "Serializable": # ---------------------------------------------------------------------------- # +def rsetattr(obj, attr, val): + "Recursive setattr (supports dotted attr names)" + + def _setitem(obj, attr, val): + return operator.setitem(obj, int(attr), val) + + pre, _, post = attr.rpartition(".") + setfunc = _setitem if post.isnumeric() else setattr + + return setfunc(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + "Recursive getattr (supports dotted attr names)" + + def _getitem(obj, attr): + return operator.getitem(obj, int(attr), *args) + + def _getattr(obj, attr): + getfunc = _getitem if attr.isnumeric() else getattr + return getfunc(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split(".")) + + +def rsetitem(obj, attr, val): + "Recursive setitem (supports dotted attr names)" + pre, _, post = attr.rpartition(".") + return operator.setitem(rgetitem(obj, pre) if pre else obj, post, val) + + +def rgetitem(obj, attr, *args): + "Recursive getitem (supports dotted attr names)" + + def _getitem(obj, attr): + return operator.getitem(obj, int(attr) if attr.isnumeric() else attr, *args) + + return functools.reduce(_getitem, [obj] + attr.split(".")) + + def _get_help(field): try: field_help = field.metadata["help"] @@ -576,7 +821,13 @@ def parse_bool(x): # ---------------------------------------------------------------------------- # -@dataclass +def is_decorated_with_dataclass(cls): + source = inspect.getsource(cls) + lines = source.split("\n") + decorator_lines = [line.strip() for line in lines if line.strip().startswith("@")] + return any("dataclass" in line for line in decorator_lines) + + class Coqpit(Serializable, MutableMapping): """Coqpit base class to be inherited by any Coqpit dataclasses. It overrides Python `dict` interface and provides `dict` compatible API. @@ -591,8 +842,18 @@ def _is_initialized(self): at the initialization when no attribute has been defined.""" return "_initialized" in vars(self) and self._initialized + def __init_subclass__(cls, **kwargs) -> Any: + """Auto decorate subclasses with `dataclass` decorator if not already decorated.""" + # pylint: disable=self-cls-assignment + super().__init_subclass__(**kwargs) + if not is_decorated_with_dataclass(cls): + cls = dataclass(cls) + def __post_init__(self): self._initialized = True + err = check_dataclass(self, type(self)) + if err: + raise err try: self.check_values() except AttributeError: @@ -625,7 +886,8 @@ def __getattribute__(self, arg: str): # pylint: disable=no-self-use """Check if the mandatory field is defined when accessing it.""" value = super().__getattribute__(arg) if isinstance(value, str) and value == "???": - raise AttributeError(f" [!] MISSING field {arg} must be defined.") + # raise MISSINGError(arg) + raise ValueError(f" [!] Mandatory field '{arg}' is not defined.") return value def __contains__(self, arg: str): diff --git a/requirements.txt b/requirements.txt index d5673ad..1225b12 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -dataclasses;python_version=='3.6' \ No newline at end of file +dataclasses;python_version=='3.6' +typing_extensions \ No newline at end of file diff --git a/tests/test_faulty_deserialization.json b/tests/test_faulty_deserialization.json new file mode 100644 index 0000000..9ac125c --- /dev/null +++ b/tests/test_faulty_deserialization.json @@ -0,0 +1,18 @@ +{ + "name": "Coqpit", + "size": 3, + "people": [ + { + "name": "Eren", + "age": "11" + }, + { + "name": "Geren", + "age": 12 + }, + { + "name": "Ceren", + "age": 15 + } + ] +} \ No newline at end of file diff --git a/tests/test_faulty_serialization.json b/tests/test_faulty_serialization.json new file mode 100644 index 0000000..891683f --- /dev/null +++ b/tests/test_faulty_serialization.json @@ -0,0 +1,14 @@ +{ + "list_of_list": [ + [ + 1, + 2, + 3 + ], + [ + 4, + 5, + 6 + ] + ] +} \ No newline at end of file diff --git a/tests/test_init_from_dict.py b/tests/test_init_from_dict.py index 8c26faf..91b270c 100644 --- a/tests/test_init_from_dict.py +++ b/tests/test_init_from_dict.py @@ -1,16 +1,14 @@ -from dataclasses import dataclass, field +from dataclasses import field from typing import List from coqpit import Coqpit -@dataclass class Person(Coqpit): name: str = None age: int = None -@dataclass class Reference(Coqpit): name: str = "Coqpit" size: int = 3 @@ -24,12 +22,12 @@ class Reference(Coqpit): people_ids: List[int] = field(default_factory=lambda: [1, 2, 3]) -@dataclass class WithRequired(Coqpit): name: str def test_new_from_dict(): + # pylint: disable=unsubscriptable-object ref_config = Reference(name="Fancy", size=3**10, people=[Person(name="Anonymous", age=42)]) new_config = Reference.new_from_dict( diff --git a/tests/test_merge_configs.py b/tests/test_merge_configs.py index 8f5a52d..8d7fac8 100644 --- a/tests/test_merge_configs.py +++ b/tests/test_merge_configs.py @@ -27,7 +27,7 @@ class Reference(Coqpit): val_e: int = 257 val_f: float = -10.21 val_g: str = "Coqpit is really great!" - val_same: int = 10.21 # duplicate fields are override by the merged Coqpit class. + val_same: float = 10.21 # duplicate fields are override by the merged Coqpit class. def test_config_merge(): diff --git a/tests/test_nested_configs.py b/tests/test_nested_configs.py index 0016ed9..a9e65a6 100644 --- a/tests/test_nested_configs.py +++ b/tests/test_nested_configs.py @@ -1,11 +1,10 @@ import os -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, field from typing import List, Union from coqpit import Coqpit, check_argument -@dataclass class SimpleConfig(Coqpit): val_a: int = 10 val_b: int = None @@ -21,7 +20,6 @@ def check_values( check_argument("val_c", c, restricted=True) -@dataclass class NestedConfig(Coqpit): val_d: int = 10 val_e: int = None diff --git a/tests/test_parse_argparse.py b/tests/test_parse_argparse.py index 8949df2..ee50dad 100644 --- a/tests/test_parse_argparse.py +++ b/tests/test_parse_argparse.py @@ -4,12 +4,10 @@ from coqpit.coqpit import Coqpit, check_argument -@dataclass class SimplerConfig(Coqpit): val_a: int = field(default=None, metadata={"help": "this is val_a"}) -@dataclass class SimpleConfig(Coqpit): val_a: int = field(default=10, metadata={"help": "this is val_a of SimpleConfig"}) val_b: int = field(default=None, metadata={"help": "this is val_b"}) diff --git a/tests/test_parse_known_argparse.py b/tests/test_parse_known_argparse.py index ee17fc0..a08c854 100644 --- a/tests/test_parse_known_argparse.py +++ b/tests/test_parse_known_argparse.py @@ -1,15 +1,13 @@ -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, field from typing import List from coqpit.coqpit import Coqpit, check_argument -@dataclass class SimplerConfig(Coqpit): val_a: int = field(default=None, metadata={"help": "this is val_a"}) -@dataclass class SimpleConfig(Coqpit): val_a: int = field(default=10, metadata={"help": "this is val_a of SimpleConfig"}) val_b: int = field(default=None, metadata={"help": "this is val_b"}) @@ -77,7 +75,7 @@ def test_parse_edited_argparse(): config.val_a = 333 config.val_b = 444 config.val_c = "this is different" - config.mylist_with_default[0].val_a = 777 + config.mylist_with_default[0].val_a = 777 # pylint: disable=unsubscriptable-object print(config.pprint()) # reference config that we like to match with the config above diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 528dab7..68c2fae 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,24 +1,21 @@ import os -from dataclasses import dataclass, field +from dataclasses import field from typing import List -from coqpit import Coqpit +from coqpit import Coqpit, CoqpitTypeError -@dataclass class Person(Coqpit): name: str = None age: int = None -@dataclass class Group(Coqpit): name: str = None size: int = None people: List[Person] = None -@dataclass class Reference(Coqpit): name: str = "Coqpit" size: int = 3 @@ -32,6 +29,7 @@ class Reference(Coqpit): def test_serizalization(): + # pylint: disable=unsubscriptable-object file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_serialization.json") ref_config = Reference() @@ -51,3 +49,15 @@ def test_serizalization(): assert ref_config.people[0].age == new_config.people[0].age assert ref_config.people[1].age == new_config.people[1].age assert ref_config.people[2].age == new_config.people[2].age + + +def test_faulty_deserialization(): + """Try to load a json file with a type mismatch""" + file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_faulty_deserialization.json") + + try: + ref_config = Reference() + ref_config.load_json(file_path) + assert False, "Should have failed" + except CoqpitTypeError: + pass diff --git a/tests/test_simple_config.py b/tests/test_simple_config.py index 9485ca6..de4789f 100644 --- a/tests/test_simple_config.py +++ b/tests/test_simple_config.py @@ -1,11 +1,10 @@ import os -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, field from typing import List, Union from coqpit.coqpit import MISSING, Coqpit, check_argument -@dataclass class SimpleConfig(Coqpit): val_a: int = 10 val_b: int = None @@ -18,7 +17,7 @@ class SimpleConfig(Coqpit): # optional field val_dict: dict = field(default_factory=lambda: {"val_aa": 10, "val_ss": "This is in a dict."}) # list of list - val_listoflist: List[List] = field(default_factory=lambda: [[1, 2], [3, 4]]) + val_listoflist: List[List[int]] = field(default_factory=lambda: [[1, 2], [3, 4]]) val_listofunion: List[List[Union[str, int, bool]]] = field( default_factory=lambda: [[1, 3], [1, "Hi!"], [True, False]] ) @@ -35,7 +34,13 @@ def check_values( def test_simple_config(): file_path = os.path.dirname(os.path.abspath(__file__)) - config = SimpleConfig() + + try: + config = SimpleConfig() + except ValueError as e: + print(" Mandatory field `val_k` is not set. Error: ", e) + + config = SimpleConfig(val_k=10) # try MISSING class argument try: diff --git a/tests/test_type_checking.py b/tests/test_type_checking.py new file mode 100644 index 0000000..0a2a9be --- /dev/null +++ b/tests/test_type_checking.py @@ -0,0 +1,163 @@ +import sys +import typing +from dataclasses import dataclass, field +from enum import Enum +from typing import Callable, Dict, List, Optional, Set, Tuple + +import pytest +from typing_extensions import TypedDict + +from coqpit import Coqpit +from coqpit.coqpit import check, check_dataclass, is_error, is_typeddict + + +def test_tuple(): + assert not is_error(check((1, 2, 3), Tuple[int, ...])) + assert is_error(check((1, "b"), Tuple[int, ...])) + + +def test_set(): + assert is_error(check({"foo", "bar", 1}, Set[str])) + + +def test_str(): + assert is_error(check("foo", List[str])) + + +def test_callable(): + assert is_error(check(1, Callable)) + + +@dataclass +class B: + a: int = 0 + b: Dict[str, int] = field(default_factory=dict) + + +@dataclass +class A: + a: int = 0 + b: str = "" + c: B = field(default_factory=B) + + +class C(Coqpit): + a: int = 0 + b: Dict[str, int] = field(default_factory=dict) + + +def test_error(): + a = A("foo") + err = check_dataclass(a, A) + assert is_error(err) + assert "a" in err.path + + +def test_error_dataclass(): + a = A(c=B(a="foo")) + err = check_dataclass(a, A) + assert is_error(err) + assert "c" in err.path + + +def test_error_dict_value(): + a = A(c=B(b={"foo": "bar"})) + err = check_dataclass(a, A) + assert is_error(err) + assert "c" in err.path + assert "b" in err.path + assert "foo" in err.path + + +def test_error_dict_key(): + a = A(c=B(b={1: 1})) + err = check_dataclass(a, A) + assert is_error(err) + assert "c" in err.path + assert "b" in err.path + assert 1 not in err.path + + +def test_bool(): + assert is_error(check(1, bool)) + assert not is_error(check(1, int)) + + assert not is_error(check(False, bool)) + assert is_error(check(False, int)) + + +class ENUM(Enum): + a = "a" + + +def test_enum(): + assert is_error(check("a", ENUM)) + assert not is_error(check(ENUM.a, ENUM)) + + +class TD(TypedDict): + a: str + b: int + c: Optional["TD"] + + +class TDPartial(TypedDict, total=False): + a: str + b: int + c: Optional["TDPartial"] + + +class TDList(TypedDict): + a: int + b: List[TDPartial] + + +@pytest.mark.parametrize( + "ty,value", + [ + (TDList, {"a": 1, "b": [{"a": "1"}]}), + (TD, {"a": "foo", "b": 1, "c": None}), + (TDPartial, {"a": "foo"}), + (TDPartial, {}), + (TDPartial, {"c": {"c": {"c": {"c": {"c": {}}}}}}), + ( + TD, + { + "a": "foo", + "b": 1, + "c": {"a": "bar", "b": 1, "c": {"a": "foo", "b": 2, "c": None}}, + }, + ), + ], +) +def test_typeddict(ty, value): + assert not is_error(check(value, ty)) + + +@pytest.mark.parametrize( + "ty,value", + [ + (TD, {"a": "foo", "b": 1}), + (TD, {"a": 1}), + (TDPartial, {"c": {"c": {"c": {"c": {"c": {"a": 10}}}}}}), + ( + TD, + { + "a": "foo", + "b": 1, + "c": {"a": "bar", "b": 1, "c": {"a": 1, "b": 2, "c": None}}, + }, + ), + ], +) +def test_typeddict_error(ty, value): + assert is_error(check(value, ty)) + + +if sys.version_info >= (3, 8, 0): + + class XTD(typing.TypedDict): + a: int + + def test_is_typeddict(): + assert is_typeddict(XTD)