Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 125 additions & 35 deletions eegdash/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from .. import downloader
from ..bids_metadata import (
build_query_from_kwargs,
get_entities_from_record,
merge_participants_fields,
normalize_key,
)
Expand Down Expand Up @@ -178,6 +177,23 @@ class EEGDashDataset(BaseConcatDataset, metaclass=NumpyDocstringInheritanceInitM

Skipped recordings are flagged via ``ds._skipped`` so callers can
filter them out with a list comprehension after iteration.
description_precedence : str, default "record"
Which source wins when the same field appears in both the record and
the embedded ``participant_tsv`` data:

- ``"record"`` (default): the record-level value is kept.
- ``"participant_tsv"``: the ``participant_tsv`` value overwrites the
record value for conflicting fields.

In both cases a ``debug``-level log is emitted when a conflict is
detected.

.. note::
When ``description_precedence="participant_tsv"``, a ``None``
value in ``participant_tsv`` will overwrite a non-``None`` record
value for the same field. This is deliberate — choosing this mode
means trusting the ``participant_tsv`` source fully, including its
gaps.
**kwargs : dict
Additional keyword arguments serving two purposes:

Expand All @@ -202,13 +218,21 @@ def __init__(
auth_token: str | None = None,
on_error: str = "raise",
max_concurrency: int = 20,
description_precedence: str = "record",
**kwargs,
):
# Internal-only kwargs
suppress_comp_warning = kwargs.pop("_suppress_comp_warning", False)
self._dedupe_records: bool = kwargs.pop("_dedupe_records", False)

self._on_error = on_error
_valid = {"record", "participant_tsv"}
if description_precedence not in _valid:
raise ValueError(
f"description_precedence must be one of {sorted(_valid)}, "
f"got {description_precedence!r}"
)
self._description_precedence = description_precedence
self.s3_bucket = s3_bucket
self.database = database
self.auth_token = auth_token
Expand Down Expand Up @@ -246,15 +270,17 @@ def __init__(

if records is not None:
self.records = self._normalize_records(records)

datasets = [
EEGDashRaw(
record,
self.cache_dir,
**base_dataset_kwargs,
datasets = []
for norm_record in self.records:
description = self._build_description(norm_record, description_fields)
datasets.append(
EEGDashRaw(
norm_record,
self.cache_dir,
description=description,
**base_dataset_kwargs,
)
)
for record in self.records
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for record in self.records
datasets = [
EEGDashRaw(
record,
self.cache_dir,
description=self._build_description(record, description_fields),
**base_dataset_kwargs,
)
for record in self.records
]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minimizing the diff.

]
elif not download: # only assume local data is complete if not downloading
if not self.data_dir.exists():
raise ValueError(
Expand All @@ -272,26 +298,20 @@ def __init__(

datasets = []
for record in records:
# Start with entity values from filename (supports v1 and v2 formats)
desc: dict[str, Any] = get_entities_from_record(record)

part_row: dict[str, Any] | None = None
if bids_ds is not None:
try:
rel_from_dataset = Path(record["bidspath"]).relative_to(
record["dataset"]
) # type: ignore[index]
local_file = (self.data_dir / rel_from_dataset).as_posix()
part_row = bids_ds.subject_participant_tsv(local_file)
desc = merge_participants_fields(
description=desc,
participants_row=part_row
if isinstance(part_row, dict)
else None,
description_fields=description_fields,
)
row = bids_ds.subject_participant_tsv(local_file)
part_row = row if isinstance(row, dict) else None
except Exception:
pass

desc = self._build_description(
record, description_fields, participants_row=part_row
)
datasets.append(
EEGDashRaw(
record=record,
Expand Down Expand Up @@ -561,6 +581,89 @@ def _find_local_bids_records(
"""
return discover_local_bids_records(dataset_root, filters)

def _build_description(
self,
record: dict[str, Any],
description_fields: list[str],
participants_row: dict[str, Any] | None = None,
) -> dict[str, Any]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is more verbose than necessary for this function, and I would move to an auxiliary place, some utilities. This way, the dataset object does not deliver this function, which is only used once.

"""Build a description dict for a single record.

Extracts values for each requested field from the record, then merges
participant data from either an explicit ``participants_row`` (offline
path, from a local ``participants.tsv``) or the embedded
``participant_tsv`` key inside the record (online paths). Fields still
absent after the merge are set to ``None`` so the schema is always
complete. When both the record and participant data carry the same
field, precedence is determined by ``self._description_precedence``; a
``debug``-level log is emitted when the values differ.
Comment on lines +584 to +599
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can compact much more function too.


Parameters
----------
record : dict
The metadata for a single record.
description_fields : list of str
The fields to include in the description.
participants_row : dict or None
Optional participant-level metadata to merge. If None, the method
will look for an embedded ``participant_tsv`` key in the record.

Returns
-------
dict
A dictionary containing the requested description fields for the record.

"""
description: dict[str, Any] = {}

for field_name in description_fields:
value = self._find_key_in_nested_dict(record, field_name)
if value is not None:
description[field_name] = value

effective_part = participants_row
if effective_part is None:
embedded = self._find_key_in_nested_dict(record, "participant_tsv")
if isinstance(embedded, dict):
effective_part = embedded

if isinstance(effective_part, dict):
norm_present = {
normalize_key(k): k for k, v in description.items() if v is not None
}
for part_key, part_val in effective_part.items():
existing_field = norm_present.get(normalize_key(part_key))
if (
existing_field is not None
and description[existing_field] != part_val
):
if self._description_precedence == "participant_tsv":
logger.debug(
"Field '%s': participant_tsv value %r overwrote record value %r.",
existing_field,
part_val,
description[existing_field],
)
description[existing_field] = part_val
else:
logger.debug(
"Field '%s': record value %r kept over participant_tsv value %r.",
existing_field,
description[existing_field],
part_val,
)
description = merge_participants_fields(
description=description,
participants_row=effective_part,
description_fields=description_fields,
)

# Ensure all requested fields are present; None for any that were not found
for field in description_fields:
description.setdefault(field, None)

return description

def _find_key_in_nested_dict(self, data: Any, target_key: str) -> Any:
"""Recursively search for a key in nested dicts/lists.

Expand Down Expand Up @@ -639,20 +742,7 @@ def _find_datasets(
f"Record data_name: {record.get('data_name', 'unknown')}"
)

description: dict[str, Any] = {}
# Requested fields first (normalized matching)
for field_name in description_fields:
value = self._find_key_in_nested_dict(record, field_name)
if value is not None:
description[field_name] = value
# Merge all participants.tsv columns generically
part = self._find_key_in_nested_dict(record, "participant_tsv")
if isinstance(part, dict):
description = merge_participants_fields(
description=description,
participants_row=part,
description_fields=description_fields,
)
description = self._build_description(record, description_fields)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice simplification

datasets.append(
EEGDashRaw(
record,
Expand Down
Loading
Loading