Skip to content

Commit b57565c

Browse files
authored
Merge pull request #32 from coqui-ai/revert-15-allow-file-objects
Revert "Allow file-like objects when saving and loading"
2 parents dabaf67 + fb748a8 commit b57565c

File tree

2 files changed

+18
-45
lines changed

2 files changed

+18
-45
lines changed

coqpit/coqpit.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import argparse
2-
import contextlib
32
import functools
43
import json
54
import operator
@@ -697,35 +696,28 @@ def to_json(self) -> str:
697696
"""Returns a JSON string representation."""
698697
return json.dumps(asdict(self), indent=4, default=_coqpit_json_default)
699698

700-
def save_json(self, file_name: Union[str, Path, Any]) -> None:
699+
def save_json(self, file_name: str) -> None:
701700
"""Save Coqpit to a json file.
702701
703702
Args:
704-
file_name (str, Path or file-like object): path to the output json file or a file-like object to write to.
703+
file_name (str): path to the output json file.
705704
"""
706-
if isinstance(file_name, (Path, str)):
707-
opened = open(file_name, "w", encoding="utf8")
708-
else:
709-
opened = contextlib.nullcontext(file_name)
710-
with opened as f:
705+
with open(file_name, "w", encoding="utf8") as f:
711706
json.dump(asdict(self), f, indent=4)
712707

713-
def load_json(self, file_name: Union[str, Path, Any]) -> None:
708+
def load_json(self, file_name: str) -> None:
714709
"""Load a json file and update matching config fields with type checking.
715710
Non-matching parameters in the json file are ignored.
716711
717712
Args:
718-
file_name (str, Path or file-like object): Path to the json file or a file-like object to read from.
713+
file_name (str): path to the json file.
719714
720715
Returns:
721716
Coqpit: new Coqpit with updated config fields.
722717
"""
723-
if isinstance(file_name, (Path, str)):
724-
opened = open(file_name, "r", encoding="utf8")
725-
else:
726-
opened = contextlib.nullcontext(file_name)
727-
with opened as f:
728-
dump_dict = json.load(f)
718+
with open(file_name, "r", encoding="utf8") as f:
719+
input_str = f.read()
720+
dump_dict = json.loads(input_str)
729721
# TODO: this looks stupid 💆
730722
self = self.deserialize(dump_dict) # pylint: disable=self-cls-assignment
731723
self.check_values()

tests/test_serialization.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,6 @@ class Reference(Coqpit):
3131
)
3232

3333

34-
def assert_equal(a: Reference, b: Reference):
35-
assert len(a) == len(b)
36-
assert a.name == b.name
37-
assert a.size == b.size
38-
assert a.people[0].name == b.people[0].name
39-
assert a.people[1].name == b.people[1].name
40-
assert a.people[2].name == b.people[2].name
41-
assert a.people[0].age == b.people[0].age
42-
assert a.people[1].age == b.people[1].age
43-
assert a.people[2].age == b.people[2].age
44-
45-
4634
def test_serizalization():
4735
file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_serialization.json")
4836

@@ -53,20 +41,13 @@ def test_serizalization():
5341
new_config.load_json(file_path)
5442
new_config.pprint()
5543

56-
assert_equal(ref_config, new_config)
57-
58-
59-
def test_serizalization_fileobject():
60-
"""Test serialization to and from file-like objects instead of paths"""
61-
file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_serialization_file.json")
62-
63-
ref_config = Reference()
64-
with open(file_path, "w", encoding="utf-8") as f:
65-
ref_config.save_json(f)
66-
67-
new_config = Group()
68-
with open(file_path, "r", encoding="utf-8") as f:
69-
new_config.load_json(f)
70-
new_config.pprint()
71-
72-
assert_equal(ref_config, new_config)
44+
# check values
45+
assert len(ref_config) == len(new_config)
46+
assert ref_config.name == new_config.name
47+
assert ref_config.size == new_config.size
48+
assert ref_config.people[0].name == new_config.people[0].name
49+
assert ref_config.people[1].name == new_config.people[1].name
50+
assert ref_config.people[2].name == new_config.people[2].name
51+
assert ref_config.people[0].age == new_config.people[0].age
52+
assert ref_config.people[1].age == new_config.people[1].age
53+
assert ref_config.people[2].age == new_config.people[2].age

0 commit comments

Comments
 (0)