|
1 | 1 | import argparse |
2 | | -import contextlib |
3 | 2 | import functools |
4 | 3 | import json |
5 | 4 | import operator |
@@ -697,35 +696,28 @@ def to_json(self) -> str: |
697 | 696 | """Returns a JSON string representation.""" |
698 | 697 | return json.dumps(asdict(self), indent=4, default=_coqpit_json_default) |
699 | 698 |
|
700 | | - def save_json(self, file_name: Union[str, Path, Any]) -> None: |
| 699 | + def save_json(self, file_name: str) -> None: |
701 | 700 | """Save Coqpit to a json file. |
702 | 701 |
|
703 | 702 | Args: |
704 | | - file_name (str, Path or file-like object): path to the output json file or a file-like object to write to. |
| 703 | + file_name (str): path to the output json file. |
705 | 704 | """ |
706 | | - if isinstance(file_name, (Path, str)): |
707 | | - opened = open(file_name, "w", encoding="utf8") |
708 | | - else: |
709 | | - opened = contextlib.nullcontext(file_name) |
710 | | - with opened as f: |
| 705 | + with open(file_name, "w", encoding="utf8") as f: |
711 | 706 | json.dump(asdict(self), f, indent=4) |
712 | 707 |
|
713 | | - def load_json(self, file_name: Union[str, Path, Any]) -> None: |
| 708 | + def load_json(self, file_name: str) -> None: |
714 | 709 | """Load a json file and update matching config fields with type checking. |
715 | 710 | Non-matching parameters in the json file are ignored. |
716 | 711 |
|
717 | 712 | Args: |
718 | | - file_name (str, Path or file-like object): Path to the json file or a file-like object to read from. |
| 713 | + file_name (str): path to the json file. |
719 | 714 |
|
720 | 715 | Returns: |
721 | 716 | Coqpit: new Coqpit with updated config fields. |
722 | 717 | """ |
723 | | - if isinstance(file_name, (Path, str)): |
724 | | - opened = open(file_name, "r", encoding="utf8") |
725 | | - else: |
726 | | - opened = contextlib.nullcontext(file_name) |
727 | | - with opened as f: |
728 | | - dump_dict = json.load(f) |
| 718 | + with open(file_name, "r", encoding="utf8") as f: |
| 719 | + input_str = f.read() |
| 720 | + dump_dict = json.loads(input_str) |
729 | 721 | # TODO: this looks stupid 💆 |
730 | 722 | self = self.deserialize(dump_dict) # pylint: disable=self-cls-assignment |
731 | 723 | self.check_values() |
|
0 commit comments