diff --git a/CHANGELOG.md b/CHANGELOG.md index abc6da25..c5b79791 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- (ProblemDefinition) Refactor to use `pydantic` for data validation and parsing. +- (ProblemDefinition) `directory_path` argument in `__init__` is deprecated, use `path` instead. + ### Fixes ### Removed diff --git a/docs/source/core_concepts/problem_definition.md b/docs/source/core_concepts/problem_definition.md index 2db59b19..79fcd2b4 100644 --- a/docs/source/core_concepts/problem_definition.md +++ b/docs/source/core_concepts/problem_definition.md @@ -10,6 +10,8 @@ title: Problem definition - outputs: list of FeatureIdentifiers - split: arbitrary named splits (train/val/test, etc.) stored as JSON +`ProblemDefinition` is built on top of [Pydantic](https://docs.pydantic.dev/), providing robust data validation. Authorized tasks and score functions are strictly enforced. + Typical usage: ```python diff --git a/src/plaid/problem_definition.py b/src/plaid/problem_definition.py index 1e49f219..d34cc36b 100644 --- a/src/plaid/problem_definition.py +++ b/src/plaid/problem_definition.py @@ -21,11 +21,18 @@ import csv import json import logging +import warnings from pathlib import Path -from typing import Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union import yaml from packaging.version import Version +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, +) import plaid from plaid.constants import AUTHORIZED_SCORE_FUNCTIONS, AUTHORIZED_TASKS @@ -42,13 +49,46 @@ # %% Classes -class ProblemDefinition(object): +class ProblemDefinition(BaseModel): """Gathers all necessary informations to define a learning problem.""" + model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore") + + name: Optional[str] = Field(None, description="Name of the problem") + version: Optional[Union[str, Version]] = Field( + default_factory=lambda: Version(plaid.__version__), + description="Version of the problem definition", + ) + task: Optional[str] = Field(None, description="Task of the problem") + score_function: Optional[str] = Field(None, description="Score function used") + + in_features_identifiers: List[Union[str, FeatureIdentifier]] = Field( + default_factory=list + ) + out_features_identifiers: List[Union[str, FeatureIdentifier]] = Field( + default_factory=list + ) + constant_features_identifiers: List[str] = Field(default_factory=list) + + # Legacy fields + in_scalars_names: List[str] = Field(default_factory=list) + out_scalars_names: List[str] = Field(default_factory=list) + in_timeseries_names: List[str] = Field(default_factory=list) + out_timeseries_names: List[str] = Field(default_factory=list) + in_fields_names: List[str] = Field(default_factory=list) + out_fields_names: List[str] = Field(default_factory=list) + in_meshes_names: List[str] = Field(default_factory=list) + out_meshes_names: List[str] = Field(default_factory=list) + + split: Optional[Dict[str, Any]] = Field(None) + train_split: Optional[Dict[str, Dict[str, Any]]] = Field(None) + test_split: Optional[Dict[str, Dict[str, Any]]] = Field(None) + def __init__( self, path: Optional[Union[str, Path]] = None, directory_path: Optional[Union[str, Path]] = None, + **data: Any, ) -> None: """Initialize an empty :class:`ProblemDefinition `. @@ -57,6 +97,7 @@ def __init__( Args: path (Union[str,Path], optional): The path from which to load PLAID problem definition files. directory_path (Union[str,Path], optional): Deprecated, use `path` instead. + **data: Additional arguments to initialize the Pydantic model. Example: .. code-block:: python @@ -73,40 +114,123 @@ def __init__( print(problem_definition) >>> ProblemDefinition(input_scalars_names=['s_1'], output_scalars_names=['s_2'], input_meshes_names=['mesh'], task='regression') """ - self._name: str = None - self._version: Union[Version] = Version(plaid.__version__) - self._task: str = None - self._score_function: str = None - self.in_features_identifiers: Sequence[Union[str, FeatureIdentifier]] = [] - self.out_features_identifiers: Sequence[Union[str, FeatureIdentifier]] = [] - self.constant_features_identifiers: list[str] = [] - self.in_scalars_names: list[str] = [] - self.out_scalars_names: list[str] = [] - self.in_timeseries_names: list[str] = [] - self.out_timeseries_names: list[str] = [] - self.in_fields_names: list[str] = [] - self.out_fields_names: list[str] = [] - self.in_meshes_names: list[str] = [] - self.out_meshes_names: list[str] = [] - self._split: Optional[dict[str, IndexType]] = None - self._train_split: Optional[dict[str, dict[str, IndexType]]] = None - self._test_split: Optional[dict[str, dict[str, IndexType]]] = None + super().__init__(**data) if directory_path is not None: if path is not None: raise ValueError( "Arguments `path` and `directory_path` cannot be both set. Use only `path` as `directory_path` is deprecated." ) - else: - path = directory_path - logger.warning( - "DeprecationWarning: 'directory_path' is deprecated, use 'path' instead." - ) + warnings.warn( + "`directory_path` is deprecated, use `path` instead.", + DeprecationWarning, + ) + path = directory_path if path is not None: path = Path(path) self._load_from_dir_(path) + @field_validator("task") + @classmethod + def validate_task(cls, v: Optional[str]) -> Optional[str]: + """Validate that the task is among the authorized tasks.""" + if v is not None and v not in AUTHORIZED_TASKS: + raise ValueError( + f"{v} not among authorized tasks. Maybe you want to try among: {AUTHORIZED_TASKS}" + ) + return v + + @field_validator("score_function") + @classmethod + def validate_score_function(cls, v: Optional[str]) -> Optional[str]: + """Validate that the score function is among the authorized score functions.""" + if v is not None and v not in AUTHORIZED_SCORE_FUNCTIONS: + raise ValueError( + f"{v} not among authorized score functions. Maybe you want to try among: {AUTHORIZED_SCORE_FUNCTIONS}" + ) + return v + + @property + def _name(self) -> Optional[str]: + return self.name + + @_name.setter + def _name(self, value: Optional[str]): + self.name = value + + @property + def _version(self) -> Optional[Union[str, Version]]: + return self.version + + @_version.setter + def _version(self, value: Optional[Union[str, Version]]): + self.version = value + + @property + def _task(self) -> Optional[str]: + return self.task + + @_task.setter + def _task(self, value: Optional[str]): + self.task = value + + @property + def _score_function(self) -> Optional[str]: + return self.score_function + + @_score_function.setter + def _score_function(self, value: Optional[str]): + self.score_function = value + + @property + def _in_features_identifiers(self) -> List[Union[str, FeatureIdentifier]]: + return self.in_features_identifiers + + @_in_features_identifiers.setter + def _in_features_identifiers(self, value: List[Union[str, FeatureIdentifier]]): + self.in_features_identifiers = value + + @property + def _out_features_identifiers(self) -> List[Union[str, FeatureIdentifier]]: + return self.out_features_identifiers + + @_out_features_identifiers.setter + def _out_features_identifiers(self, value: List[Union[str, FeatureIdentifier]]): + self.out_features_identifiers = value + + @property + def _constant_features_identifiers(self) -> List[str]: + return self.constant_features_identifiers + + @_constant_features_identifiers.setter + def _constant_features_identifiers(self, value: List[str]): + self.constant_features_identifiers = value + + @property + def _split(self) -> Optional[Dict[str, Any]]: + return self.split + + @_split.setter + def _split(self, value: Optional[Dict[str, Any]]): + self.split = value + + @property + def _train_split(self) -> Optional[Dict[str, Dict[str, Any]]]: + return self.train_split + + @_train_split.setter + def _train_split(self, value: Optional[Dict[str, Dict[str, Any]]]): + self.train_split = value + + @property + def _test_split(self) -> Optional[Dict[str, Dict[str, Any]]]: + return self.test_split + + @_test_split.setter + def _test_split(self, value: Optional[Dict[str, Dict[str, Any]]]): + self.test_split = value + # -------------------------------------------------------------------------# def get_name(self) -> str: """Get the name. None if not defined. @@ -156,7 +280,7 @@ def set_task(self, task: str) -> None: elif task in AUTHORIZED_TASKS: self._task = task else: - raise TypeError( + raise ValueError( f"{task} not among authorized tasks. Maybe you want to try among: {AUTHORIZED_TASKS}" ) @@ -182,8 +306,8 @@ def set_score_function(self, score_function: str) -> None: elif score_function in AUTHORIZED_SCORE_FUNCTIONS: self._score_function = score_function else: - raise TypeError( - f"{score_function} not among authorized tasks. Maybe you want to try among: {AUTHORIZED_SCORE_FUNCTIONS}" + raise ValueError( + f"{score_function} not among authorized score functions. Maybe you want to try among: {AUTHORIZED_SCORE_FUNCTIONS}" ) # -------------------------------------------------------------------------# diff --git a/tests/test_problem_definition.py b/tests/test_problem_definition.py index 5dcfbfd6..7b8029ea 100644 --- a/tests/test_problem_definition.py +++ b/tests/test_problem_definition.py @@ -13,8 +13,10 @@ import pytest from packaging.version import Version +from pydantic import ValidationError import plaid +from plaid.constants import AUTHORIZED_SCORE_FUNCTIONS, AUTHORIZED_TASKS from plaid.containers import FeatureIdentifier from plaid.problem_definition import ProblemDefinition @@ -144,7 +146,7 @@ def test_version(self, problem_definition): # -------------------------------------------------------------------------# def test_task(self, problem_definition): # Unauthorized task - with pytest.raises(TypeError): + with pytest.raises(ValueError): problem_definition.set_task("ighyurgv") problem_definition.set_task("classification") with pytest.raises(ValueError): @@ -155,7 +157,7 @@ def test_task(self, problem_definition): # -------------------------------------------------------------------------# def test_score_function(self, problem_definition): # Unauthorized task - with pytest.raises(TypeError): + with pytest.raises(ValueError): problem_definition.set_score_function("ighyurgv") problem_definition.set_score_function("RRMSE") with pytest.raises(ValueError): @@ -757,3 +759,41 @@ def test_extract_problem_definition_from_identifiers(self, problem_definition): assert sub_problem_definition.get_task() == "regression" assert sub_problem_definition.get_name() == "regression_1" assert sub_problem_definition.get_split() == {"train": [0, 1], "test": [2, 3]} + + def test_init_invalid_task(self): + """Test initialization with an invalid task to trigger Pydantic validator.""" + with pytest.raises(ValidationError) as excinfo: + ProblemDefinition(task="invalid_task_name") + # Check that the underlying error is TypeError (as raised by validator) or just check msg + assert "not among authorized tasks" in str(excinfo.value) + + def test_init_invalid_score_function(self): + """Test initialization with an invalid score function.""" + with pytest.raises(ValidationError) as excinfo: + ProblemDefinition(score_function="invalid_score_function") + assert "not among authorized score functions" in str(excinfo.value) + + def test_init_valid_task(self): + """Test initialization with a valid task.""" + task = list(AUTHORIZED_TASKS)[0] + pd = ProblemDefinition(task=task) + assert pd.task == task + + def test_init_valid_score_function(self): + """Test initialization with a valid score function.""" + sf = list(AUTHORIZED_SCORE_FUNCTIONS)[0] + pd = ProblemDefinition(score_function=sf) + assert pd.score_function == sf + + def test_legacy_properties(self): + """Test legacy property access via Pydantic model.""" + pd = ProblemDefinition() + pd._name = "test_name" + assert pd.name == "test_name" + assert pd._name == "test_name" + + pd._task = list(AUTHORIZED_TASKS)[0] + assert pd.task == list(AUTHORIZED_TASKS)[0] + + pd._score_function = list(AUTHORIZED_SCORE_FUNCTIONS)[0] + assert pd.score_function == list(AUTHORIZED_SCORE_FUNCTIONS)[0]