Skip to content

Commit 368640e

Browse files
authored
Merge pull request #13 from coqui-ai/copy-support
Copy support
2 parents 500d31f + eaccbde commit 368640e

File tree

4 files changed

+32
-15
lines changed

4 files changed

+32
-15
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.0.9
1+
0.0.10

coqpit/coqpit.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
from collections.abc import MutableMapping
77
from dataclasses import MISSING as _MISSING
8-
from dataclasses import Field, asdict, dataclass, fields, is_dataclass
8+
from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace
99
from pathlib import Path
1010
from pprint import pprint
1111
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, get_type_hints
@@ -558,20 +558,8 @@ def _is_initialized(self):
558558
at the initialization when no attribute has been defined."""
559559
return "_initialized" in vars(self) and self._initialized
560560

561-
def __setattr__(self, name: str, value: Any) -> None:
562-
if self._is_initialized() and issubclass(type(value), Coqpit):
563-
self.__fields__[name].type = type(value)
564-
return super().__setattr__(name, value)
565-
566-
def __set_fields(self):
567-
"""Create a list of fields defined at the object initialization"""
568-
self.__fields__ = {} # pylint: disable=attribute-defined-outside-init
569-
for field in fields(self):
570-
self.__fields__[field.name] = field
571-
572561
def __post_init__(self):
573562
self._initialized = True
574-
self.__set_fields()
575563
try:
576564
self.check_values()
577565
except AttributeError:
@@ -583,7 +571,7 @@ def __iter__(self):
583571
return iter(asdict(self))
584572

585573
def __len__(self):
586-
return len(self.__fields__)
574+
return len(fields(self))
587575

588576
def __setitem__(self, arg: str, value: Any):
589577
setattr(self, arg, value)
@@ -645,6 +633,9 @@ def check_values(self):
645633
def has(self, arg: str) -> bool:
646634
return arg in vars(self)
647635

636+
def copy(self):
637+
return replace(self)
638+
648639
def update(self, new: dict, allow_new=False) -> None:
649640
"""Update Coqpit fields by the input ```dict```.
650641

tests/test_copying.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import copy
2+
from dataclasses import dataclass
3+
4+
from coqpit.coqpit import Coqpit
5+
6+
7+
@dataclass
8+
class SimpleConfig(Coqpit):
9+
val_a: int = 10
10+
11+
12+
def test_copying():
13+
config = SimpleConfig()
14+
15+
config_new = config.copy()
16+
config_new.val_a = 1234
17+
assert config.val_a != config_new.val_a
18+
19+
config_new = copy.copy(config)
20+
config_new.val_a = 4321
21+
assert config.val_a != config_new.val_a
22+
23+
config_new = copy.deepcopy(config)
24+
config_new.val_a = 4321
25+
assert config.val_a != config_new.val_a

tests/test_simple_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def test_simple_config():
5656
print(dict(**config))
5757

5858
# value assignment by mapping
59+
# TODO: MAYBE this should raise an errorby the value check.
5960
config["val_a"] = -999
6061
print(config["val_a"])
6162
assert config.val_a == -999

0 commit comments

Comments
 (0)