Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- (ProblemDefinition) Refactor to use `pydantic` for data validation and parsing.
- (ProblemDefinition) `directory_path` argument in `__init__` is deprecated, use `path` instead.

### Fixes

### Removed
Expand Down
2 changes: 2 additions & 0 deletions docs/source/core_concepts/problem_definition.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ title: Problem definition
- outputs: list of FeatureIdentifiers
- split: arbitrary named splits (train/val/test, etc.) stored as JSON

`ProblemDefinition` is built on top of [Pydantic](https://docs.pydantic.dev/), providing robust data validation. Authorized tasks and score functions are strictly enforced.
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs claim tasks/score functions are "strictly enforced", but the current implementation can bypass Pydantic validation when values are assigned after initialization (e.g., when loading from disk) unless assignment validation is enabled. Either adjust the wording or ensure the implementation validates assignments during load.

Suggested change
`ProblemDefinition` is built on top of [Pydantic](https://docs.pydantic.dev/), providing robust data validation. Authorized tasks and score functions are strictly enforced.
`ProblemDefinition` is built on top of [Pydantic](https://docs.pydantic.dev/), providing robust data validation when defining and updating problems through its public API. Authorized tasks and score functions are validated accordingly.

Copilot uses AI. Check for mistakes.

Typical usage:

```python
Expand Down
180 changes: 152 additions & 28 deletions src/plaid/problem_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@
import csv
import json
import logging
import warnings
from pathlib import Path
from typing import Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union

import yaml
Comment on lines +26 to 28
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ruff is configured with Pyflakes (F401) and isort (I) (see ruff.toml). This import block is currently unsorted, and PrivateAttr/model_validator appear to be unused in this file; this will fail lint. Reorder the typing/pydantic imports and drop unused names (or start using them).

Copilot uses AI. Check for mistakes.
from packaging.version import Version
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
)

import plaid
from plaid.constants import AUTHORIZED_SCORE_FUNCTIONS, AUTHORIZED_TASKS
Expand All @@ -42,13 +49,46 @@
# %% Classes


class ProblemDefinition(object):
class ProblemDefinition(BaseModel):
"""Gathers all necessary informations to define a learning problem."""

model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ProblemDefinition is mutating validated fields during _load_from_dir_/_initialize_from_problem_infos_dict via property setters after super().__init__(). With the current model_config, assignments won’t trigger Pydantic validators, so invalid task/score_function values loaded from disk can slip through. Consider enabling validate_assignment=True (or refactor loading to use model_validate) to make validation effective beyond initial construction.

Suggested change
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="ignore",
validate_assignment=True,
)

Copilot uses AI. Check for mistakes.

name: Optional[str] = Field(None, description="Name of the problem")
version: Optional[Union[str, Version]] = Field(
default_factory=lambda: Version(plaid.__version__),
description="Version of the problem definition",
)
task: Optional[str] = Field(None, description="Task of the problem")
score_function: Optional[str] = Field(None, description="Score function used")

in_features_identifiers: List[Union[str, FeatureIdentifier]] = Field(
default_factory=list
)
out_features_identifiers: List[Union[str, FeatureIdentifier]] = Field(
default_factory=list
)
constant_features_identifiers: List[str] = Field(default_factory=list)

# Legacy fields
in_scalars_names: List[str] = Field(default_factory=list)
out_scalars_names: List[str] = Field(default_factory=list)
in_timeseries_names: List[str] = Field(default_factory=list)
out_timeseries_names: List[str] = Field(default_factory=list)
in_fields_names: List[str] = Field(default_factory=list)
out_fields_names: List[str] = Field(default_factory=list)
in_meshes_names: List[str] = Field(default_factory=list)
out_meshes_names: List[str] = Field(default_factory=list)

split: Optional[Dict[str, Any]] = Field(None)
train_split: Optional[Dict[str, Dict[str, Any]]] = Field(None)
test_split: Optional[Dict[str, Dict[str, Any]]] = Field(None)

def __init__(
self,
path: Optional[Union[str, Path]] = None,
directory_path: Optional[Union[str, Path]] = None,
**data: Any,
) -> None:
"""Initialize an empty :class:`ProblemDefinition <plaid.problem_definition.ProblemDefinition>`.

Expand All @@ -57,6 +97,7 @@ def __init__(
Args:
path (Union[str,Path], optional): The path from which to load PLAID problem definition files.
directory_path (Union[str,Path], optional): Deprecated, use `path` instead.
**data: Additional arguments to initialize the Pydantic model.

Example:
.. code-block:: python
Expand All @@ -73,40 +114,123 @@ def __init__(
print(problem_definition)
>>> ProblemDefinition(input_scalars_names=['s_1'], output_scalars_names=['s_2'], input_meshes_names=['mesh'], task='regression')
"""
self._name: str = None
self._version: Union[Version] = Version(plaid.__version__)
self._task: str = None
self._score_function: str = None
self.in_features_identifiers: Sequence[Union[str, FeatureIdentifier]] = []
self.out_features_identifiers: Sequence[Union[str, FeatureIdentifier]] = []
self.constant_features_identifiers: list[str] = []
self.in_scalars_names: list[str] = []
self.out_scalars_names: list[str] = []
self.in_timeseries_names: list[str] = []
self.out_timeseries_names: list[str] = []
self.in_fields_names: list[str] = []
self.out_fields_names: list[str] = []
self.in_meshes_names: list[str] = []
self.out_meshes_names: list[str] = []
self._split: Optional[dict[str, IndexType]] = None
self._train_split: Optional[dict[str, dict[str, IndexType]]] = None
self._test_split: Optional[dict[str, dict[str, IndexType]]] = None
super().__init__(**data)

if directory_path is not None:
if path is not None:
raise ValueError(
"Arguments `path` and `directory_path` cannot be both set. Use only `path` as `directory_path` is deprecated."
)
else:
path = directory_path
logger.warning(
"DeprecationWarning: 'directory_path' is deprecated, use 'path' instead."
)
warnings.warn(
"`directory_path` is deprecated, use `path` instead.",
DeprecationWarning,
)
Comment on lines +125 to +127
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The raise line is indented with an extra space, which violates the repo’s Ruff E111 indentation rule (indentation not a multiple of four). Align the indentation inside this if block.

Copilot uses AI. Check for mistakes.
path = directory_path

if path is not None:
path = Path(path)
self._load_from_dir_(path)

Comment on lines 122 to 133
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behavior: task and score_function can now be validated via Pydantic on model construction. There are existing tests for set_task/set_score_function, but none covering ProblemDefinition(task=...) / ProblemDefinition(score_function=...) success and failure cases. Add tests to lock in the new validation path.

Copilot uses AI. Check for mistakes.
@field_validator("task")
@classmethod
def validate_task(cls, v: Optional[str]) -> Optional[str]:
Comment on lines +134 to +136
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same indentation issue here: the raise statement has indentation that isn’t a multiple of four, which will be flagged by Ruff (E111).

Copilot uses AI. Check for mistakes.
"""Validate that the task is among the authorized tasks."""
Comment on lines +135 to +137
Copy link

Copilot AI Feb 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message refers to "authorized tasks" but this validator is for score_function. Update the message to mention authorized score functions (and ideally include AUTHORIZED_SCORE_FUNCTIONS).

Copilot uses AI. Check for mistakes.
if v is not None and v not in AUTHORIZED_TASKS:
raise ValueError(
f"{v} not among authorized tasks. Maybe you want to try among: {AUTHORIZED_TASKS}"
)
return v

@field_validator("score_function")
@classmethod
def validate_score_function(cls, v: Optional[str]) -> Optional[str]:
"""Validate that the score function is among the authorized score functions."""
if v is not None and v not in AUTHORIZED_SCORE_FUNCTIONS:
raise ValueError(
f"{v} not among authorized score functions. Maybe you want to try among: {AUTHORIZED_SCORE_FUNCTIONS}"
)
return v

@property
def _name(self) -> Optional[str]:
return self.name

@_name.setter
def _name(self, value: Optional[str]):
self.name = value

@property
def _version(self) -> Optional[Union[str, Version]]:
return self.version

@_version.setter
def _version(self, value: Optional[Union[str, Version]]):
self.version = value

@property
def _task(self) -> Optional[str]:
return self.task

@_task.setter
def _task(self, value: Optional[str]):
self.task = value

@property
def _score_function(self) -> Optional[str]:
return self.score_function

@_score_function.setter
def _score_function(self, value: Optional[str]):
self.score_function = value

@property
def _in_features_identifiers(self) -> List[Union[str, FeatureIdentifier]]:
return self.in_features_identifiers

@_in_features_identifiers.setter
def _in_features_identifiers(self, value: List[Union[str, FeatureIdentifier]]):
self.in_features_identifiers = value

@property
def _out_features_identifiers(self) -> List[Union[str, FeatureIdentifier]]:
return self.out_features_identifiers

@_out_features_identifiers.setter
def _out_features_identifiers(self, value: List[Union[str, FeatureIdentifier]]):
self.out_features_identifiers = value

@property
def _constant_features_identifiers(self) -> List[str]:
return self.constant_features_identifiers

@_constant_features_identifiers.setter
def _constant_features_identifiers(self, value: List[str]):
self.constant_features_identifiers = value

@property
def _split(self) -> Optional[Dict[str, Any]]:
return self.split

@_split.setter
def _split(self, value: Optional[Dict[str, Any]]):
self.split = value

@property
def _train_split(self) -> Optional[Dict[str, Dict[str, Any]]]:
return self.train_split

@_train_split.setter
def _train_split(self, value: Optional[Dict[str, Dict[str, Any]]]):
self.train_split = value

@property
def _test_split(self) -> Optional[Dict[str, Dict[str, Any]]]:
return self.test_split

@_test_split.setter
def _test_split(self, value: Optional[Dict[str, Dict[str, Any]]]):
self.test_split = value

# -------------------------------------------------------------------------#
def get_name(self) -> str:
"""Get the name. None if not defined.
Expand Down Expand Up @@ -156,7 +280,7 @@ def set_task(self, task: str) -> None:
elif task in AUTHORIZED_TASKS:
self._task = task
else:
raise TypeError(
raise ValueError(
f"{task} not among authorized tasks. Maybe you want to try among: {AUTHORIZED_TASKS}"
)

Expand All @@ -182,8 +306,8 @@ def set_score_function(self, score_function: str) -> None:
elif score_function in AUTHORIZED_SCORE_FUNCTIONS:
self._score_function = score_function
else:
raise TypeError(
f"{score_function} not among authorized tasks. Maybe you want to try among: {AUTHORIZED_SCORE_FUNCTIONS}"
raise ValueError(
f"{score_function} not among authorized score functions. Maybe you want to try among: {AUTHORIZED_SCORE_FUNCTIONS}"
)

# -------------------------------------------------------------------------#
Expand Down
44 changes: 42 additions & 2 deletions tests/test_problem_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@

import pytest
from packaging.version import Version
from pydantic import ValidationError

import plaid
from plaid.constants import AUTHORIZED_SCORE_FUNCTIONS, AUTHORIZED_TASKS
from plaid.containers import FeatureIdentifier
from plaid.problem_definition import ProblemDefinition

Expand Down Expand Up @@ -144,7 +146,7 @@ def test_version(self, problem_definition):
# -------------------------------------------------------------------------#
def test_task(self, problem_definition):
# Unauthorized task
with pytest.raises(TypeError):
with pytest.raises(ValueError):
problem_definition.set_task("ighyurgv")
problem_definition.set_task("classification")
with pytest.raises(ValueError):
Expand All @@ -155,7 +157,7 @@ def test_task(self, problem_definition):
# -------------------------------------------------------------------------#
def test_score_function(self, problem_definition):
# Unauthorized task
with pytest.raises(TypeError):
with pytest.raises(ValueError):
problem_definition.set_score_function("ighyurgv")
problem_definition.set_score_function("RRMSE")
with pytest.raises(ValueError):
Expand Down Expand Up @@ -757,3 +759,41 @@ def test_extract_problem_definition_from_identifiers(self, problem_definition):
assert sub_problem_definition.get_task() == "regression"
assert sub_problem_definition.get_name() == "regression_1"
assert sub_problem_definition.get_split() == {"train": [0, 1], "test": [2, 3]}

def test_init_invalid_task(self):
"""Test initialization with an invalid task to trigger Pydantic validator."""
with pytest.raises(ValidationError) as excinfo:
ProblemDefinition(task="invalid_task_name")
# Check that the underlying error is TypeError (as raised by validator) or just check msg
assert "not among authorized tasks" in str(excinfo.value)

def test_init_invalid_score_function(self):
"""Test initialization with an invalid score function."""
with pytest.raises(ValidationError) as excinfo:
ProblemDefinition(score_function="invalid_score_function")
assert "not among authorized score functions" in str(excinfo.value)

def test_init_valid_task(self):
"""Test initialization with a valid task."""
task = list(AUTHORIZED_TASKS)[0]
pd = ProblemDefinition(task=task)
assert pd.task == task

def test_init_valid_score_function(self):
"""Test initialization with a valid score function."""
sf = list(AUTHORIZED_SCORE_FUNCTIONS)[0]
pd = ProblemDefinition(score_function=sf)
assert pd.score_function == sf

def test_legacy_properties(self):
"""Test legacy property access via Pydantic model."""
pd = ProblemDefinition()
pd._name = "test_name"
assert pd.name == "test_name"
assert pd._name == "test_name"

pd._task = list(AUTHORIZED_TASKS)[0]
assert pd.task == list(AUTHORIZED_TASKS)[0]

pd._score_function = list(AUTHORIZED_SCORE_FUNCTIONS)[0]
assert pd.score_function == list(AUTHORIZED_SCORE_FUNCTIONS)[0]
Loading