diff --git a/eegdash/dataset/dataset.py b/eegdash/dataset/dataset.py index 664c4fc5e..e0f5def7a 100644 --- a/eegdash/dataset/dataset.py +++ b/eegdash/dataset/dataset.py @@ -14,7 +14,6 @@ from .. import downloader from ..bids_metadata import ( build_query_from_kwargs, - get_entities_from_record, merge_participants_fields, normalize_key, ) @@ -178,6 +177,23 @@ class EEGDashDataset(BaseConcatDataset, metaclass=NumpyDocstringInheritanceInitM Skipped recordings are flagged via ``ds._skipped`` so callers can filter them out with a list comprehension after iteration. + description_precedence : str, default "record" + Which source wins when the same field appears in both the record and + the embedded ``participant_tsv`` data: + + - ``"record"`` (default): the record-level value is kept. + - ``"participant_tsv"``: the ``participant_tsv`` value overwrites the + record value for conflicting fields. + + In both cases a ``debug``-level log is emitted when a conflict is + detected. + + .. note:: + When ``description_precedence="participant_tsv"``, a ``None`` + value in ``participant_tsv`` will overwrite a non-``None`` record + value for the same field. This is deliberate — choosing this mode + means trusting the ``participant_tsv`` source fully, including its + gaps. **kwargs : dict Additional keyword arguments serving two purposes: @@ -202,6 +218,7 @@ def __init__( auth_token: str | None = None, on_error: str = "raise", max_concurrency: int = 20, + description_precedence: str = "record", **kwargs, ): # Internal-only kwargs @@ -209,6 +226,13 @@ def __init__( self._dedupe_records: bool = kwargs.pop("_dedupe_records", False) self._on_error = on_error + _valid = {"record", "participant_tsv"} + if description_precedence not in _valid: + raise ValueError( + f"description_precedence must be one of {sorted(_valid)}, " + f"got {description_precedence!r}" + ) + self._description_precedence = description_precedence self.s3_bucket = s3_bucket self.database = database self.auth_token = auth_token @@ -246,15 +270,17 @@ def __init__( if records is not None: self.records = self._normalize_records(records) - - datasets = [ - EEGDashRaw( - record, - self.cache_dir, - **base_dataset_kwargs, + datasets = [] + for norm_record in self.records: + description = self._build_description(norm_record, description_fields) + datasets.append( + EEGDashRaw( + norm_record, + self.cache_dir, + description=description, + **base_dataset_kwargs, + ) ) - for record in self.records - ] elif not download: # only assume local data is complete if not downloading if not self.data_dir.exists(): raise ValueError( @@ -272,26 +298,20 @@ def __init__( datasets = [] for record in records: - # Start with entity values from filename (supports v1 and v2 formats) - desc: dict[str, Any] = get_entities_from_record(record) - + part_row: dict[str, Any] | None = None if bids_ds is not None: try: rel_from_dataset = Path(record["bidspath"]).relative_to( record["dataset"] ) # type: ignore[index] local_file = (self.data_dir / rel_from_dataset).as_posix() - part_row = bids_ds.subject_participant_tsv(local_file) - desc = merge_participants_fields( - description=desc, - participants_row=part_row - if isinstance(part_row, dict) - else None, - description_fields=description_fields, - ) + row = bids_ds.subject_participant_tsv(local_file) + part_row = row if isinstance(row, dict) else None except Exception: pass - + desc = self._build_description( + record, description_fields, participants_row=part_row + ) datasets.append( EEGDashRaw( record=record, @@ -561,6 +581,89 @@ def _find_local_bids_records( """ return discover_local_bids_records(dataset_root, filters) + def _build_description( + self, + record: dict[str, Any], + description_fields: list[str], + participants_row: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Build a description dict for a single record. + + Extracts values for each requested field from the record, then merges + participant data from either an explicit ``participants_row`` (offline + path, from a local ``participants.tsv``) or the embedded + ``participant_tsv`` key inside the record (online paths). Fields still + absent after the merge are set to ``None`` so the schema is always + complete. When both the record and participant data carry the same + field, precedence is determined by ``self._description_precedence``; a + ``debug``-level log is emitted when the values differ. + + Parameters + ---------- + record : dict + The metadata for a single record. + description_fields : list of str + The fields to include in the description. + participants_row : dict or None + Optional participant-level metadata to merge. If None, the method + will look for an embedded ``participant_tsv`` key in the record. + + Returns + ------- + dict + A dictionary containing the requested description fields for the record. + + """ + description: dict[str, Any] = {} + + for field_name in description_fields: + value = self._find_key_in_nested_dict(record, field_name) + if value is not None: + description[field_name] = value + + effective_part = participants_row + if effective_part is None: + embedded = self._find_key_in_nested_dict(record, "participant_tsv") + if isinstance(embedded, dict): + effective_part = embedded + + if isinstance(effective_part, dict): + norm_present = { + normalize_key(k): k for k, v in description.items() if v is not None + } + for part_key, part_val in effective_part.items(): + existing_field = norm_present.get(normalize_key(part_key)) + if ( + existing_field is not None + and description[existing_field] != part_val + ): + if self._description_precedence == "participant_tsv": + logger.debug( + "Field '%s': participant_tsv value %r overwrote record value %r.", + existing_field, + part_val, + description[existing_field], + ) + description[existing_field] = part_val + else: + logger.debug( + "Field '%s': record value %r kept over participant_tsv value %r.", + existing_field, + description[existing_field], + part_val, + ) + description = merge_participants_fields( + description=description, + participants_row=effective_part, + description_fields=description_fields, + ) + + # Ensure all requested fields are present; None for any that were not found + for field in description_fields: + description.setdefault(field, None) + + return description + def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any: """Recursively search for a key in nested dicts/lists. @@ -639,20 +742,7 @@ def _find_datasets( f"Record data_name: {record.get('data_name', 'unknown')}" ) - description: dict[str, Any] = {} - # Requested fields first (normalized matching) - for field_name in description_fields: - value = self._find_key_in_nested_dict(record, field_name) - if value is not None: - description[field_name] = value - # Merge all participants.tsv columns generically - part = self._find_key_in_nested_dict(record, "participant_tsv") - if isinstance(part, dict): - description = merge_participants_fields( - description=description, - participants_row=part, - description_fields=description_fields, - ) + description = self._build_description(record, description_fields) datasets.append( EEGDashRaw( record, diff --git a/tests/unit_tests/dataset/test_build_description.py b/tests/unit_tests/dataset/test_build_description.py new file mode 100644 index 000000000..e2198a51c --- /dev/null +++ b/tests/unit_tests/dataset/test_build_description.py @@ -0,0 +1,204 @@ +"""Tests for the unified _build_description helper and the three EEGDashDataset +initialization paths (records=, offline, query). + +Covers: +- Description parity across all three construction paths +- Configurable description_precedence (record vs participant_tsv) +""" + +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from eegdash.dataset.dataset import EEGDashDataset + +# --------------------------------------------------------------------------- +# Lightweight EEGDashRaw stub +# --------------------------------------------------------------------------- + + +class _FakeRaw: + """Minimal stand-in for EEGDashRaw. + + Satisfies the two things BaseConcatDataset needs from each element: + - __len__ returning an integer + - description as a pd.Series + + Stores the description kwarg so tests can inspect it later. + """ + + def __init__(self, record, cache_dir=None, description=None, **kwargs): + _ = ( + cache_dir, + kwargs, + ) # accepted to match EEGDashRaw's signature; not needed in stub + self.record = record + self.description = pd.Series(description or {}, dtype=object) + + def __len__(self): + return 1 + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def parity_record(tmp_path): + """A v2-format record suitable for all three construction paths.""" + return { + "dataset": "ds_parity", + "bidspath": "ds_parity/sub-01/eeg/sub-01_task-rest_eeg.set", + "bids_relpath": "sub-01/eeg/sub-01_task-rest_eeg.set", + "extension": ".set", + "entities": {"subject": "01", "task": "rest"}, + "entities_mne": {"subject": "01", "task": "rest"}, + "storage": {"backend": "local", "base": str(tmp_path / "ds_parity")}, + } + + +# --------------------------------------------------------------------------- +# 1. Path parity: records=, offline, and query paths produce identical descriptions +# --------------------------------------------------------------------------- + + +def test_dataset_initialization_path_parity(tmp_path, parity_record): + """All three EEGDashDataset construction paths must build identical descriptions. + + _FakeRaw is used instead of MagicMock so BaseConcatDataset can safely + call len() and access .description on each element. The description + passed to each _FakeRaw constructor is extracted and compared via + pandas.testing.assert_frame_equal. + """ + description_fields = ["subject", "task"] + (tmp_path / "ds_parity").mkdir(parents=True, exist_ok=True) + + with patch("eegdash.dataset.dataset.EEGDashRaw", _FakeRaw): + # -- Path 1: records= ------------------------------------------------ + ds_records = EEGDashDataset( + cache_dir=tmp_path, + records=[parity_record], + download=True, + description_fields=description_fields, + ) + + # -- Path 2: offline (download=False) --------------------------------- + # discover_local_bids_records returns the same record; EEGBIDSDataset + # is made to fail so no participant enrichment happens, keeping the + # result identical to what the records= path produces. + with ( + patch( + "eegdash.dataset.dataset.discover_local_bids_records", + return_value=[parity_record], + ), + patch( + "eegdash.dataset.dataset.EEGBIDSDataset", + side_effect=Exception("no bids"), + ), + ): + ds_offline = EEGDashDataset( + cache_dir=tmp_path, + dataset="ds_parity", + download=False, + description_fields=description_fields, + ) + + # -- Path 3: query (mocked API) --------------------------------------- + mock_api = MagicMock() + mock_api.find.return_value = [parity_record] + with patch("eegdash.dataset.dataset.validate_record", return_value=[]): + ds_query = EEGDashDataset( + cache_dir=tmp_path, + dataset="ds_parity", + eeg_dash_instance=mock_api, + download=True, + description_fields=description_fields, + ) + + # BaseConcatDataset.description builds pd.DataFrame([ds.description for ds in datasets]) + # _FakeRaw.description is a real pd.Series, so this works correctly. + pd.testing.assert_frame_equal( + ds_records.description, + ds_offline.description, + check_like=True, + obj="records= vs offline", + ) + pd.testing.assert_frame_equal( + ds_records.description, + ds_query.description, + check_like=True, + obj="records= vs query", + ) + + +# --------------------------------------------------------------------------- +# 2. description_precedence="participant_tsv" — participant_tsv values win +# --------------------------------------------------------------------------- + + +def test_build_description_participant_tsv_precedence(tmp_path): + """participant_tsv values overwrite conflicting record values when precedence='participant_tsv'. + + Also verifies that a None value in participant_tsv overwrites a non-None + record value — this is intentional when the caller trusts that source fully. + """ + _stub_record = { + "dataset": "ds_prec", + "bidspath": "ds_prec/a.set", + "bids_relpath": "a.set", + "extension": ".set", + "storage": {"backend": "local", "base": str(tmp_path)}, + } + with patch("eegdash.dataset.dataset.EEGDashRaw", _FakeRaw): + ds = EEGDashDataset( + cache_dir=tmp_path, + records=[_stub_record], + download=True, + description_precedence="participant_tsv", + ) + + record = { + "age": 30, + "participant_tsv": {"age": 99, "sex": "M"}, + } + desc = ds._build_description(record, description_fields=["age", "sex"]) + + assert desc["age"] == 99, ( + "participant_tsv value must win when precedence='participant_tsv'" + ) + assert desc["sex"] == "M" + + # None in participant_tsv overwrites a real record value (documented behaviour). + record_none = { + "age": 30, + "participant_tsv": {"age": None}, + } + desc_none = ds._build_description(record_none, description_fields=["age"]) + assert desc_none["age"] is None, ( + "None in participant_tsv must overwrite record value when precedence='participant_tsv'" + ) + + +# --------------------------------------------------------------------------- +# 3. Invalid description_precedence raises ValueError at construction +# --------------------------------------------------------------------------- + + +def test_dataset_invalid_description_precedence(tmp_path): + """An unsupported description_precedence value raises ValueError at construction.""" + _stub_record = { + "dataset": "ds_inv", + "bidspath": "ds_inv/a.set", + "bids_relpath": "a.set", + "extension": ".set", + "storage": {"backend": "local", "base": str(tmp_path)}, + } + with pytest.raises(ValueError, match="description_precedence must be one of"): + EEGDashDataset( + cache_dir=tmp_path, + records=[_stub_record], + download=True, + description_precedence="invalid_mode", + ) diff --git a/tests/unit_tests/dataset/test_dataset.py b/tests/unit_tests/dataset/test_dataset.py index dc7b73420..8378a7a5d 100644 --- a/tests/unit_tests/dataset/test_dataset.py +++ b/tests/unit_tests/dataset/test_dataset.py @@ -171,19 +171,11 @@ def test_datasets_init_gap(): # We can just try to instantiate with minimum args # Patching the CLASS method with patch( - "eegdash.dataset.dataset.EEGDashDataset._find_datasets", return_value=[] + "eegdash.dataset.dataset.EEGDashDataset._find_datasets", + return_value=[MagicMock()], ): - # But wait, if it returns empty list, init might raise "No datasets found" - # We need to see code. - # If I look above, standard logic is: datasets = self._find_datasets... if not datasets: raise ValueError - # So we must return a non-empty list - with patch( - "eegdash.dataset.dataset.EEGDashDataset._find_datasets", - return_value=[MagicMock()], - ): - ds = EEGDashDataset(query={}, cache_dir=".", dataset="ds001") - # query is set before _find_datasets check - assert ds.query == {"dataset": "ds001"} + ds = EEGDashDataset(query={}, cache_dir=".", dataset="ds001") + assert ds.query == {"dataset": "ds001"} def test_dataset_init_exception_gap(tmp_path): @@ -202,35 +194,25 @@ def test_dataset_init_exception_gap(tmp_path): {"path": "s2", "bidspath": "foo/baz", "dataset": "ds001"}, ] - # Try patching the class in the module with patch("eegdash.dataset.dataset.EEGDashRaw"): - # Mock get_entities_from_record (plural) called in dataset.py - with patch( - "eegdash.dataset.dataset.get_entities_from_record", - return_value={"sub": "01"}, - ): - # Mock participants_row_for_subject is NOT called directly, usage is: - # part_row = bids_ds.subject_participant_tsv(local_file) - # So we mock EEGBIDSDataset or its method? - # In the code: bids_ds = EEGBIDSDataset(...) - # We should patch EEGBIDSDataset class in dataset.py - with patch("eegdash.dataset.dataset.EEGBIDSDataset") as mock_bids_cls: - mock_bids = mock_bids_cls.return_value - mock_bids.subject_participant_tsv.return_value = {} - - # Mock merge_participants_fields to raise Exception - with patch( - "eegdash.dataset.dataset.merge_participants_fields", - side_effect=Exception("Boom"), - ): - ds = EEGDashDataset( - cache_dir=cache, - dataset="ds001", - check_files=False, - download=False, - ) - # Should swallow exception and continue - assert len(ds.datasets) == 2 + with patch("eegdash.dataset.dataset.EEGBIDSDataset") as mock_bids_cls: + mock_bids = mock_bids_cls.return_value + mock_bids.subject_participant_tsv.return_value = {} + + # merge_participants_fields raises inside _build_description's offline + # enrichment path; the per-record exception must be swallowed. + with patch( + "eegdash.dataset.dataset.merge_participants_fields", + side_effect=Exception("Boom"), + ): + ds = EEGDashDataset( + cache_dir=cache, + dataset="ds001", + check_files=False, + download=False, + ) + # Should swallow exception and continue + assert len(ds.datasets) == 2 def test_dataset_init_kwargs_gap(tmp_path):