diff --git a/.gitignore b/.gitignore index 9c9fb96..2bc61f8 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,5 @@ cython_debug/ *.sif *.bak + +docs/source/generated/ diff --git a/README.md b/README.md index ca98f27..6c77fe8 100644 --- a/README.md +++ b/README.md @@ -1,31 +1,98 @@ # Experanto + Experanto is a Python package designed for interpolating recordings and stimuli in neuroscience experiments. It enables users to load single or multiple experiments and create efficient dataloaders for machine learning applications. +## Features + +- **Unified Experiment Interface**: Load and query multi-modal neuroscience data (neural responses, eye tracking, treadmill, visual stimuli) through a single `Experiment` class +- **Flexible Interpolation**: Interpolate data at arbitrary time points with support for linear and nearest-neighbor methods +- **Multi-Session Support**: Combine data from multiple recording sessions into a single dataloader +- **Configurable Preprocessing**: YAML-based configuration for sampling rates, normalization, transforms, and filtering +- **PyTorch Integration**: Native PyTorch `Dataset` and `DataLoader` implementations optimized for training + ## Docs [![Docs](https://readthedocs.org/projects/experanto/badge/?version=latest)](https://experanto.readthedocs.io/) ## Installation -To install Experanto, clone locally and run: + ```bash -pip install -e /path_to/experanto +git clone https://github.com/sensorium-competition/experanto.git +cd experanto +pip install -e . ``` -To replicate the `generate_sample` example, install: +### Note + +To replicate the `generate_sample` example, use the following command (see [allen_exporter](https://github.com/sensorium-competition/allen-exporter)): + ```bash -pip install -e /path_to/allen_exporter +pip install -e /path/to/allen_exporter ``` -(Repository: [allen_exporter](https://github.com/sensorium-competition/allen-exporter)) -To replicate the `sensorium_example`, also install the following with their dependencies: +To replicate the `sensorium_example` (see [sensorium_2023](https://github.com/ecker-lab/sensorium_2023)), install neuralpredictors (see [neuralpredictors](https://github.com/sinzlab/neuralpredictors)) as well: + ```bash -pip install -e /path_to/neuralpredictors +pip install -e /path/to/neuralpredictors +pip install -e /path/to/sensorium_2023 ``` -(Repository: [neuralpredictors](https://github.com/sinzlab/neuralpredictors)) -```bash -pip install -e /path_to/sensorium_2023 +## Quick Start + +### Loading an Experiment + +```python +from experanto.experiment import Experiment + +# Load a single experiment +exp = Experiment("/path/to/experiment") + +# Query data at specific time points +import numpy as np +times = np.linspace(0, 10, 100) # 100 time points over 10 seconds + +# Get interpolated data and a boolean mask with valid time points from all devices +data, valid = exp.interpolate(times) + +# Or from a specific device +responses, valid = exp.interpolate(times, device="responses") +``` + +### Configuration + +Experanto uses YAML configuration files. See `configs/default.yaml` for all options: + +```yaml +dataset: + modality_config: + responses: + sampling_rate: 8 + chunk_size: 16 + transforms: + normalization: "standardize" + screen: + sampling_rate: 30 + chunk_size: 60 + transforms: + normalization: "normalize" + +dataloader: + batch_size: 16 + num_workers: 2 ``` -(Repository: [sensorium_2023](https://github.com/ecker-lab/sensorium_2023)) -Ensure you replace `/path_to/` with the actual path to the cloned repositories. +## Documentation + +Full documentation is available at [Read the Docs](https://experanto.readthedocs.io/). + +- [Installation Guide](https://experanto.readthedocs.io/en/latest/concepts/installation.html) +- [Getting Started](https://experanto.readthedocs.io/en/latest/concepts/getting_started.html) +- [API Reference](https://experanto.readthedocs.io/en/latest/api.html) +- [Configuration Options](https://experanto.readthedocs.io/en/latest/configuration.html) + +## Contributing + +Contributions are welcome! Please open an issue or submit a pull request on [GitHub](https://github.com/sensorium-competition/experanto). + +## License +This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. diff --git a/docs/source/_static/.gitkeep b/docs/source/_static/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/source/_templates/custom-class-template.rst b/docs/source/_templates/custom-class-template.rst new file mode 100644 index 0000000..2a6ffd2 --- /dev/null +++ b/docs/source/_templates/custom-class-template.rst @@ -0,0 +1,31 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + + {% block methods %} + {% if methods %} + .. rubric:: Methods + + .. autosummary:: + {% for item in methods %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block attributes %} + {% if attributes %} + .. rubric:: Attributes + + .. autosummary:: + {% for item in attributes %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/docs/source/api.rst b/docs/source/api.rst new file mode 100644 index 0000000..5d718a8 --- /dev/null +++ b/docs/source/api.rst @@ -0,0 +1,96 @@ +Classes and functions +===================== + +This section documents all public classes and functions in Experanto. + +Core Classes +------------ + +.. autosummary:: + :toctree: generated + :template: custom-class-template.rst + :nosignatures: + + experanto.experiment.Experiment + experanto.datasets.ChunkDataset + +Interpolators +------------- + +.. autosummary:: + :toctree: generated + :template: custom-class-template.rst + :nosignatures: + + experanto.interpolators.Interpolator + experanto.interpolators.SequenceInterpolator + experanto.interpolators.PhaseShiftedSequenceInterpolator + experanto.interpolators.ScreenInterpolator + experanto.interpolators.TimeIntervalInterpolator + experanto.interpolators.ScreenTrial + experanto.interpolators.ImageTrial + experanto.interpolators.VideoTrial + experanto.interpolators.BlankTrial + experanto.interpolators.InvalidTrial + +Time Intervals +-------------- + +.. autosummary:: + :toctree: generated + :template: custom-class-template.rst + :nosignatures: + + experanto.intervals.TimeInterval + +.. autosummary:: + :toctree: generated + :nosignatures: + + experanto.intervals.uniquefy_interval_array + experanto.intervals.find_intersection_between_two_interval_arrays + experanto.intervals.find_intersection_across_arrays_of_intervals + experanto.intervals.find_union_across_arrays_of_intervals + experanto.intervals.find_complement_of_interval_array + experanto.intervals.get_stats_for_valid_interval + +Dataloaders +----------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + experanto.dataloaders.get_multisession_dataloader + experanto.dataloaders.get_multisession_concat_dataloader + +Utilities +--------- + +.. autosummary:: + :toctree: generated + :template: custom-class-template.rst + :nosignatures: + + experanto.utils.LongCycler + experanto.utils.ShortCycler + experanto.utils.FastSessionDataLoader + experanto.utils.MultiEpochsDataLoader + experanto.utils.SessionConcatDataset + experanto.utils.SessionBatchSampler + experanto.utils.SessionSpecificSampler + +.. autosummary:: + :toctree: generated + :nosignatures: + + experanto.utils.add_behavior_as_channels + +Filters +------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + experanto.filters.common_filters.nan_filter diff --git a/docs/source/concepts/demo_configs.rst b/docs/source/concepts/demo_configs.rst index 89dfdc2..3d0b090 100644 --- a/docs/source/concepts/demo_configs.rst +++ b/docs/source/concepts/demo_configs.rst @@ -1,3 +1,5 @@ +.. _dataset_and_dataloader_configuration: + Dataset and dataloader configuration ==================================== @@ -114,3 +116,132 @@ You can change parameters programmatically: cfg.dataset.modality_config.screen.include_blanks = True cfg.dataset.modality_config.screen.valid_condition = {"tier": "train"} cfg.dataloader.num_workers = 8 + + +Configuration options +^^^^^^^^^^^^^^^^^^^^^ + +Dataset options +""""""""""""""" + +``global_sampling_rate`` + Override sampling rate for all modalities. Set to ``None`` to use + per-modality rates. + +``global_chunk_size`` + Override chunk size (number of time steps/data points) for all modalities. + Set to ``None`` to use per-modality sizes. + + The time window covered by a chunk is ``chunk_size / sampling_rate``, so + the ``global_sampling_rate`` should be taken into account: + + - **With** ``global_sampling_rate`` set: all modalities share the same + output rate, so a single ``global_chunk_size`` unambiguously gives every + modality the same time window. + - **Without** ``global_sampling_rate`` (per-modality rates active): + different modalities have different rates, so the same sample count + produces different durations. In this case, leave ``global_chunk_size`` + as ``None`` and set ``chunk_size`` per modality instead. + +``add_behavior_as_channels`` + If ``True``, concatenate behavioral data (e.g., eye tracker, treadmill) as + additional channels to the screen data. + +``replace_nans_with_means`` + If ``True``, replace NaN values with the mean of non-NaN values. + +``cache_data`` + If ``True``, cache interpolated data in memory for faster access. + +``out_keys`` + List of modality keys to include in the output dictionary. + +``normalize_timestamps`` + If ``True``, normalize timestamps to start from 0. + +Modality options +"""""""""""""""" + +Each modality (e.g., screen, responses, eye_tracker, treadmill) supports: + +``keep_nans`` + Whether to keep NaN values in the output. + +``sampling_rate`` + Controls the spacing of the time points that the dataset constructs and + passes to :meth:`~experanto.experiment.Experiment.interpolate`. Concretely, + each item in the dataset requests values at times + ``start, start + 1/sampling_rate, start + 2/sampling_rate, …``. The + interpolator then interpolates the stored raw samples at those points. + +``chunk_size`` + Number of **time steps/data points** returned per item for this modality. + Internally, ``sampling_rate`` defines the spacing of the time points passed + to the interpolator, so the covered time window is: + + .. math:: + + \text{duration (s)} = \frac{\text{chunk\_size}}{\text{sampling\_rate}} + + Note that ``sampling_rate`` here controls the *spacing* of the time points + requested from the underlying experiment (see ``sampling_rate`` above). The + native acquisition rate of the signal does not matter (the interpolator simply looks up the stored values closest to each requested time, e.g.). + + When per-modality output rates differ, ``chunk_size`` must be set per + modality to cover the same time window. The default configuration keeps + all modalities at a 2-second window while using different output rates: + + ============ ============= =========== =========== + Modality sampling_rate chunk_size Duration + ============ ============= =========== =========== + screen 30 Hz 60 2 s + eye_tracker 30 Hz 60 2 s + treadmill 30 Hz 60 2 s + responses 8 Hz 16 2 s + ============ ============= =========== =========== + + If you unify all rates with ``global_sampling_rate``, use + ``global_chunk_size`` instead and this per-modality value is ignored. + In general: ``chunk_size = desired_duration_seconds * sampling_rate``. + +``offset`` + Time offset in seconds applied to the time points constructed for this + modality. For example, if the screen is queried at times + ``[t, t + 1/sampling_rate, …]``, setting ``offset = 0.1`` on responses + means responses are queried at ``[t + 0.1, t + 0.1 + 1/sampling_rate, …]``. + Useful for aligning modalities with known temporal delays relative to the + screen stimulus. + +``transforms`` + Dictionary of transforms to apply at the dataset level. This is modality + specific, i.e., not all modalities support the same set of transforms. Some + examples include ``"normalize"`` for sequences, such as eye_tracker, + and ``"standardize"`` for responses. + + To understand how transforms are loaded and applied internally, refer to + :meth:`experanto.datasets.ChunkDataset.initialize_transforms`. If you need + to implement a custom transform, we recommend following the same pattern + used there. In particular, note how each entry in the ``transforms`` + dictionary is checked and, when it is a config ``dict``, instantiated via + Hydra before being added to the transform pipeline. + + You can point Experanto to any callable (function or class) by using + Hydra's ``_target_`` key, which triggers + `hydra.utils.instantiate `_ + under the hood (e.g., ``_target_: my_package.my_module.MyTransform``). + +``interpolation`` + Interpolation settings. This is modality specific, i.e., not all modalities + support the same set of interpolation methods. Some examples include + ``"rescale"`` for the screen and ``"interpolation_mode"`` (e.g., + ``"nearest_neighbor"``) for sequences. + +``filters`` + Dictionary of filter functions to apply to the data. + +Dataloader options +"""""""""""""""""" + +All standard ``torch.utils.data.DataLoader`` options are supported. See the +`PyTorch DataLoader documentation `_ +for the full list of available parameters. diff --git a/docs/source/concepts/demo_dataset.rst b/docs/source/concepts/demo_dataset.rst index 90472d9..0825fef 100644 --- a/docs/source/concepts/demo_dataset.rst +++ b/docs/source/concepts/demo_dataset.rst @@ -1,17 +1,33 @@ - .. _loading_dataset: Loading a dataset object ======================== -Dataset objects organize experimental data (from **Experiment** class) for machine learning tasks, offering project-specific and configurable access for training and evaluation. They often serve as a source for creating **dataloaders**. +Dataset objects organize experimental data (from the :class:`~experanto.experiment.Experiment` class) for machine learning tasks, offering project-specific and configurable access for training and evaluation. They often serve as a source for creating dataloaders (see :func:`~experanto.dataloaders.get_multisession_dataloader`). + +.. note:: + + The key distinction between :class:`~experanto.experiment.Experiment` and a + dataset object is one of **time discretization**. + :class:`~experanto.experiment.Experiment` is a low-level interface: you hand + it any array of time points and it returns values at those points via a + lookup into the raw stored data. :class:`~experanto.datasets.ChunkDataset` + is used on top of it and imposes a specific time structure. For each item, + it constructs a separate ``times`` array per modality using that modality's + configured ``sampling_rate`` and ``chunk_size`` (``times = start + np.arange(chunk_size) / sampling_rate``), + then calls :meth:`~experanto.experiment.Experiment.interpolate` for each + modality independently. This is how all modalities end up covering the same + time window with compatible shapes in a batch. Key features of dataset objects ------------------------------- Dataset objects provide several essential features: -- **Sampling Rate**: Defines the frequency of equally spaced interpolation times across the entire experiment. This ensures consistency in temporal data alignment. +- **Sampling Rate**: Defines the spacing of the time points that the dataset + constructs and hands to the underlying :class:`~experanto.experiment.Experiment` + for each item (``time_delta = 1 / sampling_rate``). The experiment then does + a lookup into the raw stored data at those points. - **Chunk Size**: Determines the number of values returned when calling the ``__getitem__`` method. This is crucial, for example, for deep learning models that use 3D convolutions over time, where single elements or small chunk sizes are insufficient to capture meaningful temporal patterns. - **Modality Configuration**: Specifies the details of the interpolation, including: @@ -81,7 +97,7 @@ This will output something like: Defining dataloaders --------------------- -Once the dataset is verified, we can define **DataLoader** objects for training or other purposes. This allows easy batch processing during training: +Once the dataset is verified, we can define `DataLoader `_ objects for training or other purposes. This allows easy batch processing during training: .. code-block:: python diff --git a/docs/source/concepts/demo_experiment.rst b/docs/source/concepts/demo_experiment.rst index 8a83344..a4296ff 100644 --- a/docs/source/concepts/demo_experiment.rst +++ b/docs/source/concepts/demo_experiment.rst @@ -3,7 +3,20 @@ Loading a single experiment =========================== -To load an experiment, we use the **Experiment** class. This is particularly useful for testing whether the formatting and interpolation behave as expected before loading multiple experiments into dataset objects. +To load an experiment, we use the :class:`~experanto.experiment.Experiment` +class. This class aggregates all modalities and their respective interpolators +in a single object. Its main job is to unify the access to all modalities. + +:class:`~experanto.experiment.Experiment` accepts an arbitrary array of time +points and returns the corresponding values for each modality by looking them +up in the raw stored data (e.g., using nearest-neighbour or linear +interpolation for sequences). When you need a regular sampling grid or +fixed-length intervals (chunks) for training, a dataset object such as +:class:`~experanto.datasets.ChunkDataset` should be used **on top of** +:class:`~experanto.experiment.Experiment`. There you can define the time +discretization (via ``sampling_rate`` and ``chunk_size``), construct the +appropriate time points, and delegate the data retrieval to the +underlying :class:`~experanto.experiment.Experiment`. Loading an experiment --------------------- @@ -34,8 +47,26 @@ All compatible modalities for the loaded experiment can be checked using: Interpolating data ------------------ -Once the modalities are identified, we can interpolate their data. -The following example interpolates a 20-second window with 2 frames per second, resulting in 40 images: +Once the modalities are identified, we can interpolate their data using +:meth:`~experanto.experiment.Experiment.interpolate`. + +:meth:`~experanto.experiment.Experiment.interpolate` accepts any 1-D array of +time points and returns, for each modality, an array of shape +``(len(times), n_signals)``. The number of returned points is always +``len(times)``, regardless of the native acquisition rate of the modality. + +When you call :meth:`~experanto.experiment.Experiment.interpolate` **without** +a ``device`` argument, every modality receives the *same* time array. This +means modalities with low native rates can return repeated values for +consecutive requested times that fall in the same native sample (nearest +neighbour), while modalities with high native rates will effectively be +sub-sampled (the behavior is interpolator-dependent). If you need different +time densities per modality, call :meth:`~experanto.experiment.Experiment.interpolate` +separately with a different ``times`` array for each ``device``. This is +exactly what :class:`~experanto.datasets.ChunkDataset` does internally. + +The following example interpolates a 20-second window at 2 time points per +second, resulting in 40 screen frames: .. code-block:: python diff --git a/docs/source/concepts/demo_multisession.rst b/docs/source/concepts/demo_multisession.rst index 0c5675d..71a1b80 100644 --- a/docs/source/concepts/demo_multisession.rst +++ b/docs/source/concepts/demo_multisession.rst @@ -1,14 +1,14 @@ Loading multiple sessions ========================= -To load multiple sessions at once, you can use the ``get_multisession_dataloader`` function from ``experanto.dataloaders``. +To load multiple sessions at once, you can use :func:`~experanto.dataloaders.get_multisession_dataloader`. This function takes: - A list of paths pointing to your experiment directories - A configuration dictionary, similar to the one used for loading a single dataset -It returns a dictionary of ``MultiEpochsDataLoader`` objects, each corresponding to a session, loaded with the specified configurations. +It returns a dictionary of :class:`~experanto.utils.MultiEpochsDataLoader` objects, each corresponding to a session, loaded with the specified configurations. Example ------- @@ -30,4 +30,4 @@ Example # Load first two sessions train_dl = get_multisession_dataloader(full_paths[:2], cfg) -The returned ``train_dl`` is a dictionary containing two ``MultiEpochsDataLoader`` objects which can be used for training. +The returned ``train_dl`` is a dictionary containing two :class:`~experanto.utils.MultiEpochsDataLoader` objects which can be used for training. diff --git a/docs/source/concepts/installation.rst b/docs/source/concepts/installation.rst index 7ab765f..4090a07 100644 --- a/docs/source/concepts/installation.rst +++ b/docs/source/concepts/installation.rst @@ -11,5 +11,5 @@ The package works on top of `jupyter/datascience-notebooks`, but the minimum req to install the package, clone it into a local repository and run:: - pip -e install /path_to_folder/experanto + pip install -e /path/to/folder/experanto diff --git a/docs/source/conf.py b/docs/source/conf.py index 3b658cc..5a60ef9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -3,6 +3,12 @@ # For the full list of built-in configuration values, see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html +import os +import sys + +# Add the project root to the path so autodoc can find the modules +sys.path.insert(0, os.path.abspath("../..")) + # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information @@ -14,11 +20,79 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = [] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.mathjax", + "sphinx.ext.intersphinx", + "sphinx.ext.viewcode", +] + +# -- Autosummary settings ---------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/extensions/autosummary.html + +autosummary_generate = True +autosummary_generate_overwrite = True +autosummary_imported_members = False templates_path = ["_templates"] exclude_patterns = [] +# -- Napoleon settings (NumPy-style docstrings) ------------------------------ +# https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html + +napoleon_google_docstring = False +napoleon_numpy_docstring = True +napoleon_include_init_with_doc = True +napoleon_include_private_with_doc = False +napoleon_include_special_with_doc = True +napoleon_use_admonition_for_examples = False +napoleon_use_admonition_for_notes = False +napoleon_use_admonition_for_references = False +napoleon_use_ivar = False +napoleon_use_param = True +napoleon_use_rtype = True +napoleon_type_aliases = None + +# -- Autodoc settings -------------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html + +autodoc_default_options = { + "members": True, + "member-order": "bysource", + "special-members": "__init__", + "undoc-members": True, + "exclude-members": "__weakref__", +} +autodoc_typehints = "description" +autodoc_typehints_description_target = "documented" + +# Mock heavy imports that aren't needed for documentation generation +autodoc_mock_imports = [ + "torch", + "torchvision", + "numpy", + "scipy", + "cv2", + "pandas", + "hydra", + "omegaconf", + "jaxtyping", + "plotly", + "optree", + "rootutils", +] + +# -- Intersphinx settings (cross-references to external docs) ---------------- +# https://www.sphinx-doc.org/en/master/usage/extensions/intersphinx.html + +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/", None), + "torch": ("https://pytorch.org/docs/stable/", None), +} # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/docs/source/index.rst b/docs/source/index.rst index 5886f44..628db63 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -3,7 +3,7 @@ **Experanto** is a Python package designed for interpolating recordings and stimuli in neuroscience experiments. It enables users to load single or multiple experiments and create efficient dataloaders for machine learning applications. -Issues with the package can be submited at our `GitHub Issues page `_. +Issues with the package can be submitted at our `GitHub Issues page `_. ------------ @@ -17,10 +17,16 @@ Issues with the package can be submited at our `GitHub Issues page >> from experanto.configs import DEFAULT_CONFIG +>>> from omegaconf import OmegaConf +>>> print(OmegaConf.to_yaml(DEFAULT_CONFIG)) + +Customizing configuration: + +>>> from experanto.configs import DEFAULT_MODALITY_CONFIG +>>> cfg = DEFAULT_MODALITY_CONFIG.copy() +>>> cfg.screen.sampling_rate = 60 +>>> cfg.responses.chunk_size = 32 + +See Also +-------- +ChunkDataset : Uses these configurations for data loading. +""" + from pathlib import Path from hydra import compose, initialize, initialize_config_dir diff --git a/experanto/dataloaders.py b/experanto/dataloaders.py index 37f344e..f891bd5 100644 --- a/experanto/dataloaders.py +++ b/experanto/dataloaders.py @@ -26,15 +26,46 @@ def get_multisession_dataloader( shuffle_keys: bool = False, **kwargs, ) -> LongCycler: - """ - Create a multisession dataloader from a list of paths and corresponding configs. - Args: - paths (List[str]): List of paths to the datasets. - configs (Union[DictConfig, Dict, List[Union[DictConfig, Dict]]]): Configuration for each dataset. - If a single config is provided, it will be applied to all datasets. - If a list is provided, it should match the length of paths. - shuffle_keys (bool): Whether to shuffle the keys of the dataloaders. - **kwargs: Additional keyword arguments for dataset and dataloader configuration. + """Create a multi-session dataloader from multiple experiment paths. + + By default, creates a :class:`ChunkDataset` for each path and wraps them in + a :class:`LongCycler` that yields ``(session_key, batch)`` pairs. + The cycler continues until the longest session is exhausted. + + Parameters + ---------- + paths : list of str + Paths to experiment directories. + configs : dict, DictConfig, list, optional + Configuration for each dataset. If a single config is provided, + it will be applied to all datasets. If a list is provided, it + should match the length of ``paths``. Each config should have + ``dataset`` and ``dataloader`` keys. + shuffle_keys : bool, default=False + Whether to shuffle the order of session keys. + **kwargs + Additional keyword arguments. Supports ``config`` as an alias + for ``configs``. + + Returns + ------- + LongCycler + A dataloader-like object that yields ``(session_key, batch)`` tuples. + Iterates until the longest session is exhausted. + + See Also + -------- + get_multisession_concat_dataloader : Alternative that concatenates sessions. + LongCycler : The underlying multi-session iterator. + + Examples + -------- + >>> from experanto.dataloaders import get_multisession_dataloader + >>> from experanto.configs import DEFAULT_CONFIG + >>> paths = ['/path/to/exp1', '/path/to/exp2'] + >>> loader = get_multisession_dataloader(paths, configs=DEFAULT_CONFIG) + >>> for session_key, batch in loader: + ... print(f"Session: {session_key}, batch shape: {batch['responses'].shape}") """ if configs is None and "config" in kwargs: @@ -70,20 +101,51 @@ def get_multisession_concat_dataloader( dataloader_config: Optional[Dict] = None, **kwargs, ) -> Optional["FastSessionDataLoader"]: - """ - Creates a multi-session dataloader using SessionConcatDataset and SessionDataLoader. - Returns (session_key, batch) pairs during iteration. - - Args: - paths: List of paths to dataset files - configs: Configuration for datasets (single config or list of configs) - seed: Random seed for reproducibility - num_workers: Number of worker processes for data loading - prefetch_factor: Prefetch factor for data loading - **kwargs: Additional arguments - - Returns: - SessionDataLoader instance or None if no valid datasets found + """Create a concatenated multi-session dataloader. + + Unlike :func:`get_multisession_dataloader`, this function concatenates + all sessions into a single dataset and uses batch sampling to ensure + each batch contains samples from only one session. This is more + memory-efficient and provides better shuffling across sessions. + + Parameters + ---------- + paths : list of str + Paths to experiment directories. + configs : dict or list of dict, optional + Configuration for each dataset. If a single config is provided, + it will be applied to all datasets. Each config should have + ``dataset`` and ``dataloader`` keys. + seed : int, default=0 + Random seed for reproducibility. Each dataset gets a deterministic + seed derived from this value and its path hash. + dataloader_config : dict, optional + Configuration for the dataloader (batch_size, num_workers, etc.). + If None, uses the dataloader config from the first config. + **kwargs + Additional keyword arguments. Supports ``config`` as an alias + for ``configs``. + + Returns + ------- + FastSessionDataLoader or None + A dataloader that yields ``(session_key, batch)`` tuples. + Returns None if no valid datasets could be created. + + See Also + -------- + get_multisession_dataloader : Alternative using separate dataloaders. + FastSessionDataLoader : The underlying dataloader implementation. + SessionConcatDataset : Dataset that concatenates multiple sessions. + + Examples + -------- + >>> from experanto.dataloaders import get_multisession_concat_dataloader + >>> from experanto.configs import DEFAULT_CONFIG + >>> paths = ['/path/to/exp1', '/path/to/exp2', '/path/to/exp3'] + >>> loader = get_multisession_concat_dataloader(paths, configs=DEFAULT_CONFIG) + >>> for session_key, batch in loader: + ... print(f"Session: {session_key}") """ if configs is None and "config" in kwargs: configs = kwargs.pop("config") @@ -98,7 +160,6 @@ def get_multisession_concat_dataloader( start_time = time.time() for i, (path, cfg) in enumerate(zip(paths, configs)): - # Create dataset with deterministic seed path_hash = hash(path) % 10000 dataset_seed = seed + path_hash if seed is not None else None diff --git a/experanto/datasets.py b/experanto/datasets.py index 1045263..7044749 100644 --- a/experanto/datasets.py +++ b/experanto/datasets.py @@ -71,26 +71,77 @@ def __getitem__(self, idx): class ChunkDataset(Dataset): - def __init__( - self, - root_folder: str, - global_sampling_rate: Optional[float] = None, - global_chunk_size: Optional[int] = None, - add_behavior_as_channels: bool = False, - replace_nans_with_means: bool = False, - cache_data: bool = False, - out_keys: Optional[Iterable] = None, - normalize_timestamps: bool = True, - modality_config: dict = DEFAULT_MODALITY_CONFIG, - seed: Optional[int] = None, - safe_interval_threshold: float = 0.5, - interpolate_precision: int = 5, - ) -> None: - """ - interpolate_precision: number of digits after the dot to keep, without it we might get different numbers from interpolation - - The full modality config is a nested dictionary. - The following is an example of a modality config for a screen, responses, eye_tracker, and treadmill: + """PyTorch Dataset for chunked experiment data. + + This dataset loads an experiment and provides temporally-chunked samples + suitable for training neural networks. Each sample contains synchronized + data from all modalities (e.g., screen, responses, eye_tracker, treadmill) + at a given time window. + + Parameters + ---------- + root_folder : str + Path to the experiment directory containing modality subfolders. + global_sampling_rate : float, optional + Sampling rate (Hz) applied to all modalities. If None, uses + per-modality rates from ``modality_config``. + global_chunk_size : int, optional + Number of samples per chunk for all modalities. If None, uses + per-modality sizes from ``modality_config``. + add_behavior_as_channels : bool, default=False + If True, concatenates behavioral data as additional image channels. + Deprecated: use separate modality outputs instead. + replace_nans_with_means : bool, default=False + If True, replaces NaN values with column means. + cache_data : bool, default=False + If True, keeps loaded data in memory for faster access. + out_keys : iterable, optional + Which modalities to include in output. Defaults to all modalities + plus 'timestamps'. + normalize_timestamps : bool, default=True + If True, normalizes timestamps relative to recording start. + modality_config : dict + Configuration for each modality including sampling rates, transforms, + filters, and interpolation settings. See Notes for structure. + seed : int, optional + Random seed for reproducible shuffling of valid time points. + safe_interval_threshold : float, default=0.5 + Safety margin (in seconds) to exclude from edges of valid intervals. + interpolate_precision : int, default=5 + Number of decimal places for time precision. Prevents floating-point + accumulation errors during interpolation. + + + Attributes + ---------- + data_key : str + Unique identifier for this dataset, extracted from metadata. + device_names : tuple + Names of loaded modalities. + start_time : float + Start of valid time range (after applying safety threshold). + end_time : float + End of valid time range (after applying safety threshold). + + See Also + -------- + Experiment : Lower-level interface for data access. + get_multisession_dataloader : Load multiple datasets. + experanto.configs : Default configuration values. + + Notes + ----- + The dataset handles: + + - Per-modality sampling rates and chunk sizes + - Time offset alignment between modalities + - Data normalization and transforms + - Filtering based on trial conditions and data quality + - Reproducible random sampling with seeds + + The ``modality_config`` is a nested dictionary with per-modality settings. The following is an example of a modality config for a screen, responses, eye_tracker, and treadmill: + + .. code-block:: yaml screen: sampling_rate: null @@ -141,7 +192,41 @@ def __init__( normalize: true interpolation: interpolation_mode: nearest_neighbor - """ + + Examples + -------- + >>> from experanto.datasets import ChunkDataset + >>> from experanto.configs import DEFAULT_MODALITY_CONFIG + >>> dataset = ChunkDataset( + ... '/path/to/experiment', + ... global_sampling_rate=30, + ... global_chunk_size=60, + ... modality_config=DEFAULT_MODALITY_CONFIG, + ... ) + >>> len(dataset) + 1000 + >>> sample = dataset[0] + >>> sample['screen'].shape + torch.Size([1, 60, 144, 256]) + >>> sample['responses'].shape + torch.Size([16, 500]) + """ + + def __init__( + self, + root_folder: str, + global_sampling_rate: Optional[float] = None, + global_chunk_size: Optional[int] = None, + add_behavior_as_channels: bool = False, + replace_nans_with_means: bool = False, + cache_data: bool = False, + out_keys: Optional[Iterable] = None, + normalize_timestamps: bool = True, + modality_config: dict = DEFAULT_MODALITY_CONFIG, + seed: Optional[int] = None, + safe_interval_threshold: float = 0.5, + interpolate_precision: int = 5, + ) -> None: self.root_folder = Path(root_folder) self.data_key = self.get_data_key_from_root_folder(root_folder) self.interpolate_precision = interpolate_precision @@ -234,10 +319,10 @@ def _read_trials(self) -> None: ] def initialize_statistics(self) -> None: - """ - Initializes the statistics for each device based on the modality config. - :return: - instantiates self._statistics with the mean and std for each device + """Initialize normalization statistics for each device. + + Loads mean and standard deviation values from each device's meta folder + and stores them in ``self._statistics`` for use during data transforms. """ self._statistics = {} for device_name in self.device_names: @@ -298,10 +383,7 @@ def add_channel_function(x): return torch.from_numpy(x) def initialize_transforms(self): - """ - Initializes the transforms for each device based on the modality config. - :return: - """ + """Initialize data transforms for each device based on modality config.""" transforms = {} for device_name in self.device_names: if device_name == "screen": @@ -316,7 +398,6 @@ def initialize_transforms(self): transform_list.insert(0, add_channel) else: - transform_list: List[Any] = [ToTensor()] # Normalization. @@ -332,15 +413,22 @@ def initialize_transforms(self): return transforms def _get_callable_filter(self, filter_config): - """ - Helper function to get a callable filter function from either a config or an already instantiated callable. + """Return a callable filter function from config or an existing callable. + + Notes + ----- Handles partial instantiation using hydra.utils.instantiate. - Args: - filter_config: Either a config dict/DictConfig or a callable function + Parameters + ---------- + filter_config : dict, DictConfig, or callable + Either a config dictionary/DictConfig specifying a filter (with + '__target__'), or an already-instantiated callable filter function. - Returns: - callable: The final filter function ready to be called with device_ + Returns + ------- + callable + The final filter function ready to be called with `device_`. """ # Check if it's already a callable (function) if callable(filter_config): @@ -412,16 +500,25 @@ def get_valid_intervals_from_filters( def get_condition_mask_from_meta_conditions( self, valid_conditions_sum_of_product: List[dict] ) -> np.ndarray: - """Creates a boolean mask for trials that satisfy any of the given condition combinations. - - Args: - valid_conditions_sum_of_product: List of dictionaries, where each dictionary represents a set of - conditions that should be satisfied together (AND). Multiple dictionaries are combined with OR. - Example: [{'tier': 'train', 'stim_type': 'natural'}, {'tier': 'blank'}] matches trials that - are either (train AND natural) OR blank. - - Returns: - np.ndarray: Boolean mask indicating which trials satisfy at least one set of conditions. + """Create a boolean mask for trials satisfying given conditions. + + Parameters + ---------- + valid_conditions_sum_of_product : list of dict + Condition dictionaries combined with OR logic, where conditions + within each dictionary use AND logic. + + Returns + ------- + np.ndarray + Boolean mask indicating which trials satisfy at least one set of + conditions. + + Notes + ----- + For example, + ``[{'tier': 'train', 'stim_type': 'natural'}, {'tier': 'blank'}]`` + matches trials that are either (train AND natural) OR blank. """ all_conditions: Optional[np.ndarray] = None for valid_conditions_product in valid_conditions_sum_of_product: @@ -449,15 +546,22 @@ def get_screen_sample_mask_from_meta_conditions( valid_conditions_sum_of_product: List[dict], filter_for_valid_intervals: bool = True, ) -> np.ndarray: - """Creates a boolean mask indicating which screen samples satisfy the given conditions. - - Args: - satisfy_for_next: Number of consecutive samples that must satisfy conditions - valid_conditions_sum_of_product: List of condition dictionaries combined with OR logic, - where conditions within each dictionary use AND logic - - Returns: - Boolean array matching screen sample times, True where conditions are met + """Create a boolean mask for screen samples satisfying given conditions. + + Parameters + ---------- + satisfy_for_next : int + Number of consecutive samples that must satisfy conditions. + valid_conditions_sum_of_product : list of dict + Condition dictionaries combined with OR logic, where conditions + within each dictionary use AND logic. + filter_for_valid_intervals : bool, default=True + Whether to apply interval-based filtering. + + Returns + ------- + numpy.ndarray + Boolean array matching screen sample times, True where conditions are met. """ all_conditions = self.get_condition_mask_from_meta_conditions( valid_conditions_sum_of_product @@ -508,12 +612,21 @@ def get_screen_sample_mask_from_meta_conditions( def get_full_valid_sample_times( self, filter_for_valid_intervals: bool = True ) -> np.ndarray: - """ - iterates through all sample times and checks if they could be used as - start times, eg if the next `self.chunk_sizes["screen"]` points are still valid - based on the previous meta condition filtering - :returns: - valid_times: np.array of valid starting points + """Get all valid chunk starting times based on meta conditions. + + Iterates through sample times and checks if they can be used as chunk + start times (i.e., the next ``chunk_size`` points are all valid based + on the previous meta condition filtering). + + Parameters + ---------- + filter_for_valid_intervals : bool, default=True + Whether to apply interval-based filtering. + + Returns + ------- + numpy.ndarray + Array of valid starting time points. """ # Calculate all possible end indices @@ -565,14 +678,20 @@ def shuffle_valid_screen_times(self) -> None: ) def get_data_key_from_root_folder(self, root_folder): - """ - Extract a data key from the root folder path by checking for a meta.json file. + """Extract a data key from the root folder path. + + Checks for a meta.json file and extracts the data_key or scan_key. - Args: - root_folder (str or Path): Path to the root folder containing dataset + Parameters + ---------- + root_folder : str or Path + Path to the root folder containing the dataset. - Returns: - str: The extracted data key or folder name if meta.json doesn't exist or lacks data_key + Returns + ------- + str + The extracted data key, or folder name if meta.json doesn't + exist or lacks data_key. """ # Convert Path object to string if necessary root_folder = str(root_folder) @@ -617,7 +736,28 @@ def get_data_key_from_root_folder(self, root_folder): def __len__(self): return len(self._valid_screen_times) - def __getitem__(self, idx) -> dict: + def __getitem__(self, idx: int) -> dict: + """Return a single data sample at the given index. + + Parameters + ---------- + idx : int + Index of the sample to retrieve. + + Returns + ------- + dict + Dictionary containing data for each modality in ``out_keys``, e.g.: + + - ``'screen'``: torch.Tensor of shape ``(C, T, H, W)`` + - ``'responses'``: torch.Tensor of shape ``(T, N_neurons)`` + - ``'eye_tracker'``: torch.Tensor of shape ``(T, N_features)`` + - ``'treadmill'``: torch.Tensor of shape ``(T, N_features)`` + - ``'timestamps'``: dict mapping modality names to time arrays + + Where ``T`` is the chunk size (may differ per modality), + ``C`` is channels, ``H`` is height, ``W`` is width. + """ out = {} timestamps = {} s = self._valid_screen_times[idx] diff --git a/experanto/experiment.py b/experanto/experiment.py index e9df4cf..552737a 100644 --- a/experanto/experiment.py +++ b/experanto/experiment.py @@ -18,18 +18,56 @@ class Experiment: + """High-level interface for loading and querying neuroscience experiments. + + An Experiment represents a single recording session containing multiple + modalities (e.g., visual stimuli, neural responses, behavioral data). + Each modality is loaded as an Interpolator, allowing data to be queried + at arbitrary time points. + + Parameters + ---------- + root_folder : str + Path to the experiment directory. Should contain subdirectories + for each modality (e.g., ``screen/``, ``responses/``, ``eye_tracker/``). + modality_config : dict, optional + Configuration dictionary specifying interpolation and processing + settings for each modality. See :mod:`experanto.configs` for the + default configuration structure. + cache_data : bool, default=False + If True, loads all trial data into memory for faster access. + Useful for smaller datasets or when memory is not a constraint. + + Attributes + ---------- + devices : dict + Dictionary mapping device names to their :class:`Interpolator` instances. + start_time : float + Earliest valid timestamp across all devices. + end_time : float + Latest valid timestamp across all devices. + + See Also + -------- + ChunkDataset : Higher-level interface for ML training. + Interpolator : Base class for modality-specific interpolators. + + Examples + -------- + >>> from experanto.experiment import Experiment + >>> exp = Experiment('/path/to/experiment') + >>> exp.device_names + ('screen', 'responses', 'eye_tracker') + >>> times = np.linspace(0, 10, 100) + >>> data, valid = exp.interpolate(times, device='responses') + """ + def __init__( self, root_folder: str, modality_config: dict = DEFAULT_MODALITY_CONFIG, cache_data: bool = False, ) -> None: - """ - root_folder: path to the data folder - interp_config: dict for configuring interpolators, like - interp_config = {"screen": {...}, {"eye_tracker": {...}, } - cache_data: if True, loads and keeps all trial data in memory - """ self.root_folder = Path(root_folder) self.devices = dict() self.start_time = np.inf @@ -73,7 +111,8 @@ def _load_devices(self) -> None: else: # Default back to original logic warnings.warn( - "Falling back to original Interpolator creation logic.", UserWarning + "Falling back to original Interpolator creation logic.", + UserWarning, ) dev = Interpolator.create(d, cache_data=self.cache_data, **interp_conf) # type: ignore[arg-type] @@ -92,6 +131,39 @@ def interpolate( device: Union[str, Interpolator, None] = None, return_valid: bool = False, ) -> Union[tuple[dict, dict], dict, tuple[np.ndarray, np.ndarray], np.ndarray]: + """Interpolate data from one or all devices at specified time points. + + Parameters + ---------- + times : array_like + 1D array of time points (in seconds) at which to interpolate. + device : str, optional + Name of a specific device to interpolate. If None, interpolates + all devices and returns dictionaries. + + Returns + ------- + values : numpy.ndarray or dict + If ``device`` is specified, returns the interpolated data array. + Otherwise, returns a dict mapping device names to their data arrays. + valid : numpy.ndarray or dict + Boolean mask(s) indicating which time points were valid. + Same structure as ``values``. + + Examples + -------- + Interpolate a single device: + + >>> data, valid = exp.interpolate(times, device='responses') + >>> data.shape + (100, 500) # 100 time points, 500 neurons + + Interpolate all devices: + + >>> data, valid = exp.interpolate(times) + >>> data.keys() + dict_keys(['screen', 'responses', 'eye_tracker']) + """ if device is None: values = {} valid = {} @@ -114,5 +186,18 @@ def interpolate( else: raise ValueError(f"Unsupported device type: {type(device)}") - def get_valid_range(self, device_name) -> tuple[float, float]: + def get_valid_range(self, device_name: str) -> tuple[float, float]: + """Get the valid time range for a specific device. + + Parameters + ---------- + device_name : str + Name of the device (e.g., 'screen', 'responses'). + + Returns + ------- + tuple + A tuple ``(start_time, end_time)`` representing the valid + time interval in seconds. + """ return tuple(self.devices[device_name].valid_interval) diff --git a/experanto/filters/common_filters.py b/experanto/filters/common_filters.py index 137ef4a..d74ef34 100644 --- a/experanto/filters/common_filters.py +++ b/experanto/filters/common_filters.py @@ -9,33 +9,48 @@ def nan_filter(vicinity=0.05): + """Create a filter that excludes time regions around NaN values. + + Returns a closure that, given a :class:`~experanto.interpolators.SequenceInterpolator`, + identifies all time points containing NaN in any channel and marks a + symmetric window of ``vicinity`` seconds around each as invalid. + + Parameters + ---------- + vicinity : float, optional + Half-width of the exclusion window in seconds around each NaN + time point. Default is 0.05. + + Returns + ------- + callable + A function that takes a + :class:`~experanto.interpolators.SequenceInterpolator` and returns + a list of :class:`~experanto.intervals.TimeInterval` representing + the valid (NaN-free) portions of the recording. + """ + def implementation(device_: SequenceInterpolator): - # requests SequenceInterpolator as uses time_delta internally - # and other interpolators don't have it + # Requires a SequenceInterpolator since it relies on time_delta, + # which other interpolator types do not expose. time_delta = device_.time_delta start_time = device_.start_time end_time = device_.end_time - data = device_._data # (T, n_neurons) + data = device_._data # (T, n_features) - # detect nans - nan_mask = np.isnan(data) # (T, n_neurons) + nan_mask = np.isnan(data) # (T, n_features) nan_mask = np.any(nan_mask, axis=1) # (T,) - - # Find indices where nan_mask is True nan_indices = np.where(nan_mask)[0] - # Create invalid TimeIntervals around each nan point invalid_intervals = [] - vicinity_seconds = vicinity # vicinity is already in seconds for idx in nan_indices: time_point = start_time + idx * time_delta - interval_start = max(start_time, time_point - vicinity_seconds) - interval_end = min(end_time, time_point + vicinity_seconds) + interval_start = max(start_time, time_point - vicinity) + interval_end = min(end_time, time_point + vicinity) invalid_intervals.append(TimeInterval(interval_start, interval_end)) # Merge overlapping invalid intervals invalid_intervals = uniquefy_interval_array(invalid_intervals) - # Find the complement of invalid intervals to get valid intervals valid_intervals = find_complement_of_interval_array( start_time, end_time, invalid_intervals diff --git a/experanto/interpolators.py b/experanto/interpolators.py index 48ecaa2..cc20938 100644 --- a/experanto/interpolators.py +++ b/experanto/interpolators.py @@ -18,6 +18,40 @@ class Interpolator: + """Abstract base class for time series interpolation. + + Interpolators load data from a modality folder and map time points to + data values. Each modality (e.g., screen, responses, eye_tracker, + treadmill) is assigned to a separate interpolator object belonging to + one of the Interpolator subclasses (e.g., SequenceInterpolator, + ScreenInterpolator, etc.), but multiple modalities can belong to the same + class, such as treadmill and eye_tracker both being assigned to the + SequenceInterpolator subclass. + + Parameters + ---------- + root_folder : str + Path to the modality directory containing data and metadata files. + + Attributes + ---------- + root_folder : pathlib.Path + Path to the modality directory. + start_time : float + Earliest timestamp in the data. + end_time : float + Latest timestamp in the data. + valid_interval : TimeInterval + Time range for which interpolation is valid. + + See Also + -------- + SequenceInterpolator : For time series data (responses, behaviors). + ScreenInterpolator : For visual stimuli (images, videos). + TimeIntervalInterpolator : For labeled time intervals (e.g., train/test splits). + Experiment : High-level interface that manages multiple interpolators. + """ + def __init__(self, root_folder: str) -> None: self.root_folder = Path(root_folder) self.start_time = None @@ -34,8 +68,8 @@ def load_meta(self): def interpolate( self, times: np.ndarray, return_valid: bool = False ) -> Union[tuple[np.ndarray, np.ndarray], np.ndarray]: + """Map an array of time points to interpolated data values.""" ... - # returns interpolated signal and boolean mask of valid samples def __contains__(self, times: np.ndarray): return np.any(self.valid_times(times)) @@ -48,6 +82,30 @@ def __exit__(self, *exc): @staticmethod def create(root_folder: str, cache_data: bool = False, **kwargs) -> "Interpolator": + """Factory method to create the appropriate interpolator for a modality. + + Reads the ``meta.yml`` file in the folder to determine the modality type + and instantiates the corresponding interpolator subclass. + + Parameters + ---------- + root_folder : str + Path to the modality directory. + cache_data : bool, default=False + If True, loads all data into memory for faster access. + **kwargs + Additional arguments passed to the interpolator constructor. + + Returns + ------- + Interpolator + An instance of the appropriate interpolator subclass. + + Raises + ------ + ValueError + If the modality type is not supported. + """ with open(Path(root_folder) / "meta.yml", "r") as file: meta_data = yaml.safe_load(file) modality = meta_data.get("modality") @@ -79,6 +137,50 @@ def close(self): class SequenceInterpolator(Interpolator): + """Interpolator for 1D time series data (neural responses, behaviors). + + Handles regularly-sampled time series stored as memory-mapped or NumPy + arrays. Supports nearest-neighbor and linear interpolation modes. + + Parameters + ---------- + root_folder : str + Path to the modality directory containing ``data.mem`` or ``data.npy``. + cache_data : bool, default=False + If True, loads memory-mapped data into RAM for faster access. + keep_nans : bool, default=False + If False, replaces NaN values with column means during interpolation. + interpolation_mode : str, default='nearest_neighbor' + Interpolation method: ``'nearest_neighbor'`` or ``'linear'``. + normalize : bool, default=False + If True, normalizes data using stored mean/std statistics. + normalize_subtract_mean : bool, default=False + If True, subtracts mean during normalization. + normalize_std_threshold : float, optional + Minimum std threshold to prevent division by near-zero values. + **kwargs + Additional keyword arguments (ignored). + + Attributes + ---------- + sampling_rate : float + Original sampling rate of the data in Hz. + time_delta : float + Time between samples (1 / sampling_rate). + n_signals : int + Number of signals (e.g., neurons, behavior channels). + + Notes + ----- + For linear interpolation, values are computed as: + + .. math:: + + y(t) = y_0 \\cdot \\frac{t_1 - t}{t_1 - t_0} + y_1 \\cdot \\frac{t - t_0}{t_1 - t_0}, + + where :math:`t_0` and :math:`t_1` are the surrounding sample times. + """ + def __init__( self, root_folder: str, @@ -90,10 +192,6 @@ def __init__( normalize_std_threshold: typing.Optional[float] = None, # or 0.01 **kwargs, ) -> None: - """ - interpolation_mode - nearest neighbor or linear - keep_nans - if we keep nans in linear interpolation - """ super().__init__(root_folder) meta = self.load_meta() self.keep_nans = keep_nans @@ -228,6 +326,25 @@ def close(self) -> None: class PhaseShiftedSequenceInterpolator(SequenceInterpolator): + """Sequence interpolator with per-signal phase shifts. + + Extends :class:`SequenceInterpolator` to handle signals recorded with + different phase offsets (e.g., neurons with different response latencies). + Each signal is interpolated at its own phase-shifted time. + + Parameters + ---------- + root_folder : str + Path to the modality directory. Must contain ``meta/phase_shifts.npy``. + **kwargs + All parameters from :class:`SequenceInterpolator`. + + Attributes + ---------- + _phase_shifts : numpy.ndarray + Per-signal phase shift values in seconds. + """ + def __init__( self, root_folder: str, @@ -333,6 +450,43 @@ def interpolate( class ScreenInterpolator(Interpolator): + """Interpolator for visual stimuli (images and videos). + + Handles frame-based visual data organized as trials. Each trial can be + a single image, a video sequence, or a blank screen. Frames are indexed + by timestamp and retrieved on demand. + + Parameters + ---------- + root_folder : str + Path to the screen modality directory containing ``timestamps.npy``, + ``data/`` folder with trial files, and ``meta/`` folder with metadata. + cache_data : bool, default=False + If True, loads all trial data into memory for faster access. + rescale : bool, default=False + If True, rescales frames to ``rescale_size``. + rescale_size : tuple of int, optional + Target size ``(height, width)`` for rescaling. If None, uses the + native image size from metadata. + normalize : bool, default=False + If True, normalizes frames using stored mean/std statistics. + **kwargs + Additional keyword arguments (ignored). + + Attributes + ---------- + timestamps : numpy.ndarray + Array of frame timestamps. + trials : list of ScreenTrial + List of trial objects containing frame data. + + See Also + -------- + ImageTrial : Single-frame stimuli. + VideoTrial : Multi-frame video stimuli. + BlankTrial : Blank/gray screen stimuli. + """ + def __init__( self, root_folder: str, @@ -342,10 +496,6 @@ def __init__( normalize: bool = False, **kwargs, ) -> None: - """ - rescale would rescale images to the _image_size if true - cache_data: if True, loads and keeps all trial data in memory - """ super().__init__(root_folder) self.timestamps = np.load(self.root_folder / "timestamps.npy") self.start_time = self.timestamps[0] @@ -488,9 +638,17 @@ def interpolate( return (out, valid) if return_valid else out def rescale_frame(self, frame: np.ndarray) -> np.ndarray: - """ - Changes the resolution of the image to this size. - Returns: Rescaled image + """Rescale frame to the configured image size. + + Parameters + ---------- + frame : np.ndarray + Input image frame. + + Returns + ------- + np.ndarray + Rescaled image as float32. """ return cv2.resize(frame, self._image_size, interpolation=cv2.INTER_AREA).astype( np.float32 @@ -498,6 +656,48 @@ def rescale_frame(self, frame: np.ndarray) -> np.ndarray: class TimeIntervalInterpolator(Interpolator): + """Interpolator for labeled time intervals. + + Maps time points to boolean membership in labeled intervals. Given a + set of time points, returns a boolean array indicating whether each + point falls within any interval for each label. + + Labels and their intervals are defined in the ``meta.yml`` file under + the ``labels`` key. Each label points to a ``.npy`` file containing an + array of shape ``(n, 2)``, where each row is a ``[start, end)`` + half-open time interval. Typical labels include ``'train'``, + ``'validation'``, ``'test'``, ``'saccade'``, ``'gaze'``, or + ``'target'``. + + The half-open convention means a timestamp *t* is considered inside an + interval when ``start <= t < end``. Intervals where ``start > end`` + are treated as invalid and trigger a warning. + + Parameters + ---------- + root_folder : str + Path to the modality directory containing ``meta.yml`` and the + ``.npy`` interval files referenced by its ``labels`` mapping. + cache_data : bool, default=False + If True, loads all interval arrays into memory at init time. + **kwargs + Additional keyword arguments (ignored). + + Attributes + ---------- + meta_labels : dict + Mapping from label names to ``.npy`` filenames. + + Notes + ----- + - Only time points within the valid interval (as defined by + ``start_time`` and ``end_time`` in ``meta.yml``) are considered; + others are filtered out. + - The ``interpolate`` method returns an array of shape + ``(n_valid_times, n_labels)`` where ``out[i, j]`` is True if the + *i*-th valid time falls within any interval for the *j*-th label. + """ + def __init__(self, root_folder: str, cache_data: bool = False, **kwargs): super().__init__(root_folder) self.cache_data = cache_data @@ -517,41 +717,6 @@ def __init__(self, root_folder: str, cache_data: bool = False, **kwargs): def interpolate( self, times: np.ndarray, return_valid: bool = False ) -> Union[tuple[np.ndarray, np.ndarray], np.ndarray]: - """ - Interpolate time intervals for labeled events. - - Given a set of time points and a set of labeled intervals (defined in the - `meta.yml` file), this method returns a boolean array indicating, for each - time point, whether it falls within any interval for each label. - - The method uses half-open intervals [start, end), where a timestamp t is - considered to fall within an interval if start <= t < end. This means the - start time is inclusive and the end time is exclusive. - - Parameters - ---------- - times : np.ndarray - Array of time points to be checked against the labeled intervals. - - Returns - ------- - out : np.ndarray of bool, shape (len(valid_times), n_labels) - Boolean array where each row corresponds to a valid time point and each - column corresponds to a label. `out[i, j]` is True if the i-th valid - time falls within any interval for the j-th label, and False otherwise. - - Notes - ----- - - The labels and their corresponding intervals are defined in the `meta.yml` - file under the `labels` key. Each label points to a `.npy` file containing - an array of shape (n, 2), where each row is a [start, end) time interval. - - Typical labels might include 'train', 'validation', 'test', 'saccade', - 'gaze', or 'target'. - - Only time points within the valid interval (as defined by start_time and - end_time in meta.yml) are considered; others are filtered out. - - Intervals where start > end are considered invalid and will trigger a - warning. - """ valid = self.valid_times(times) valid_times = times[valid] @@ -591,6 +756,27 @@ def interpolate( class ScreenTrial: + """Base class for visual stimulus trials. + + Represents a single trial (stimulus presentation) in a screen recording. + Subclasses handle different trial types: images, videos, and blanks. + + Parameters + ---------- + data_file_name : str + Path to the data file for this trial. + meta_data : dict + Metadata dictionary for the trial. + image_size : tuple + Frame dimensions ``(height, width)`` or ``(height, width, channels)``. + first_frame_idx : int + Index of the first frame in the global timestamp array. + num_frames : int + Number of frames in this trial. + cache_data : bool, default=False + If True, loads and caches data on initialization. + """ + def __init__( self, data_file_name: Union[str, Path], @@ -613,7 +799,9 @@ def __init__( @staticmethod def create( - data_file_name: Union[str, Path], meta_data: dict, cache_data: bool = False + data_file_name: Union[str, Path], + meta_data: dict, + cache_data: bool = False, ) -> "ScreenTrial": modality = meta_data.get("modality") assert modality is not None @@ -636,6 +824,8 @@ def get_meta(self, property: str): class ImageTrial(ScreenTrial): + """Trial containing a single static image.""" + def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None: super().__init__( data_file_name, @@ -648,6 +838,8 @@ def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None: class VideoTrial(ScreenTrial): + """Trial containing a multi-frame video sequence.""" + def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None: super().__init__( data_file_name, @@ -660,6 +852,8 @@ def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None: class BlankTrial(ScreenTrial): + """Trial containing a blank/gray screen (inter-stimulus interval).""" + def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None: self.interleave_value = meta_data.get("interleave_value") @@ -678,6 +872,8 @@ def get_data_(self) -> np.ndarray: class InvalidTrial(ScreenTrial): + """Placeholder for invalid or corrupted trials.""" + def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None: self.interleave_value = meta_data.get("interleave_value") diff --git a/experanto/intervals.py b/experanto/intervals.py index 6fd0d4e..1607985 100644 --- a/experanto/intervals.py +++ b/experanto/intervals.py @@ -5,6 +5,24 @@ class TimeInterval(typing.NamedTuple): + """A time interval represented by start and end times. + + Parameters + ---------- + start : float + Start time in seconds. + end : float + End time in seconds. + + Examples + -------- + >>> interval = TimeInterval(0.0, 10.0) + >>> 5.0 in interval + True + >>> 15.0 in interval + False + """ + start: float end: float @@ -28,18 +46,18 @@ def intersect(self, times: np.ndarray) -> np.ndarray: return np.where((times >= self.start) & (times <= self.end))[0] -def uniquefy_interval_array( - interval_array: List[TimeInterval], -) -> List[TimeInterval]: - """ - Takes an array of TimeIntervals and returns a new array where no intervals overlap. - If intervals overlap or are adjacent, they are merged into a single interval. +def uniquefy_interval_array(interval_array: List[TimeInterval]) -> List[TimeInterval]: + """Merge overlapping or adjacent intervals into non-overlapping intervals. - Args: - interval_array: List of TimeInterval objects + Parameters + ---------- + interval_array : list of TimeInterval + Input intervals, may overlap or be adjacent. - Returns: - List of non-overlapping TimeInterval objects sorted by start time + Returns + ------- + list of TimeInterval + Non-overlapping intervals sorted by start time. """ if not interval_array: return [] @@ -66,6 +84,20 @@ def uniquefy_interval_array( def find_intersection_between_two_interval_arrays( interval_array_1: List[TimeInterval], interval_array_2: List[TimeInterval] ) -> List[TimeInterval]: + """Find the intersection of two interval arrays. + + Parameters + ---------- + interval_array_1 : list of TimeInterval + First set of intervals. + interval_array_2 : list of TimeInterval + Second set of intervals. + + Returns + ------- + list of TimeInterval + Intervals where both input arrays overlap. + """ # Sort both arrays by start time sorted_1 = sorted(interval_array_1, key=lambda x: x.start) sorted_2 = sorted(interval_array_2, key=lambda x: x.start) @@ -94,6 +126,18 @@ def find_intersection_between_two_interval_arrays( def find_intersection_across_arrays_of_intervals( intervals_array: List[List[TimeInterval]], ) -> List[TimeInterval]: + """Find the common intersection across multiple interval arrays. + + Parameters + ---------- + intervals_array : list of list of TimeInterval + Multiple sets of intervals. + + Returns + ------- + list of TimeInterval + Intervals where all input arrays overlap. + """ common_interval_array = intervals_array[0] for interval_array in intervals_array[1:]: @@ -107,6 +151,18 @@ def find_intersection_across_arrays_of_intervals( def find_union_across_arrays_of_intervals( intervals_array: List[List[TimeInterval]], ) -> List[TimeInterval]: + """Find the union of multiple interval arrays. + + Parameters + ---------- + intervals_array : list of list of TimeInterval + Multiple sets of intervals. + + Returns + ------- + list of TimeInterval + Merged non-overlapping intervals covering all input intervals. + """ union_array = [] for interval_array in intervals_array: union_array.extend(interval_array) @@ -116,17 +172,21 @@ def find_union_across_arrays_of_intervals( def find_complement_of_interval_array( start: float, end: float, interval_array: List[TimeInterval] ) -> List[TimeInterval]: - """ - Finds the complement of an interval array within a given range [start, end]. - Returns intervals that represent the gaps not covered by the input intervals. - - Args: - start: Start time of the range - end: End time of the range - interval_array: List of TimeInterval objects - - Returns: - List of TimeInterval objects representing the complement + """Find gaps not covered by intervals within a range. + + Parameters + ---------- + start : float + Start of the range. + end : float + End of the range. + interval_array : list of TimeInterval + Intervals to find the complement of. + + Returns + ------- + list of TimeInterval + Intervals representing uncovered gaps in ``[start, end]``. """ if not interval_array: return [TimeInterval(start, end)] @@ -157,16 +217,22 @@ def find_complement_of_interval_array( def get_stats_for_valid_interval( intervals: List[TimeInterval], start_time: float, end_time: float ) -> str: - """ - Calculates and returns statistics about valid and invalid time intervals within a given range. - - Args: - intervals: List of TimeInterval objects representing the valid periods. - start_time: The beginning of the total time range to consider. - end_time: The end of the total time range to consider. - - Returns: - A string summarizing the statistics of valid and invalid intervals. + """Calculate statistics about valid and invalid intervals within a range. + + Parameters + ---------- + intervals : list of TimeInterval + Valid time intervals. + start_time : float + Start of the analysis range. + end_time : float + End of the analysis range. + + Returns + ------- + str + Formatted string with statistics (duration, mean, std) for both + valid and invalid intervals. """ total_duration = end_time - start_time if total_duration <= 0: diff --git a/experanto/utils.py b/experanto/utils.py index 9551e56..6192bf3 100644 --- a/experanto/utils.py +++ b/experanto/utils.py @@ -34,21 +34,20 @@ def replace_nan_with_batch_mean(data: np.ndarray) -> np.ndarray: def add_behavior_as_channels(data: dict[str, torch.Tensor]) -> dict: - """ - Adds behavioral data as additional channels to screen data. - - Input: - data = { - 'screen': torch.Tensor: (c, t, h, w) - 'eye_tracker': torch.Tensor: (t, c_eye) or (t, h, w) - 'treadmill': torch.Tensor: (t, c_tread) or (t, h, w) - } - - Output: - data = { - 'screen': torch.Tensor: (c+behavior_channels, t, h, w) - contiguous - ... - } + """Add behavioral data as additional channels to screen data. + + Parameters + ---------- + data : dict + Dictionary with keys 'screen', 'eye_tracker', 'treadmill'. + Screen shape: ``(C, T, H, W)``. + Behavior shapes: ``(T, C_behavior)`` or ``(T, H, W)``. + + Returns + ------- + dict + Modified dictionary with behavior concatenated to screen channels. + Screen shape becomes ``(C + behavior_channels, T, H, W)``. """ screen = data["screen"] # Already contiguous, shape (c, t, h, w) c, t, h, w = screen.shape @@ -152,9 +151,28 @@ def __len__(self): class LongCycler: - """ - Cycles through trainloaders until the loader with largest size is exhausted. - Needed for dataloaders of unequal size (as in the monkey data). + """Cycle through multiple dataloaders until the longest is exhausted. + + Useful for training with multiple sessions of unequal size. Cycles through + all loaders, yielding ``(session_key, batch)`` pairs. Shorter loaders are + recycled until the longest loader completes one full epoch. + + Parameters + ---------- + loaders : dict + Dictionary mapping session keys to DataLoader instances. + + Attributes + ---------- + max_batches : int + Number of batches in the longest loader. + + Examples + -------- + >>> loaders = {'session_1': loader1, 'session_2': loader2} + >>> cycler = LongCycler(loaders) + >>> for session_key, batch in cycler: + ... print(f"Processing {session_key}") """ def __init__(self, loaders): @@ -175,9 +193,20 @@ def __len__(self): class ShortCycler: - """ - Cycles through trainloaders until the loader with smallest size is exhausted. - Needed for dataloaders of unequal size (as in the monkey data). + """Cycle through multiple dataloaders until the shortest is exhausted. + + Similar to :class:`LongCycler`, but stops when the smallest loader + completes one epoch. No recycling occurs. + + Parameters + ---------- + loaders : dict + Dictionary mapping session keys to DataLoader instances. + + Attributes + ---------- + min_batches : int + Number of batches in the shortest loader. """ def __init__(self, loaders): @@ -245,7 +274,10 @@ def __init__(self, datasets, session_names=None): for i, dataset in enumerate(datasets): session_name = session_names[i] session_size = len(dataset) - self.session_indices[session_name] = (start_idx, start_idx + session_size) + self.session_indices[session_name] = ( + start_idx, + start_idx + session_size, + ) start_idx += session_size def __len__(self): @@ -298,15 +330,21 @@ class SessionBatchSampler(Sampler): """ def __init__(self, dataset, batch_size, drop_last=False, shuffle=False, seed=None): - """ - Initialize session batch sampler. - - Args: - dataset: The SessionConcatDataset to sample from - batch_size: Number of samples per batch - drop_last: Whether to drop the last batch if it's smaller than batch_size - shuffle: Whether to shuffle samples within each session - seed: Random seed for reproducibility + """Initialize session batch sampler. + + Parameters + ---------- + dataset : SessionConcatDataset + The dataset to sample from. + batch_size : int + Number of samples per batch. + drop_last : bool, optional + Whether to drop the last batch if smaller than batch_size. + Default is False. + shuffle : bool, optional + Whether to shuffle samples within each session. Default is False. + seed : int, optional + Random seed for reproducibility. """ self.dataset = dataset self.batch_size = batch_size @@ -386,12 +424,45 @@ def set_state(self, state): class FastSessionDataLoader: - """ - An optimized dataloader that ensures: - 1. Each session appears exactly once before repeating - 2. The epoch ends when the longest session is exhausted - 3. Perfect alignment between sessions and batches is maintained - 4. State is properly tracked and can be restored + """Optimized multi-session dataloader with state tracking. + + Provides efficient data loading across multiple sessions with guarantees: + + - Each session appears exactly once before repeating + - Epoch ends when the longest session is exhausted + - Perfect alignment between sessions and batches is maintained + - State is properly tracked and can be restored + + Parameters + ---------- + dataset : SessionConcatDataset + Concatenated dataset with session tracking. + batch_size : int, default=1 + Number of samples per batch. + shuffle : bool, default=False + Whether to shuffle samples within each session. + num_workers : int, default=0 + Number of worker processes for data loading. + pin_memory : bool, default=False + Whether to pin memory for GPU transfer. + drop_last : bool, default=False + Whether to drop incomplete batches. + seed : int, optional + Random seed for reproducibility. + **kwargs + Additional arguments passed to underlying DataLoaders. + + Attributes + ---------- + session_names : list + Names of all sessions in the dataset. + batches_per_session : dict + Number of batches in each session. + + See Also + -------- + SessionConcatDataset : Dataset that tracks session membership. + LongCycler : Simpler alternative without state tracking. """ def __init__( @@ -405,18 +476,6 @@ def __init__( seed=None, **kwargs, ): - """ - Initialize optimized session dataloader. - - Args: - dataset: The SessionConcatDataset to load from - batch_size: Number of samples per batch - shuffle: Whether to shuffle indices within sessions - num_workers: Number of worker processes for data loading - pin_memory: Whether to pin memory in GPU - drop_last: Whether to drop the last batch if smaller than batch_size - seed: Random seed for reproducibility - """ # Store dataset and parameters self.dataset = dataset self.batch_size = batch_size @@ -602,7 +661,6 @@ def __iter__(self): # Continue until we've gone through one full epoch # (i.e., until the longest session is exhausted) while active_sessions and position_in_epoch < self.max_batches_per_session: - # Create a cycle order of sessions cycle_order = self.batch_sampler.get_session_cycle() @@ -665,15 +723,21 @@ class SessionSpecificSampler(Sampler): """ def __init__(self, indices, batch_size, drop_last=False, shuffle=False, seed=None): - """ - Initialize session-specific sampler. - - Args: - indices: List of dataset indices belonging to this session - batch_size: Number of samples per batch - drop_last: Whether to drop the last batch if smaller than batch_size - shuffle: Whether to shuffle indices - seed: Random seed for reproducibility + """Initialize session-specific sampler. + + Parameters + ---------- + indices : list + Dataset indices belonging to this session. + batch_size : int + Number of samples per batch. + drop_last : bool, optional + Whether to drop the last batch if smaller than batch_size. + Default is False. + shuffle : bool, optional + Whether to shuffle indices. Default is False. + seed : int, optional + Random seed for reproducibility. """ self.indices = list(indices) # Make a copy to avoid modification issues self.batch_size = batch_size diff --git a/pyproject.toml b/pyproject.toml index 391d610..5f968c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "experanto" -version = "0.1" +version = "0.1.0" description = "Python package to interpolate recordings and stimuli of neuroscience experiments" readme = "README.md" requires-python = ">=3.9" @@ -9,10 +9,10 @@ requires-python = ">=3.9" [tool.setuptools.packages.find] where = ["."] -include = ["experanto*", "configs"] +include = ["experanto*"] [tool.setuptools.package-data] -"configs" = ["*.yaml"] +experanto = ["../configs/*.yaml"] [project.urls] Homepage = "https://github.com/sensorium-competition/experanto" @@ -23,7 +23,7 @@ build-backend = "setuptools.build_meta" [tool.black] line-length = 88 -target-version = ["py312"] +target-version = ["py39"] include = '\.pyi?$' exclude = ''' /(