diff --git a/src/convert_dataset.py b/src/convert_dataset.py index 17df9773..e773b76c 100644 --- a/src/convert_dataset.py +++ b/src/convert_dataset.py @@ -7,7 +7,7 @@ import platform import warnings from argparse import ArgumentParser, Namespace -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from typing import Dict, Iterable, Optional, Union @@ -72,9 +72,9 @@ class DataSplitConstants: @dataclass class DatasetConstants: - chars_per_sample: int - chars_per_token: int - splits = {} + chars_per_sample: Optional[int] = None + chars_per_token: Optional[int] = None + splits: dict[str, DataSplitConstants] = field(default_factory=dict) def __iter__(self): for _, v in self.splits.items():