diff --git a/.gitignore b/.gitignore
index 9c9fb96..2bc61f8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -163,3 +163,5 @@ cython_debug/
*.sif
*.bak
+
+docs/source/generated/
diff --git a/README.md b/README.md
index ca98f27..6c77fe8 100644
--- a/README.md
+++ b/README.md
@@ -1,31 +1,98 @@
# Experanto
+
Experanto is a Python package designed for interpolating recordings and stimuli in neuroscience experiments. It enables users to load single or multiple experiments and create efficient dataloaders for machine learning applications.
+## Features
+
+- **Unified Experiment Interface**: Load and query multi-modal neuroscience data (neural responses, eye tracking, treadmill, visual stimuli) through a single `Experiment` class
+- **Flexible Interpolation**: Interpolate data at arbitrary time points with support for linear and nearest-neighbor methods
+- **Multi-Session Support**: Combine data from multiple recording sessions into a single dataloader
+- **Configurable Preprocessing**: YAML-based configuration for sampling rates, normalization, transforms, and filtering
+- **PyTorch Integration**: Native PyTorch `Dataset` and `DataLoader` implementations optimized for training
+
## Docs
[](https://experanto.readthedocs.io/)
## Installation
-To install Experanto, clone locally and run:
+
```bash
-pip install -e /path_to/experanto
+git clone https://github.com/sensorium-competition/experanto.git
+cd experanto
+pip install -e .
```
-To replicate the `generate_sample` example, install:
+### Note
+
+To replicate the `generate_sample` example, use the following command (see [allen_exporter](https://github.com/sensorium-competition/allen-exporter)):
+
```bash
-pip install -e /path_to/allen_exporter
+pip install -e /path/to/allen_exporter
```
-(Repository: [allen_exporter](https://github.com/sensorium-competition/allen-exporter))
-To replicate the `sensorium_example`, also install the following with their dependencies:
+To replicate the `sensorium_example` (see [sensorium_2023](https://github.com/ecker-lab/sensorium_2023)), install neuralpredictors (see [neuralpredictors](https://github.com/sinzlab/neuralpredictors)) as well:
+
```bash
-pip install -e /path_to/neuralpredictors
+pip install -e /path/to/neuralpredictors
+pip install -e /path/to/sensorium_2023
```
-(Repository: [neuralpredictors](https://github.com/sinzlab/neuralpredictors))
-```bash
-pip install -e /path_to/sensorium_2023
+## Quick Start
+
+### Loading an Experiment
+
+```python
+from experanto.experiment import Experiment
+
+# Load a single experiment
+exp = Experiment("/path/to/experiment")
+
+# Query data at specific time points
+import numpy as np
+times = np.linspace(0, 10, 100) # 100 time points over 10 seconds
+
+# Get interpolated data and a boolean mask with valid time points from all devices
+data, valid = exp.interpolate(times)
+
+# Or from a specific device
+responses, valid = exp.interpolate(times, device="responses")
+```
+
+### Configuration
+
+Experanto uses YAML configuration files. See `configs/default.yaml` for all options:
+
+```yaml
+dataset:
+ modality_config:
+ responses:
+ sampling_rate: 8
+ chunk_size: 16
+ transforms:
+ normalization: "standardize"
+ screen:
+ sampling_rate: 30
+ chunk_size: 60
+ transforms:
+ normalization: "normalize"
+
+dataloader:
+ batch_size: 16
+ num_workers: 2
```
-(Repository: [sensorium_2023](https://github.com/ecker-lab/sensorium_2023))
-Ensure you replace `/path_to/` with the actual path to the cloned repositories.
+## Documentation
+
+Full documentation is available at [Read the Docs](https://experanto.readthedocs.io/).
+
+- [Installation Guide](https://experanto.readthedocs.io/en/latest/concepts/installation.html)
+- [Getting Started](https://experanto.readthedocs.io/en/latest/concepts/getting_started.html)
+- [API Reference](https://experanto.readthedocs.io/en/latest/api.html)
+- [Configuration Options](https://experanto.readthedocs.io/en/latest/configuration.html)
+
+## Contributing
+
+Contributions are welcome! Please open an issue or submit a pull request on [GitHub](https://github.com/sensorium-competition/experanto).
+
+## License
+This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
diff --git a/docs/source/_static/.gitkeep b/docs/source/_static/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/docs/source/_templates/custom-class-template.rst b/docs/source/_templates/custom-class-template.rst
new file mode 100644
index 0000000..2a6ffd2
--- /dev/null
+++ b/docs/source/_templates/custom-class-template.rst
@@ -0,0 +1,31 @@
+{{ fullname | escape | underline}}
+
+.. currentmodule:: {{ module }}
+
+.. autoclass:: {{ objname }}
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :inherited-members:
+
+ {% block methods %}
+ {% if methods %}
+ .. rubric:: Methods
+
+ .. autosummary::
+ {% for item in methods %}
+ ~{{ name }}.{{ item }}
+ {%- endfor %}
+ {% endif %}
+ {% endblock %}
+
+ {% block attributes %}
+ {% if attributes %}
+ .. rubric:: Attributes
+
+ .. autosummary::
+ {% for item in attributes %}
+ ~{{ name }}.{{ item }}
+ {%- endfor %}
+ {% endif %}
+ {% endblock %}
diff --git a/docs/source/api.rst b/docs/source/api.rst
new file mode 100644
index 0000000..5d718a8
--- /dev/null
+++ b/docs/source/api.rst
@@ -0,0 +1,96 @@
+Classes and functions
+=====================
+
+This section documents all public classes and functions in Experanto.
+
+Core Classes
+------------
+
+.. autosummary::
+ :toctree: generated
+ :template: custom-class-template.rst
+ :nosignatures:
+
+ experanto.experiment.Experiment
+ experanto.datasets.ChunkDataset
+
+Interpolators
+-------------
+
+.. autosummary::
+ :toctree: generated
+ :template: custom-class-template.rst
+ :nosignatures:
+
+ experanto.interpolators.Interpolator
+ experanto.interpolators.SequenceInterpolator
+ experanto.interpolators.PhaseShiftedSequenceInterpolator
+ experanto.interpolators.ScreenInterpolator
+ experanto.interpolators.TimeIntervalInterpolator
+ experanto.interpolators.ScreenTrial
+ experanto.interpolators.ImageTrial
+ experanto.interpolators.VideoTrial
+ experanto.interpolators.BlankTrial
+ experanto.interpolators.InvalidTrial
+
+Time Intervals
+--------------
+
+.. autosummary::
+ :toctree: generated
+ :template: custom-class-template.rst
+ :nosignatures:
+
+ experanto.intervals.TimeInterval
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
+
+ experanto.intervals.uniquefy_interval_array
+ experanto.intervals.find_intersection_between_two_interval_arrays
+ experanto.intervals.find_intersection_across_arrays_of_intervals
+ experanto.intervals.find_union_across_arrays_of_intervals
+ experanto.intervals.find_complement_of_interval_array
+ experanto.intervals.get_stats_for_valid_interval
+
+Dataloaders
+-----------
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
+
+ experanto.dataloaders.get_multisession_dataloader
+ experanto.dataloaders.get_multisession_concat_dataloader
+
+Utilities
+---------
+
+.. autosummary::
+ :toctree: generated
+ :template: custom-class-template.rst
+ :nosignatures:
+
+ experanto.utils.LongCycler
+ experanto.utils.ShortCycler
+ experanto.utils.FastSessionDataLoader
+ experanto.utils.MultiEpochsDataLoader
+ experanto.utils.SessionConcatDataset
+ experanto.utils.SessionBatchSampler
+ experanto.utils.SessionSpecificSampler
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
+
+ experanto.utils.add_behavior_as_channels
+
+Filters
+-------
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
+
+ experanto.filters.common_filters.nan_filter
diff --git a/docs/source/concepts/demo_configs.rst b/docs/source/concepts/demo_configs.rst
index 89dfdc2..3d0b090 100644
--- a/docs/source/concepts/demo_configs.rst
+++ b/docs/source/concepts/demo_configs.rst
@@ -1,3 +1,5 @@
+.. _dataset_and_dataloader_configuration:
+
Dataset and dataloader configuration
====================================
@@ -114,3 +116,132 @@ You can change parameters programmatically:
cfg.dataset.modality_config.screen.include_blanks = True
cfg.dataset.modality_config.screen.valid_condition = {"tier": "train"}
cfg.dataloader.num_workers = 8
+
+
+Configuration options
+^^^^^^^^^^^^^^^^^^^^^
+
+Dataset options
+"""""""""""""""
+
+``global_sampling_rate``
+ Override sampling rate for all modalities. Set to ``None`` to use
+ per-modality rates.
+
+``global_chunk_size``
+ Override chunk size (number of time steps/data points) for all modalities.
+ Set to ``None`` to use per-modality sizes.
+
+ The time window covered by a chunk is ``chunk_size / sampling_rate``, so
+ the ``global_sampling_rate`` should be taken into account:
+
+ - **With** ``global_sampling_rate`` set: all modalities share the same
+ output rate, so a single ``global_chunk_size`` unambiguously gives every
+ modality the same time window.
+ - **Without** ``global_sampling_rate`` (per-modality rates active):
+ different modalities have different rates, so the same sample count
+ produces different durations. In this case, leave ``global_chunk_size``
+ as ``None`` and set ``chunk_size`` per modality instead.
+
+``add_behavior_as_channels``
+ If ``True``, concatenate behavioral data (e.g., eye tracker, treadmill) as
+ additional channels to the screen data.
+
+``replace_nans_with_means``
+ If ``True``, replace NaN values with the mean of non-NaN values.
+
+``cache_data``
+ If ``True``, cache interpolated data in memory for faster access.
+
+``out_keys``
+ List of modality keys to include in the output dictionary.
+
+``normalize_timestamps``
+ If ``True``, normalize timestamps to start from 0.
+
+Modality options
+""""""""""""""""
+
+Each modality (e.g., screen, responses, eye_tracker, treadmill) supports:
+
+``keep_nans``
+ Whether to keep NaN values in the output.
+
+``sampling_rate``
+ Controls the spacing of the time points that the dataset constructs and
+ passes to :meth:`~experanto.experiment.Experiment.interpolate`. Concretely,
+ each item in the dataset requests values at times
+ ``start, start + 1/sampling_rate, start + 2/sampling_rate, …``. The
+ interpolator then interpolates the stored raw samples at those points.
+
+``chunk_size``
+ Number of **time steps/data points** returned per item for this modality.
+ Internally, ``sampling_rate`` defines the spacing of the time points passed
+ to the interpolator, so the covered time window is:
+
+ .. math::
+
+ \text{duration (s)} = \frac{\text{chunk\_size}}{\text{sampling\_rate}}
+
+ Note that ``sampling_rate`` here controls the *spacing* of the time points
+ requested from the underlying experiment (see ``sampling_rate`` above). The
+ native acquisition rate of the signal does not matter (the interpolator simply looks up the stored values closest to each requested time, e.g.).
+
+ When per-modality output rates differ, ``chunk_size`` must be set per
+ modality to cover the same time window. The default configuration keeps
+ all modalities at a 2-second window while using different output rates:
+
+ ============ ============= =========== ===========
+ Modality sampling_rate chunk_size Duration
+ ============ ============= =========== ===========
+ screen 30 Hz 60 2 s
+ eye_tracker 30 Hz 60 2 s
+ treadmill 30 Hz 60 2 s
+ responses 8 Hz 16 2 s
+ ============ ============= =========== ===========
+
+ If you unify all rates with ``global_sampling_rate``, use
+ ``global_chunk_size`` instead and this per-modality value is ignored.
+ In general: ``chunk_size = desired_duration_seconds * sampling_rate``.
+
+``offset``
+ Time offset in seconds applied to the time points constructed for this
+ modality. For example, if the screen is queried at times
+ ``[t, t + 1/sampling_rate, …]``, setting ``offset = 0.1`` on responses
+ means responses are queried at ``[t + 0.1, t + 0.1 + 1/sampling_rate, …]``.
+ Useful for aligning modalities with known temporal delays relative to the
+ screen stimulus.
+
+``transforms``
+ Dictionary of transforms to apply at the dataset level. This is modality
+ specific, i.e., not all modalities support the same set of transforms. Some
+ examples include ``"normalize"`` for sequences, such as eye_tracker,
+ and ``"standardize"`` for responses.
+
+ To understand how transforms are loaded and applied internally, refer to
+ :meth:`experanto.datasets.ChunkDataset.initialize_transforms`. If you need
+ to implement a custom transform, we recommend following the same pattern
+ used there. In particular, note how each entry in the ``transforms``
+ dictionary is checked and, when it is a config ``dict``, instantiated via
+ Hydra before being added to the transform pipeline.
+
+ You can point Experanto to any callable (function or class) by using
+ Hydra's ``_target_`` key, which triggers
+ `hydra.utils.instantiate `_
+ under the hood (e.g., ``_target_: my_package.my_module.MyTransform``).
+
+``interpolation``
+ Interpolation settings. This is modality specific, i.e., not all modalities
+ support the same set of interpolation methods. Some examples include
+ ``"rescale"`` for the screen and ``"interpolation_mode"`` (e.g.,
+ ``"nearest_neighbor"``) for sequences.
+
+``filters``
+ Dictionary of filter functions to apply to the data.
+
+Dataloader options
+""""""""""""""""""
+
+All standard ``torch.utils.data.DataLoader`` options are supported. See the
+`PyTorch DataLoader documentation `_
+for the full list of available parameters.
diff --git a/docs/source/concepts/demo_dataset.rst b/docs/source/concepts/demo_dataset.rst
index 90472d9..0825fef 100644
--- a/docs/source/concepts/demo_dataset.rst
+++ b/docs/source/concepts/demo_dataset.rst
@@ -1,17 +1,33 @@
-
.. _loading_dataset:
Loading a dataset object
========================
-Dataset objects organize experimental data (from **Experiment** class) for machine learning tasks, offering project-specific and configurable access for training and evaluation. They often serve as a source for creating **dataloaders**.
+Dataset objects organize experimental data (from the :class:`~experanto.experiment.Experiment` class) for machine learning tasks, offering project-specific and configurable access for training and evaluation. They often serve as a source for creating dataloaders (see :func:`~experanto.dataloaders.get_multisession_dataloader`).
+
+.. note::
+
+ The key distinction between :class:`~experanto.experiment.Experiment` and a
+ dataset object is one of **time discretization**.
+ :class:`~experanto.experiment.Experiment` is a low-level interface: you hand
+ it any array of time points and it returns values at those points via a
+ lookup into the raw stored data. :class:`~experanto.datasets.ChunkDataset`
+ is used on top of it and imposes a specific time structure. For each item,
+ it constructs a separate ``times`` array per modality using that modality's
+ configured ``sampling_rate`` and ``chunk_size`` (``times = start + np.arange(chunk_size) / sampling_rate``),
+ then calls :meth:`~experanto.experiment.Experiment.interpolate` for each
+ modality independently. This is how all modalities end up covering the same
+ time window with compatible shapes in a batch.
Key features of dataset objects
-------------------------------
Dataset objects provide several essential features:
-- **Sampling Rate**: Defines the frequency of equally spaced interpolation times across the entire experiment. This ensures consistency in temporal data alignment.
+- **Sampling Rate**: Defines the spacing of the time points that the dataset
+ constructs and hands to the underlying :class:`~experanto.experiment.Experiment`
+ for each item (``time_delta = 1 / sampling_rate``). The experiment then does
+ a lookup into the raw stored data at those points.
- **Chunk Size**: Determines the number of values returned when calling the ``__getitem__`` method. This is crucial, for example, for deep learning models that use 3D convolutions over time, where single elements or small chunk sizes are insufficient to capture meaningful temporal patterns.
- **Modality Configuration**: Specifies the details of the interpolation, including:
@@ -81,7 +97,7 @@ This will output something like:
Defining dataloaders
---------------------
-Once the dataset is verified, we can define **DataLoader** objects for training or other purposes. This allows easy batch processing during training:
+Once the dataset is verified, we can define `DataLoader `_ objects for training or other purposes. This allows easy batch processing during training:
.. code-block:: python
diff --git a/docs/source/concepts/demo_experiment.rst b/docs/source/concepts/demo_experiment.rst
index 8a83344..a4296ff 100644
--- a/docs/source/concepts/demo_experiment.rst
+++ b/docs/source/concepts/demo_experiment.rst
@@ -3,7 +3,20 @@
Loading a single experiment
===========================
-To load an experiment, we use the **Experiment** class. This is particularly useful for testing whether the formatting and interpolation behave as expected before loading multiple experiments into dataset objects.
+To load an experiment, we use the :class:`~experanto.experiment.Experiment`
+class. This class aggregates all modalities and their respective interpolators
+in a single object. Its main job is to unify the access to all modalities.
+
+:class:`~experanto.experiment.Experiment` accepts an arbitrary array of time
+points and returns the corresponding values for each modality by looking them
+up in the raw stored data (e.g., using nearest-neighbour or linear
+interpolation for sequences). When you need a regular sampling grid or
+fixed-length intervals (chunks) for training, a dataset object such as
+:class:`~experanto.datasets.ChunkDataset` should be used **on top of**
+:class:`~experanto.experiment.Experiment`. There you can define the time
+discretization (via ``sampling_rate`` and ``chunk_size``), construct the
+appropriate time points, and delegate the data retrieval to the
+underlying :class:`~experanto.experiment.Experiment`.
Loading an experiment
---------------------
@@ -34,8 +47,26 @@ All compatible modalities for the loaded experiment can be checked using:
Interpolating data
------------------
-Once the modalities are identified, we can interpolate their data.
-The following example interpolates a 20-second window with 2 frames per second, resulting in 40 images:
+Once the modalities are identified, we can interpolate their data using
+:meth:`~experanto.experiment.Experiment.interpolate`.
+
+:meth:`~experanto.experiment.Experiment.interpolate` accepts any 1-D array of
+time points and returns, for each modality, an array of shape
+``(len(times), n_signals)``. The number of returned points is always
+``len(times)``, regardless of the native acquisition rate of the modality.
+
+When you call :meth:`~experanto.experiment.Experiment.interpolate` **without**
+a ``device`` argument, every modality receives the *same* time array. This
+means modalities with low native rates can return repeated values for
+consecutive requested times that fall in the same native sample (nearest
+neighbour), while modalities with high native rates will effectively be
+sub-sampled (the behavior is interpolator-dependent). If you need different
+time densities per modality, call :meth:`~experanto.experiment.Experiment.interpolate`
+separately with a different ``times`` array for each ``device``. This is
+exactly what :class:`~experanto.datasets.ChunkDataset` does internally.
+
+The following example interpolates a 20-second window at 2 time points per
+second, resulting in 40 screen frames:
.. code-block:: python
diff --git a/docs/source/concepts/demo_multisession.rst b/docs/source/concepts/demo_multisession.rst
index 0c5675d..71a1b80 100644
--- a/docs/source/concepts/demo_multisession.rst
+++ b/docs/source/concepts/demo_multisession.rst
@@ -1,14 +1,14 @@
Loading multiple sessions
=========================
-To load multiple sessions at once, you can use the ``get_multisession_dataloader`` function from ``experanto.dataloaders``.
+To load multiple sessions at once, you can use :func:`~experanto.dataloaders.get_multisession_dataloader`.
This function takes:
- A list of paths pointing to your experiment directories
- A configuration dictionary, similar to the one used for loading a single dataset
-It returns a dictionary of ``MultiEpochsDataLoader`` objects, each corresponding to a session, loaded with the specified configurations.
+It returns a dictionary of :class:`~experanto.utils.MultiEpochsDataLoader` objects, each corresponding to a session, loaded with the specified configurations.
Example
-------
@@ -30,4 +30,4 @@ Example
# Load first two sessions
train_dl = get_multisession_dataloader(full_paths[:2], cfg)
-The returned ``train_dl`` is a dictionary containing two ``MultiEpochsDataLoader`` objects which can be used for training.
+The returned ``train_dl`` is a dictionary containing two :class:`~experanto.utils.MultiEpochsDataLoader` objects which can be used for training.
diff --git a/docs/source/concepts/installation.rst b/docs/source/concepts/installation.rst
index 7ab765f..4090a07 100644
--- a/docs/source/concepts/installation.rst
+++ b/docs/source/concepts/installation.rst
@@ -11,5 +11,5 @@ The package works on top of `jupyter/datascience-notebooks`, but the minimum req
to install the package, clone it into a local repository and run::
- pip -e install /path_to_folder/experanto
+ pip install -e /path/to/folder/experanto
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 3b658cc..5a60ef9 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -3,6 +3,12 @@
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
+import os
+import sys
+
+# Add the project root to the path so autodoc can find the modules
+sys.path.insert(0, os.path.abspath("../.."))
+
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
@@ -14,11 +20,79 @@
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
-extensions = []
+extensions = [
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.napoleon",
+ "sphinx.ext.mathjax",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.viewcode",
+]
+
+# -- Autosummary settings ----------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/extensions/autosummary.html
+
+autosummary_generate = True
+autosummary_generate_overwrite = True
+autosummary_imported_members = False
templates_path = ["_templates"]
exclude_patterns = []
+# -- Napoleon settings (NumPy-style docstrings) ------------------------------
+# https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html
+
+napoleon_google_docstring = False
+napoleon_numpy_docstring = True
+napoleon_include_init_with_doc = True
+napoleon_include_private_with_doc = False
+napoleon_include_special_with_doc = True
+napoleon_use_admonition_for_examples = False
+napoleon_use_admonition_for_notes = False
+napoleon_use_admonition_for_references = False
+napoleon_use_ivar = False
+napoleon_use_param = True
+napoleon_use_rtype = True
+napoleon_type_aliases = None
+
+# -- Autodoc settings --------------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html
+
+autodoc_default_options = {
+ "members": True,
+ "member-order": "bysource",
+ "special-members": "__init__",
+ "undoc-members": True,
+ "exclude-members": "__weakref__",
+}
+autodoc_typehints = "description"
+autodoc_typehints_description_target = "documented"
+
+# Mock heavy imports that aren't needed for documentation generation
+autodoc_mock_imports = [
+ "torch",
+ "torchvision",
+ "numpy",
+ "scipy",
+ "cv2",
+ "pandas",
+ "hydra",
+ "omegaconf",
+ "jaxtyping",
+ "plotly",
+ "optree",
+ "rootutils",
+]
+
+# -- Intersphinx settings (cross-references to external docs) ----------------
+# https://www.sphinx-doc.org/en/master/usage/extensions/intersphinx.html
+
+intersphinx_mapping = {
+ "python": ("https://docs.python.org/3", None),
+ "numpy": ("https://numpy.org/doc/stable/", None),
+ "scipy": ("https://docs.scipy.org/doc/scipy/", None),
+ "torch": ("https://pytorch.org/docs/stable/", None),
+}
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 5886f44..628db63 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -3,7 +3,7 @@
**Experanto** is a Python package designed for interpolating recordings and stimuli in neuroscience experiments. It enables users to load single or multiple experiments and create efficient dataloaders for machine learning applications.
-Issues with the package can be submited at our `GitHub Issues page `_.
+Issues with the package can be submitted at our `GitHub Issues page `_.
------------
@@ -17,10 +17,16 @@ Issues with the package can be submited at our `GitHub Issues page >> from experanto.configs import DEFAULT_CONFIG
+>>> from omegaconf import OmegaConf
+>>> print(OmegaConf.to_yaml(DEFAULT_CONFIG))
+
+Customizing configuration:
+
+>>> from experanto.configs import DEFAULT_MODALITY_CONFIG
+>>> cfg = DEFAULT_MODALITY_CONFIG.copy()
+>>> cfg.screen.sampling_rate = 60
+>>> cfg.responses.chunk_size = 32
+
+See Also
+--------
+ChunkDataset : Uses these configurations for data loading.
+"""
+
from pathlib import Path
from hydra import compose, initialize, initialize_config_dir
diff --git a/experanto/dataloaders.py b/experanto/dataloaders.py
index 37f344e..f891bd5 100644
--- a/experanto/dataloaders.py
+++ b/experanto/dataloaders.py
@@ -26,15 +26,46 @@ def get_multisession_dataloader(
shuffle_keys: bool = False,
**kwargs,
) -> LongCycler:
- """
- Create a multisession dataloader from a list of paths and corresponding configs.
- Args:
- paths (List[str]): List of paths to the datasets.
- configs (Union[DictConfig, Dict, List[Union[DictConfig, Dict]]]): Configuration for each dataset.
- If a single config is provided, it will be applied to all datasets.
- If a list is provided, it should match the length of paths.
- shuffle_keys (bool): Whether to shuffle the keys of the dataloaders.
- **kwargs: Additional keyword arguments for dataset and dataloader configuration.
+ """Create a multi-session dataloader from multiple experiment paths.
+
+ By default, creates a :class:`ChunkDataset` for each path and wraps them in
+ a :class:`LongCycler` that yields ``(session_key, batch)`` pairs.
+ The cycler continues until the longest session is exhausted.
+
+ Parameters
+ ----------
+ paths : list of str
+ Paths to experiment directories.
+ configs : dict, DictConfig, list, optional
+ Configuration for each dataset. If a single config is provided,
+ it will be applied to all datasets. If a list is provided, it
+ should match the length of ``paths``. Each config should have
+ ``dataset`` and ``dataloader`` keys.
+ shuffle_keys : bool, default=False
+ Whether to shuffle the order of session keys.
+ **kwargs
+ Additional keyword arguments. Supports ``config`` as an alias
+ for ``configs``.
+
+ Returns
+ -------
+ LongCycler
+ A dataloader-like object that yields ``(session_key, batch)`` tuples.
+ Iterates until the longest session is exhausted.
+
+ See Also
+ --------
+ get_multisession_concat_dataloader : Alternative that concatenates sessions.
+ LongCycler : The underlying multi-session iterator.
+
+ Examples
+ --------
+ >>> from experanto.dataloaders import get_multisession_dataloader
+ >>> from experanto.configs import DEFAULT_CONFIG
+ >>> paths = ['/path/to/exp1', '/path/to/exp2']
+ >>> loader = get_multisession_dataloader(paths, configs=DEFAULT_CONFIG)
+ >>> for session_key, batch in loader:
+ ... print(f"Session: {session_key}, batch shape: {batch['responses'].shape}")
"""
if configs is None and "config" in kwargs:
@@ -70,20 +101,51 @@ def get_multisession_concat_dataloader(
dataloader_config: Optional[Dict] = None,
**kwargs,
) -> Optional["FastSessionDataLoader"]:
- """
- Creates a multi-session dataloader using SessionConcatDataset and SessionDataLoader.
- Returns (session_key, batch) pairs during iteration.
-
- Args:
- paths: List of paths to dataset files
- configs: Configuration for datasets (single config or list of configs)
- seed: Random seed for reproducibility
- num_workers: Number of worker processes for data loading
- prefetch_factor: Prefetch factor for data loading
- **kwargs: Additional arguments
-
- Returns:
- SessionDataLoader instance or None if no valid datasets found
+ """Create a concatenated multi-session dataloader.
+
+ Unlike :func:`get_multisession_dataloader`, this function concatenates
+ all sessions into a single dataset and uses batch sampling to ensure
+ each batch contains samples from only one session. This is more
+ memory-efficient and provides better shuffling across sessions.
+
+ Parameters
+ ----------
+ paths : list of str
+ Paths to experiment directories.
+ configs : dict or list of dict, optional
+ Configuration for each dataset. If a single config is provided,
+ it will be applied to all datasets. Each config should have
+ ``dataset`` and ``dataloader`` keys.
+ seed : int, default=0
+ Random seed for reproducibility. Each dataset gets a deterministic
+ seed derived from this value and its path hash.
+ dataloader_config : dict, optional
+ Configuration for the dataloader (batch_size, num_workers, etc.).
+ If None, uses the dataloader config from the first config.
+ **kwargs
+ Additional keyword arguments. Supports ``config`` as an alias
+ for ``configs``.
+
+ Returns
+ -------
+ FastSessionDataLoader or None
+ A dataloader that yields ``(session_key, batch)`` tuples.
+ Returns None if no valid datasets could be created.
+
+ See Also
+ --------
+ get_multisession_dataloader : Alternative using separate dataloaders.
+ FastSessionDataLoader : The underlying dataloader implementation.
+ SessionConcatDataset : Dataset that concatenates multiple sessions.
+
+ Examples
+ --------
+ >>> from experanto.dataloaders import get_multisession_concat_dataloader
+ >>> from experanto.configs import DEFAULT_CONFIG
+ >>> paths = ['/path/to/exp1', '/path/to/exp2', '/path/to/exp3']
+ >>> loader = get_multisession_concat_dataloader(paths, configs=DEFAULT_CONFIG)
+ >>> for session_key, batch in loader:
+ ... print(f"Session: {session_key}")
"""
if configs is None and "config" in kwargs:
configs = kwargs.pop("config")
@@ -98,7 +160,6 @@ def get_multisession_concat_dataloader(
start_time = time.time()
for i, (path, cfg) in enumerate(zip(paths, configs)):
-
# Create dataset with deterministic seed
path_hash = hash(path) % 10000
dataset_seed = seed + path_hash if seed is not None else None
diff --git a/experanto/datasets.py b/experanto/datasets.py
index 1045263..7044749 100644
--- a/experanto/datasets.py
+++ b/experanto/datasets.py
@@ -71,26 +71,77 @@ def __getitem__(self, idx):
class ChunkDataset(Dataset):
- def __init__(
- self,
- root_folder: str,
- global_sampling_rate: Optional[float] = None,
- global_chunk_size: Optional[int] = None,
- add_behavior_as_channels: bool = False,
- replace_nans_with_means: bool = False,
- cache_data: bool = False,
- out_keys: Optional[Iterable] = None,
- normalize_timestamps: bool = True,
- modality_config: dict = DEFAULT_MODALITY_CONFIG,
- seed: Optional[int] = None,
- safe_interval_threshold: float = 0.5,
- interpolate_precision: int = 5,
- ) -> None:
- """
- interpolate_precision: number of digits after the dot to keep, without it we might get different numbers from interpolation
-
- The full modality config is a nested dictionary.
- The following is an example of a modality config for a screen, responses, eye_tracker, and treadmill:
+ """PyTorch Dataset for chunked experiment data.
+
+ This dataset loads an experiment and provides temporally-chunked samples
+ suitable for training neural networks. Each sample contains synchronized
+ data from all modalities (e.g., screen, responses, eye_tracker, treadmill)
+ at a given time window.
+
+ Parameters
+ ----------
+ root_folder : str
+ Path to the experiment directory containing modality subfolders.
+ global_sampling_rate : float, optional
+ Sampling rate (Hz) applied to all modalities. If None, uses
+ per-modality rates from ``modality_config``.
+ global_chunk_size : int, optional
+ Number of samples per chunk for all modalities. If None, uses
+ per-modality sizes from ``modality_config``.
+ add_behavior_as_channels : bool, default=False
+ If True, concatenates behavioral data as additional image channels.
+ Deprecated: use separate modality outputs instead.
+ replace_nans_with_means : bool, default=False
+ If True, replaces NaN values with column means.
+ cache_data : bool, default=False
+ If True, keeps loaded data in memory for faster access.
+ out_keys : iterable, optional
+ Which modalities to include in output. Defaults to all modalities
+ plus 'timestamps'.
+ normalize_timestamps : bool, default=True
+ If True, normalizes timestamps relative to recording start.
+ modality_config : dict
+ Configuration for each modality including sampling rates, transforms,
+ filters, and interpolation settings. See Notes for structure.
+ seed : int, optional
+ Random seed for reproducible shuffling of valid time points.
+ safe_interval_threshold : float, default=0.5
+ Safety margin (in seconds) to exclude from edges of valid intervals.
+ interpolate_precision : int, default=5
+ Number of decimal places for time precision. Prevents floating-point
+ accumulation errors during interpolation.
+
+
+ Attributes
+ ----------
+ data_key : str
+ Unique identifier for this dataset, extracted from metadata.
+ device_names : tuple
+ Names of loaded modalities.
+ start_time : float
+ Start of valid time range (after applying safety threshold).
+ end_time : float
+ End of valid time range (after applying safety threshold).
+
+ See Also
+ --------
+ Experiment : Lower-level interface for data access.
+ get_multisession_dataloader : Load multiple datasets.
+ experanto.configs : Default configuration values.
+
+ Notes
+ -----
+ The dataset handles:
+
+ - Per-modality sampling rates and chunk sizes
+ - Time offset alignment between modalities
+ - Data normalization and transforms
+ - Filtering based on trial conditions and data quality
+ - Reproducible random sampling with seeds
+
+ The ``modality_config`` is a nested dictionary with per-modality settings. The following is an example of a modality config for a screen, responses, eye_tracker, and treadmill:
+
+ .. code-block:: yaml
screen:
sampling_rate: null
@@ -141,7 +192,41 @@ def __init__(
normalize: true
interpolation:
interpolation_mode: nearest_neighbor
- """
+
+ Examples
+ --------
+ >>> from experanto.datasets import ChunkDataset
+ >>> from experanto.configs import DEFAULT_MODALITY_CONFIG
+ >>> dataset = ChunkDataset(
+ ... '/path/to/experiment',
+ ... global_sampling_rate=30,
+ ... global_chunk_size=60,
+ ... modality_config=DEFAULT_MODALITY_CONFIG,
+ ... )
+ >>> len(dataset)
+ 1000
+ >>> sample = dataset[0]
+ >>> sample['screen'].shape
+ torch.Size([1, 60, 144, 256])
+ >>> sample['responses'].shape
+ torch.Size([16, 500])
+ """
+
+ def __init__(
+ self,
+ root_folder: str,
+ global_sampling_rate: Optional[float] = None,
+ global_chunk_size: Optional[int] = None,
+ add_behavior_as_channels: bool = False,
+ replace_nans_with_means: bool = False,
+ cache_data: bool = False,
+ out_keys: Optional[Iterable] = None,
+ normalize_timestamps: bool = True,
+ modality_config: dict = DEFAULT_MODALITY_CONFIG,
+ seed: Optional[int] = None,
+ safe_interval_threshold: float = 0.5,
+ interpolate_precision: int = 5,
+ ) -> None:
self.root_folder = Path(root_folder)
self.data_key = self.get_data_key_from_root_folder(root_folder)
self.interpolate_precision = interpolate_precision
@@ -234,10 +319,10 @@ def _read_trials(self) -> None:
]
def initialize_statistics(self) -> None:
- """
- Initializes the statistics for each device based on the modality config.
- :return:
- instantiates self._statistics with the mean and std for each device
+ """Initialize normalization statistics for each device.
+
+ Loads mean and standard deviation values from each device's meta folder
+ and stores them in ``self._statistics`` for use during data transforms.
"""
self._statistics = {}
for device_name in self.device_names:
@@ -298,10 +383,7 @@ def add_channel_function(x):
return torch.from_numpy(x)
def initialize_transforms(self):
- """
- Initializes the transforms for each device based on the modality config.
- :return:
- """
+ """Initialize data transforms for each device based on modality config."""
transforms = {}
for device_name in self.device_names:
if device_name == "screen":
@@ -316,7 +398,6 @@ def initialize_transforms(self):
transform_list.insert(0, add_channel)
else:
-
transform_list: List[Any] = [ToTensor()]
# Normalization.
@@ -332,15 +413,22 @@ def initialize_transforms(self):
return transforms
def _get_callable_filter(self, filter_config):
- """
- Helper function to get a callable filter function from either a config or an already instantiated callable.
+ """Return a callable filter function from config or an existing callable.
+
+ Notes
+ -----
Handles partial instantiation using hydra.utils.instantiate.
- Args:
- filter_config: Either a config dict/DictConfig or a callable function
+ Parameters
+ ----------
+ filter_config : dict, DictConfig, or callable
+ Either a config dictionary/DictConfig specifying a filter (with
+ '__target__'), or an already-instantiated callable filter function.
- Returns:
- callable: The final filter function ready to be called with device_
+ Returns
+ -------
+ callable
+ The final filter function ready to be called with `device_`.
"""
# Check if it's already a callable (function)
if callable(filter_config):
@@ -412,16 +500,25 @@ def get_valid_intervals_from_filters(
def get_condition_mask_from_meta_conditions(
self, valid_conditions_sum_of_product: List[dict]
) -> np.ndarray:
- """Creates a boolean mask for trials that satisfy any of the given condition combinations.
-
- Args:
- valid_conditions_sum_of_product: List of dictionaries, where each dictionary represents a set of
- conditions that should be satisfied together (AND). Multiple dictionaries are combined with OR.
- Example: [{'tier': 'train', 'stim_type': 'natural'}, {'tier': 'blank'}] matches trials that
- are either (train AND natural) OR blank.
-
- Returns:
- np.ndarray: Boolean mask indicating which trials satisfy at least one set of conditions.
+ """Create a boolean mask for trials satisfying given conditions.
+
+ Parameters
+ ----------
+ valid_conditions_sum_of_product : list of dict
+ Condition dictionaries combined with OR logic, where conditions
+ within each dictionary use AND logic.
+
+ Returns
+ -------
+ np.ndarray
+ Boolean mask indicating which trials satisfy at least one set of
+ conditions.
+
+ Notes
+ -----
+ For example,
+ ``[{'tier': 'train', 'stim_type': 'natural'}, {'tier': 'blank'}]``
+ matches trials that are either (train AND natural) OR blank.
"""
all_conditions: Optional[np.ndarray] = None
for valid_conditions_product in valid_conditions_sum_of_product:
@@ -449,15 +546,22 @@ def get_screen_sample_mask_from_meta_conditions(
valid_conditions_sum_of_product: List[dict],
filter_for_valid_intervals: bool = True,
) -> np.ndarray:
- """Creates a boolean mask indicating which screen samples satisfy the given conditions.
-
- Args:
- satisfy_for_next: Number of consecutive samples that must satisfy conditions
- valid_conditions_sum_of_product: List of condition dictionaries combined with OR logic,
- where conditions within each dictionary use AND logic
-
- Returns:
- Boolean array matching screen sample times, True where conditions are met
+ """Create a boolean mask for screen samples satisfying given conditions.
+
+ Parameters
+ ----------
+ satisfy_for_next : int
+ Number of consecutive samples that must satisfy conditions.
+ valid_conditions_sum_of_product : list of dict
+ Condition dictionaries combined with OR logic, where conditions
+ within each dictionary use AND logic.
+ filter_for_valid_intervals : bool, default=True
+ Whether to apply interval-based filtering.
+
+ Returns
+ -------
+ numpy.ndarray
+ Boolean array matching screen sample times, True where conditions are met.
"""
all_conditions = self.get_condition_mask_from_meta_conditions(
valid_conditions_sum_of_product
@@ -508,12 +612,21 @@ def get_screen_sample_mask_from_meta_conditions(
def get_full_valid_sample_times(
self, filter_for_valid_intervals: bool = True
) -> np.ndarray:
- """
- iterates through all sample times and checks if they could be used as
- start times, eg if the next `self.chunk_sizes["screen"]` points are still valid
- based on the previous meta condition filtering
- :returns:
- valid_times: np.array of valid starting points
+ """Get all valid chunk starting times based on meta conditions.
+
+ Iterates through sample times and checks if they can be used as chunk
+ start times (i.e., the next ``chunk_size`` points are all valid based
+ on the previous meta condition filtering).
+
+ Parameters
+ ----------
+ filter_for_valid_intervals : bool, default=True
+ Whether to apply interval-based filtering.
+
+ Returns
+ -------
+ numpy.ndarray
+ Array of valid starting time points.
"""
# Calculate all possible end indices
@@ -565,14 +678,20 @@ def shuffle_valid_screen_times(self) -> None:
)
def get_data_key_from_root_folder(self, root_folder):
- """
- Extract a data key from the root folder path by checking for a meta.json file.
+ """Extract a data key from the root folder path.
+
+ Checks for a meta.json file and extracts the data_key or scan_key.
- Args:
- root_folder (str or Path): Path to the root folder containing dataset
+ Parameters
+ ----------
+ root_folder : str or Path
+ Path to the root folder containing the dataset.
- Returns:
- str: The extracted data key or folder name if meta.json doesn't exist or lacks data_key
+ Returns
+ -------
+ str
+ The extracted data key, or folder name if meta.json doesn't
+ exist or lacks data_key.
"""
# Convert Path object to string if necessary
root_folder = str(root_folder)
@@ -617,7 +736,28 @@ def get_data_key_from_root_folder(self, root_folder):
def __len__(self):
return len(self._valid_screen_times)
- def __getitem__(self, idx) -> dict:
+ def __getitem__(self, idx: int) -> dict:
+ """Return a single data sample at the given index.
+
+ Parameters
+ ----------
+ idx : int
+ Index of the sample to retrieve.
+
+ Returns
+ -------
+ dict
+ Dictionary containing data for each modality in ``out_keys``, e.g.:
+
+ - ``'screen'``: torch.Tensor of shape ``(C, T, H, W)``
+ - ``'responses'``: torch.Tensor of shape ``(T, N_neurons)``
+ - ``'eye_tracker'``: torch.Tensor of shape ``(T, N_features)``
+ - ``'treadmill'``: torch.Tensor of shape ``(T, N_features)``
+ - ``'timestamps'``: dict mapping modality names to time arrays
+
+ Where ``T`` is the chunk size (may differ per modality),
+ ``C`` is channels, ``H`` is height, ``W`` is width.
+ """
out = {}
timestamps = {}
s = self._valid_screen_times[idx]
diff --git a/experanto/experiment.py b/experanto/experiment.py
index e9df4cf..552737a 100644
--- a/experanto/experiment.py
+++ b/experanto/experiment.py
@@ -18,18 +18,56 @@
class Experiment:
+ """High-level interface for loading and querying neuroscience experiments.
+
+ An Experiment represents a single recording session containing multiple
+ modalities (e.g., visual stimuli, neural responses, behavioral data).
+ Each modality is loaded as an Interpolator, allowing data to be queried
+ at arbitrary time points.
+
+ Parameters
+ ----------
+ root_folder : str
+ Path to the experiment directory. Should contain subdirectories
+ for each modality (e.g., ``screen/``, ``responses/``, ``eye_tracker/``).
+ modality_config : dict, optional
+ Configuration dictionary specifying interpolation and processing
+ settings for each modality. See :mod:`experanto.configs` for the
+ default configuration structure.
+ cache_data : bool, default=False
+ If True, loads all trial data into memory for faster access.
+ Useful for smaller datasets or when memory is not a constraint.
+
+ Attributes
+ ----------
+ devices : dict
+ Dictionary mapping device names to their :class:`Interpolator` instances.
+ start_time : float
+ Earliest valid timestamp across all devices.
+ end_time : float
+ Latest valid timestamp across all devices.
+
+ See Also
+ --------
+ ChunkDataset : Higher-level interface for ML training.
+ Interpolator : Base class for modality-specific interpolators.
+
+ Examples
+ --------
+ >>> from experanto.experiment import Experiment
+ >>> exp = Experiment('/path/to/experiment')
+ >>> exp.device_names
+ ('screen', 'responses', 'eye_tracker')
+ >>> times = np.linspace(0, 10, 100)
+ >>> data, valid = exp.interpolate(times, device='responses')
+ """
+
def __init__(
self,
root_folder: str,
modality_config: dict = DEFAULT_MODALITY_CONFIG,
cache_data: bool = False,
) -> None:
- """
- root_folder: path to the data folder
- interp_config: dict for configuring interpolators, like
- interp_config = {"screen": {...}, {"eye_tracker": {...}, }
- cache_data: if True, loads and keeps all trial data in memory
- """
self.root_folder = Path(root_folder)
self.devices = dict()
self.start_time = np.inf
@@ -73,7 +111,8 @@ def _load_devices(self) -> None:
else:
# Default back to original logic
warnings.warn(
- "Falling back to original Interpolator creation logic.", UserWarning
+ "Falling back to original Interpolator creation logic.",
+ UserWarning,
)
dev = Interpolator.create(d, cache_data=self.cache_data, **interp_conf) # type: ignore[arg-type]
@@ -92,6 +131,39 @@ def interpolate(
device: Union[str, Interpolator, None] = None,
return_valid: bool = False,
) -> Union[tuple[dict, dict], dict, tuple[np.ndarray, np.ndarray], np.ndarray]:
+ """Interpolate data from one or all devices at specified time points.
+
+ Parameters
+ ----------
+ times : array_like
+ 1D array of time points (in seconds) at which to interpolate.
+ device : str, optional
+ Name of a specific device to interpolate. If None, interpolates
+ all devices and returns dictionaries.
+
+ Returns
+ -------
+ values : numpy.ndarray or dict
+ If ``device`` is specified, returns the interpolated data array.
+ Otherwise, returns a dict mapping device names to their data arrays.
+ valid : numpy.ndarray or dict
+ Boolean mask(s) indicating which time points were valid.
+ Same structure as ``values``.
+
+ Examples
+ --------
+ Interpolate a single device:
+
+ >>> data, valid = exp.interpolate(times, device='responses')
+ >>> data.shape
+ (100, 500) # 100 time points, 500 neurons
+
+ Interpolate all devices:
+
+ >>> data, valid = exp.interpolate(times)
+ >>> data.keys()
+ dict_keys(['screen', 'responses', 'eye_tracker'])
+ """
if device is None:
values = {}
valid = {}
@@ -114,5 +186,18 @@ def interpolate(
else:
raise ValueError(f"Unsupported device type: {type(device)}")
- def get_valid_range(self, device_name) -> tuple[float, float]:
+ def get_valid_range(self, device_name: str) -> tuple[float, float]:
+ """Get the valid time range for a specific device.
+
+ Parameters
+ ----------
+ device_name : str
+ Name of the device (e.g., 'screen', 'responses').
+
+ Returns
+ -------
+ tuple
+ A tuple ``(start_time, end_time)`` representing the valid
+ time interval in seconds.
+ """
return tuple(self.devices[device_name].valid_interval)
diff --git a/experanto/filters/common_filters.py b/experanto/filters/common_filters.py
index 137ef4a..d74ef34 100644
--- a/experanto/filters/common_filters.py
+++ b/experanto/filters/common_filters.py
@@ -9,33 +9,48 @@
def nan_filter(vicinity=0.05):
+ """Create a filter that excludes time regions around NaN values.
+
+ Returns a closure that, given a :class:`~experanto.interpolators.SequenceInterpolator`,
+ identifies all time points containing NaN in any channel and marks a
+ symmetric window of ``vicinity`` seconds around each as invalid.
+
+ Parameters
+ ----------
+ vicinity : float, optional
+ Half-width of the exclusion window in seconds around each NaN
+ time point. Default is 0.05.
+
+ Returns
+ -------
+ callable
+ A function that takes a
+ :class:`~experanto.interpolators.SequenceInterpolator` and returns
+ a list of :class:`~experanto.intervals.TimeInterval` representing
+ the valid (NaN-free) portions of the recording.
+ """
+
def implementation(device_: SequenceInterpolator):
- # requests SequenceInterpolator as uses time_delta internally
- # and other interpolators don't have it
+ # Requires a SequenceInterpolator since it relies on time_delta,
+ # which other interpolator types do not expose.
time_delta = device_.time_delta
start_time = device_.start_time
end_time = device_.end_time
- data = device_._data # (T, n_neurons)
+ data = device_._data # (T, n_features)
- # detect nans
- nan_mask = np.isnan(data) # (T, n_neurons)
+ nan_mask = np.isnan(data) # (T, n_features)
nan_mask = np.any(nan_mask, axis=1) # (T,)
-
- # Find indices where nan_mask is True
nan_indices = np.where(nan_mask)[0]
- # Create invalid TimeIntervals around each nan point
invalid_intervals = []
- vicinity_seconds = vicinity # vicinity is already in seconds
for idx in nan_indices:
time_point = start_time + idx * time_delta
- interval_start = max(start_time, time_point - vicinity_seconds)
- interval_end = min(end_time, time_point + vicinity_seconds)
+ interval_start = max(start_time, time_point - vicinity)
+ interval_end = min(end_time, time_point + vicinity)
invalid_intervals.append(TimeInterval(interval_start, interval_end))
# Merge overlapping invalid intervals
invalid_intervals = uniquefy_interval_array(invalid_intervals)
-
# Find the complement of invalid intervals to get valid intervals
valid_intervals = find_complement_of_interval_array(
start_time, end_time, invalid_intervals
diff --git a/experanto/interpolators.py b/experanto/interpolators.py
index 48ecaa2..cc20938 100644
--- a/experanto/interpolators.py
+++ b/experanto/interpolators.py
@@ -18,6 +18,40 @@
class Interpolator:
+ """Abstract base class for time series interpolation.
+
+ Interpolators load data from a modality folder and map time points to
+ data values. Each modality (e.g., screen, responses, eye_tracker,
+ treadmill) is assigned to a separate interpolator object belonging to
+ one of the Interpolator subclasses (e.g., SequenceInterpolator,
+ ScreenInterpolator, etc.), but multiple modalities can belong to the same
+ class, such as treadmill and eye_tracker both being assigned to the
+ SequenceInterpolator subclass.
+
+ Parameters
+ ----------
+ root_folder : str
+ Path to the modality directory containing data and metadata files.
+
+ Attributes
+ ----------
+ root_folder : pathlib.Path
+ Path to the modality directory.
+ start_time : float
+ Earliest timestamp in the data.
+ end_time : float
+ Latest timestamp in the data.
+ valid_interval : TimeInterval
+ Time range for which interpolation is valid.
+
+ See Also
+ --------
+ SequenceInterpolator : For time series data (responses, behaviors).
+ ScreenInterpolator : For visual stimuli (images, videos).
+ TimeIntervalInterpolator : For labeled time intervals (e.g., train/test splits).
+ Experiment : High-level interface that manages multiple interpolators.
+ """
+
def __init__(self, root_folder: str) -> None:
self.root_folder = Path(root_folder)
self.start_time = None
@@ -34,8 +68,8 @@ def load_meta(self):
def interpolate(
self, times: np.ndarray, return_valid: bool = False
) -> Union[tuple[np.ndarray, np.ndarray], np.ndarray]:
+ """Map an array of time points to interpolated data values."""
...
- # returns interpolated signal and boolean mask of valid samples
def __contains__(self, times: np.ndarray):
return np.any(self.valid_times(times))
@@ -48,6 +82,30 @@ def __exit__(self, *exc):
@staticmethod
def create(root_folder: str, cache_data: bool = False, **kwargs) -> "Interpolator":
+ """Factory method to create the appropriate interpolator for a modality.
+
+ Reads the ``meta.yml`` file in the folder to determine the modality type
+ and instantiates the corresponding interpolator subclass.
+
+ Parameters
+ ----------
+ root_folder : str
+ Path to the modality directory.
+ cache_data : bool, default=False
+ If True, loads all data into memory for faster access.
+ **kwargs
+ Additional arguments passed to the interpolator constructor.
+
+ Returns
+ -------
+ Interpolator
+ An instance of the appropriate interpolator subclass.
+
+ Raises
+ ------
+ ValueError
+ If the modality type is not supported.
+ """
with open(Path(root_folder) / "meta.yml", "r") as file:
meta_data = yaml.safe_load(file)
modality = meta_data.get("modality")
@@ -79,6 +137,50 @@ def close(self):
class SequenceInterpolator(Interpolator):
+ """Interpolator for 1D time series data (neural responses, behaviors).
+
+ Handles regularly-sampled time series stored as memory-mapped or NumPy
+ arrays. Supports nearest-neighbor and linear interpolation modes.
+
+ Parameters
+ ----------
+ root_folder : str
+ Path to the modality directory containing ``data.mem`` or ``data.npy``.
+ cache_data : bool, default=False
+ If True, loads memory-mapped data into RAM for faster access.
+ keep_nans : bool, default=False
+ If False, replaces NaN values with column means during interpolation.
+ interpolation_mode : str, default='nearest_neighbor'
+ Interpolation method: ``'nearest_neighbor'`` or ``'linear'``.
+ normalize : bool, default=False
+ If True, normalizes data using stored mean/std statistics.
+ normalize_subtract_mean : bool, default=False
+ If True, subtracts mean during normalization.
+ normalize_std_threshold : float, optional
+ Minimum std threshold to prevent division by near-zero values.
+ **kwargs
+ Additional keyword arguments (ignored).
+
+ Attributes
+ ----------
+ sampling_rate : float
+ Original sampling rate of the data in Hz.
+ time_delta : float
+ Time between samples (1 / sampling_rate).
+ n_signals : int
+ Number of signals (e.g., neurons, behavior channels).
+
+ Notes
+ -----
+ For linear interpolation, values are computed as:
+
+ .. math::
+
+ y(t) = y_0 \\cdot \\frac{t_1 - t}{t_1 - t_0} + y_1 \\cdot \\frac{t - t_0}{t_1 - t_0},
+
+ where :math:`t_0` and :math:`t_1` are the surrounding sample times.
+ """
+
def __init__(
self,
root_folder: str,
@@ -90,10 +192,6 @@ def __init__(
normalize_std_threshold: typing.Optional[float] = None, # or 0.01
**kwargs,
) -> None:
- """
- interpolation_mode - nearest neighbor or linear
- keep_nans - if we keep nans in linear interpolation
- """
super().__init__(root_folder)
meta = self.load_meta()
self.keep_nans = keep_nans
@@ -228,6 +326,25 @@ def close(self) -> None:
class PhaseShiftedSequenceInterpolator(SequenceInterpolator):
+ """Sequence interpolator with per-signal phase shifts.
+
+ Extends :class:`SequenceInterpolator` to handle signals recorded with
+ different phase offsets (e.g., neurons with different response latencies).
+ Each signal is interpolated at its own phase-shifted time.
+
+ Parameters
+ ----------
+ root_folder : str
+ Path to the modality directory. Must contain ``meta/phase_shifts.npy``.
+ **kwargs
+ All parameters from :class:`SequenceInterpolator`.
+
+ Attributes
+ ----------
+ _phase_shifts : numpy.ndarray
+ Per-signal phase shift values in seconds.
+ """
+
def __init__(
self,
root_folder: str,
@@ -333,6 +450,43 @@ def interpolate(
class ScreenInterpolator(Interpolator):
+ """Interpolator for visual stimuli (images and videos).
+
+ Handles frame-based visual data organized as trials. Each trial can be
+ a single image, a video sequence, or a blank screen. Frames are indexed
+ by timestamp and retrieved on demand.
+
+ Parameters
+ ----------
+ root_folder : str
+ Path to the screen modality directory containing ``timestamps.npy``,
+ ``data/`` folder with trial files, and ``meta/`` folder with metadata.
+ cache_data : bool, default=False
+ If True, loads all trial data into memory for faster access.
+ rescale : bool, default=False
+ If True, rescales frames to ``rescale_size``.
+ rescale_size : tuple of int, optional
+ Target size ``(height, width)`` for rescaling. If None, uses the
+ native image size from metadata.
+ normalize : bool, default=False
+ If True, normalizes frames using stored mean/std statistics.
+ **kwargs
+ Additional keyword arguments (ignored).
+
+ Attributes
+ ----------
+ timestamps : numpy.ndarray
+ Array of frame timestamps.
+ trials : list of ScreenTrial
+ List of trial objects containing frame data.
+
+ See Also
+ --------
+ ImageTrial : Single-frame stimuli.
+ VideoTrial : Multi-frame video stimuli.
+ BlankTrial : Blank/gray screen stimuli.
+ """
+
def __init__(
self,
root_folder: str,
@@ -342,10 +496,6 @@ def __init__(
normalize: bool = False,
**kwargs,
) -> None:
- """
- rescale would rescale images to the _image_size if true
- cache_data: if True, loads and keeps all trial data in memory
- """
super().__init__(root_folder)
self.timestamps = np.load(self.root_folder / "timestamps.npy")
self.start_time = self.timestamps[0]
@@ -488,9 +638,17 @@ def interpolate(
return (out, valid) if return_valid else out
def rescale_frame(self, frame: np.ndarray) -> np.ndarray:
- """
- Changes the resolution of the image to this size.
- Returns: Rescaled image
+ """Rescale frame to the configured image size.
+
+ Parameters
+ ----------
+ frame : np.ndarray
+ Input image frame.
+
+ Returns
+ -------
+ np.ndarray
+ Rescaled image as float32.
"""
return cv2.resize(frame, self._image_size, interpolation=cv2.INTER_AREA).astype(
np.float32
@@ -498,6 +656,48 @@ def rescale_frame(self, frame: np.ndarray) -> np.ndarray:
class TimeIntervalInterpolator(Interpolator):
+ """Interpolator for labeled time intervals.
+
+ Maps time points to boolean membership in labeled intervals. Given a
+ set of time points, returns a boolean array indicating whether each
+ point falls within any interval for each label.
+
+ Labels and their intervals are defined in the ``meta.yml`` file under
+ the ``labels`` key. Each label points to a ``.npy`` file containing an
+ array of shape ``(n, 2)``, where each row is a ``[start, end)``
+ half-open time interval. Typical labels include ``'train'``,
+ ``'validation'``, ``'test'``, ``'saccade'``, ``'gaze'``, or
+ ``'target'``.
+
+ The half-open convention means a timestamp *t* is considered inside an
+ interval when ``start <= t < end``. Intervals where ``start > end``
+ are treated as invalid and trigger a warning.
+
+ Parameters
+ ----------
+ root_folder : str
+ Path to the modality directory containing ``meta.yml`` and the
+ ``.npy`` interval files referenced by its ``labels`` mapping.
+ cache_data : bool, default=False
+ If True, loads all interval arrays into memory at init time.
+ **kwargs
+ Additional keyword arguments (ignored).
+
+ Attributes
+ ----------
+ meta_labels : dict
+ Mapping from label names to ``.npy`` filenames.
+
+ Notes
+ -----
+ - Only time points within the valid interval (as defined by
+ ``start_time`` and ``end_time`` in ``meta.yml``) are considered;
+ others are filtered out.
+ - The ``interpolate`` method returns an array of shape
+ ``(n_valid_times, n_labels)`` where ``out[i, j]`` is True if the
+ *i*-th valid time falls within any interval for the *j*-th label.
+ """
+
def __init__(self, root_folder: str, cache_data: bool = False, **kwargs):
super().__init__(root_folder)
self.cache_data = cache_data
@@ -517,41 +717,6 @@ def __init__(self, root_folder: str, cache_data: bool = False, **kwargs):
def interpolate(
self, times: np.ndarray, return_valid: bool = False
) -> Union[tuple[np.ndarray, np.ndarray], np.ndarray]:
- """
- Interpolate time intervals for labeled events.
-
- Given a set of time points and a set of labeled intervals (defined in the
- `meta.yml` file), this method returns a boolean array indicating, for each
- time point, whether it falls within any interval for each label.
-
- The method uses half-open intervals [start, end), where a timestamp t is
- considered to fall within an interval if start <= t < end. This means the
- start time is inclusive and the end time is exclusive.
-
- Parameters
- ----------
- times : np.ndarray
- Array of time points to be checked against the labeled intervals.
-
- Returns
- -------
- out : np.ndarray of bool, shape (len(valid_times), n_labels)
- Boolean array where each row corresponds to a valid time point and each
- column corresponds to a label. `out[i, j]` is True if the i-th valid
- time falls within any interval for the j-th label, and False otherwise.
-
- Notes
- -----
- - The labels and their corresponding intervals are defined in the `meta.yml`
- file under the `labels` key. Each label points to a `.npy` file containing
- an array of shape (n, 2), where each row is a [start, end) time interval.
- - Typical labels might include 'train', 'validation', 'test', 'saccade',
- 'gaze', or 'target'.
- - Only time points within the valid interval (as defined by start_time and
- end_time in meta.yml) are considered; others are filtered out.
- - Intervals where start > end are considered invalid and will trigger a
- warning.
- """
valid = self.valid_times(times)
valid_times = times[valid]
@@ -591,6 +756,27 @@ def interpolate(
class ScreenTrial:
+ """Base class for visual stimulus trials.
+
+ Represents a single trial (stimulus presentation) in a screen recording.
+ Subclasses handle different trial types: images, videos, and blanks.
+
+ Parameters
+ ----------
+ data_file_name : str
+ Path to the data file for this trial.
+ meta_data : dict
+ Metadata dictionary for the trial.
+ image_size : tuple
+ Frame dimensions ``(height, width)`` or ``(height, width, channels)``.
+ first_frame_idx : int
+ Index of the first frame in the global timestamp array.
+ num_frames : int
+ Number of frames in this trial.
+ cache_data : bool, default=False
+ If True, loads and caches data on initialization.
+ """
+
def __init__(
self,
data_file_name: Union[str, Path],
@@ -613,7 +799,9 @@ def __init__(
@staticmethod
def create(
- data_file_name: Union[str, Path], meta_data: dict, cache_data: bool = False
+ data_file_name: Union[str, Path],
+ meta_data: dict,
+ cache_data: bool = False,
) -> "ScreenTrial":
modality = meta_data.get("modality")
assert modality is not None
@@ -636,6 +824,8 @@ def get_meta(self, property: str):
class ImageTrial(ScreenTrial):
+ """Trial containing a single static image."""
+
def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None:
super().__init__(
data_file_name,
@@ -648,6 +838,8 @@ def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None:
class VideoTrial(ScreenTrial):
+ """Trial containing a multi-frame video sequence."""
+
def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None:
super().__init__(
data_file_name,
@@ -660,6 +852,8 @@ def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None:
class BlankTrial(ScreenTrial):
+ """Trial containing a blank/gray screen (inter-stimulus interval)."""
+
def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None:
self.interleave_value = meta_data.get("interleave_value")
@@ -678,6 +872,8 @@ def get_data_(self) -> np.ndarray:
class InvalidTrial(ScreenTrial):
+ """Placeholder for invalid or corrupted trials."""
+
def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None:
self.interleave_value = meta_data.get("interleave_value")
diff --git a/experanto/intervals.py b/experanto/intervals.py
index 6fd0d4e..1607985 100644
--- a/experanto/intervals.py
+++ b/experanto/intervals.py
@@ -5,6 +5,24 @@
class TimeInterval(typing.NamedTuple):
+ """A time interval represented by start and end times.
+
+ Parameters
+ ----------
+ start : float
+ Start time in seconds.
+ end : float
+ End time in seconds.
+
+ Examples
+ --------
+ >>> interval = TimeInterval(0.0, 10.0)
+ >>> 5.0 in interval
+ True
+ >>> 15.0 in interval
+ False
+ """
+
start: float
end: float
@@ -28,18 +46,18 @@ def intersect(self, times: np.ndarray) -> np.ndarray:
return np.where((times >= self.start) & (times <= self.end))[0]
-def uniquefy_interval_array(
- interval_array: List[TimeInterval],
-) -> List[TimeInterval]:
- """
- Takes an array of TimeIntervals and returns a new array where no intervals overlap.
- If intervals overlap or are adjacent, they are merged into a single interval.
+def uniquefy_interval_array(interval_array: List[TimeInterval]) -> List[TimeInterval]:
+ """Merge overlapping or adjacent intervals into non-overlapping intervals.
- Args:
- interval_array: List of TimeInterval objects
+ Parameters
+ ----------
+ interval_array : list of TimeInterval
+ Input intervals, may overlap or be adjacent.
- Returns:
- List of non-overlapping TimeInterval objects sorted by start time
+ Returns
+ -------
+ list of TimeInterval
+ Non-overlapping intervals sorted by start time.
"""
if not interval_array:
return []
@@ -66,6 +84,20 @@ def uniquefy_interval_array(
def find_intersection_between_two_interval_arrays(
interval_array_1: List[TimeInterval], interval_array_2: List[TimeInterval]
) -> List[TimeInterval]:
+ """Find the intersection of two interval arrays.
+
+ Parameters
+ ----------
+ interval_array_1 : list of TimeInterval
+ First set of intervals.
+ interval_array_2 : list of TimeInterval
+ Second set of intervals.
+
+ Returns
+ -------
+ list of TimeInterval
+ Intervals where both input arrays overlap.
+ """
# Sort both arrays by start time
sorted_1 = sorted(interval_array_1, key=lambda x: x.start)
sorted_2 = sorted(interval_array_2, key=lambda x: x.start)
@@ -94,6 +126,18 @@ def find_intersection_between_two_interval_arrays(
def find_intersection_across_arrays_of_intervals(
intervals_array: List[List[TimeInterval]],
) -> List[TimeInterval]:
+ """Find the common intersection across multiple interval arrays.
+
+ Parameters
+ ----------
+ intervals_array : list of list of TimeInterval
+ Multiple sets of intervals.
+
+ Returns
+ -------
+ list of TimeInterval
+ Intervals where all input arrays overlap.
+ """
common_interval_array = intervals_array[0]
for interval_array in intervals_array[1:]:
@@ -107,6 +151,18 @@ def find_intersection_across_arrays_of_intervals(
def find_union_across_arrays_of_intervals(
intervals_array: List[List[TimeInterval]],
) -> List[TimeInterval]:
+ """Find the union of multiple interval arrays.
+
+ Parameters
+ ----------
+ intervals_array : list of list of TimeInterval
+ Multiple sets of intervals.
+
+ Returns
+ -------
+ list of TimeInterval
+ Merged non-overlapping intervals covering all input intervals.
+ """
union_array = []
for interval_array in intervals_array:
union_array.extend(interval_array)
@@ -116,17 +172,21 @@ def find_union_across_arrays_of_intervals(
def find_complement_of_interval_array(
start: float, end: float, interval_array: List[TimeInterval]
) -> List[TimeInterval]:
- """
- Finds the complement of an interval array within a given range [start, end].
- Returns intervals that represent the gaps not covered by the input intervals.
-
- Args:
- start: Start time of the range
- end: End time of the range
- interval_array: List of TimeInterval objects
-
- Returns:
- List of TimeInterval objects representing the complement
+ """Find gaps not covered by intervals within a range.
+
+ Parameters
+ ----------
+ start : float
+ Start of the range.
+ end : float
+ End of the range.
+ interval_array : list of TimeInterval
+ Intervals to find the complement of.
+
+ Returns
+ -------
+ list of TimeInterval
+ Intervals representing uncovered gaps in ``[start, end]``.
"""
if not interval_array:
return [TimeInterval(start, end)]
@@ -157,16 +217,22 @@ def find_complement_of_interval_array(
def get_stats_for_valid_interval(
intervals: List[TimeInterval], start_time: float, end_time: float
) -> str:
- """
- Calculates and returns statistics about valid and invalid time intervals within a given range.
-
- Args:
- intervals: List of TimeInterval objects representing the valid periods.
- start_time: The beginning of the total time range to consider.
- end_time: The end of the total time range to consider.
-
- Returns:
- A string summarizing the statistics of valid and invalid intervals.
+ """Calculate statistics about valid and invalid intervals within a range.
+
+ Parameters
+ ----------
+ intervals : list of TimeInterval
+ Valid time intervals.
+ start_time : float
+ Start of the analysis range.
+ end_time : float
+ End of the analysis range.
+
+ Returns
+ -------
+ str
+ Formatted string with statistics (duration, mean, std) for both
+ valid and invalid intervals.
"""
total_duration = end_time - start_time
if total_duration <= 0:
diff --git a/experanto/utils.py b/experanto/utils.py
index 9551e56..6192bf3 100644
--- a/experanto/utils.py
+++ b/experanto/utils.py
@@ -34,21 +34,20 @@ def replace_nan_with_batch_mean(data: np.ndarray) -> np.ndarray:
def add_behavior_as_channels(data: dict[str, torch.Tensor]) -> dict:
- """
- Adds behavioral data as additional channels to screen data.
-
- Input:
- data = {
- 'screen': torch.Tensor: (c, t, h, w)
- 'eye_tracker': torch.Tensor: (t, c_eye) or (t, h, w)
- 'treadmill': torch.Tensor: (t, c_tread) or (t, h, w)
- }
-
- Output:
- data = {
- 'screen': torch.Tensor: (c+behavior_channels, t, h, w) - contiguous
- ...
- }
+ """Add behavioral data as additional channels to screen data.
+
+ Parameters
+ ----------
+ data : dict
+ Dictionary with keys 'screen', 'eye_tracker', 'treadmill'.
+ Screen shape: ``(C, T, H, W)``.
+ Behavior shapes: ``(T, C_behavior)`` or ``(T, H, W)``.
+
+ Returns
+ -------
+ dict
+ Modified dictionary with behavior concatenated to screen channels.
+ Screen shape becomes ``(C + behavior_channels, T, H, W)``.
"""
screen = data["screen"] # Already contiguous, shape (c, t, h, w)
c, t, h, w = screen.shape
@@ -152,9 +151,28 @@ def __len__(self):
class LongCycler:
- """
- Cycles through trainloaders until the loader with largest size is exhausted.
- Needed for dataloaders of unequal size (as in the monkey data).
+ """Cycle through multiple dataloaders until the longest is exhausted.
+
+ Useful for training with multiple sessions of unequal size. Cycles through
+ all loaders, yielding ``(session_key, batch)`` pairs. Shorter loaders are
+ recycled until the longest loader completes one full epoch.
+
+ Parameters
+ ----------
+ loaders : dict
+ Dictionary mapping session keys to DataLoader instances.
+
+ Attributes
+ ----------
+ max_batches : int
+ Number of batches in the longest loader.
+
+ Examples
+ --------
+ >>> loaders = {'session_1': loader1, 'session_2': loader2}
+ >>> cycler = LongCycler(loaders)
+ >>> for session_key, batch in cycler:
+ ... print(f"Processing {session_key}")
"""
def __init__(self, loaders):
@@ -175,9 +193,20 @@ def __len__(self):
class ShortCycler:
- """
- Cycles through trainloaders until the loader with smallest size is exhausted.
- Needed for dataloaders of unequal size (as in the monkey data).
+ """Cycle through multiple dataloaders until the shortest is exhausted.
+
+ Similar to :class:`LongCycler`, but stops when the smallest loader
+ completes one epoch. No recycling occurs.
+
+ Parameters
+ ----------
+ loaders : dict
+ Dictionary mapping session keys to DataLoader instances.
+
+ Attributes
+ ----------
+ min_batches : int
+ Number of batches in the shortest loader.
"""
def __init__(self, loaders):
@@ -245,7 +274,10 @@ def __init__(self, datasets, session_names=None):
for i, dataset in enumerate(datasets):
session_name = session_names[i]
session_size = len(dataset)
- self.session_indices[session_name] = (start_idx, start_idx + session_size)
+ self.session_indices[session_name] = (
+ start_idx,
+ start_idx + session_size,
+ )
start_idx += session_size
def __len__(self):
@@ -298,15 +330,21 @@ class SessionBatchSampler(Sampler):
"""
def __init__(self, dataset, batch_size, drop_last=False, shuffle=False, seed=None):
- """
- Initialize session batch sampler.
-
- Args:
- dataset: The SessionConcatDataset to sample from
- batch_size: Number of samples per batch
- drop_last: Whether to drop the last batch if it's smaller than batch_size
- shuffle: Whether to shuffle samples within each session
- seed: Random seed for reproducibility
+ """Initialize session batch sampler.
+
+ Parameters
+ ----------
+ dataset : SessionConcatDataset
+ The dataset to sample from.
+ batch_size : int
+ Number of samples per batch.
+ drop_last : bool, optional
+ Whether to drop the last batch if smaller than batch_size.
+ Default is False.
+ shuffle : bool, optional
+ Whether to shuffle samples within each session. Default is False.
+ seed : int, optional
+ Random seed for reproducibility.
"""
self.dataset = dataset
self.batch_size = batch_size
@@ -386,12 +424,45 @@ def set_state(self, state):
class FastSessionDataLoader:
- """
- An optimized dataloader that ensures:
- 1. Each session appears exactly once before repeating
- 2. The epoch ends when the longest session is exhausted
- 3. Perfect alignment between sessions and batches is maintained
- 4. State is properly tracked and can be restored
+ """Optimized multi-session dataloader with state tracking.
+
+ Provides efficient data loading across multiple sessions with guarantees:
+
+ - Each session appears exactly once before repeating
+ - Epoch ends when the longest session is exhausted
+ - Perfect alignment between sessions and batches is maintained
+ - State is properly tracked and can be restored
+
+ Parameters
+ ----------
+ dataset : SessionConcatDataset
+ Concatenated dataset with session tracking.
+ batch_size : int, default=1
+ Number of samples per batch.
+ shuffle : bool, default=False
+ Whether to shuffle samples within each session.
+ num_workers : int, default=0
+ Number of worker processes for data loading.
+ pin_memory : bool, default=False
+ Whether to pin memory for GPU transfer.
+ drop_last : bool, default=False
+ Whether to drop incomplete batches.
+ seed : int, optional
+ Random seed for reproducibility.
+ **kwargs
+ Additional arguments passed to underlying DataLoaders.
+
+ Attributes
+ ----------
+ session_names : list
+ Names of all sessions in the dataset.
+ batches_per_session : dict
+ Number of batches in each session.
+
+ See Also
+ --------
+ SessionConcatDataset : Dataset that tracks session membership.
+ LongCycler : Simpler alternative without state tracking.
"""
def __init__(
@@ -405,18 +476,6 @@ def __init__(
seed=None,
**kwargs,
):
- """
- Initialize optimized session dataloader.
-
- Args:
- dataset: The SessionConcatDataset to load from
- batch_size: Number of samples per batch
- shuffle: Whether to shuffle indices within sessions
- num_workers: Number of worker processes for data loading
- pin_memory: Whether to pin memory in GPU
- drop_last: Whether to drop the last batch if smaller than batch_size
- seed: Random seed for reproducibility
- """
# Store dataset and parameters
self.dataset = dataset
self.batch_size = batch_size
@@ -602,7 +661,6 @@ def __iter__(self):
# Continue until we've gone through one full epoch
# (i.e., until the longest session is exhausted)
while active_sessions and position_in_epoch < self.max_batches_per_session:
-
# Create a cycle order of sessions
cycle_order = self.batch_sampler.get_session_cycle()
@@ -665,15 +723,21 @@ class SessionSpecificSampler(Sampler):
"""
def __init__(self, indices, batch_size, drop_last=False, shuffle=False, seed=None):
- """
- Initialize session-specific sampler.
-
- Args:
- indices: List of dataset indices belonging to this session
- batch_size: Number of samples per batch
- drop_last: Whether to drop the last batch if smaller than batch_size
- shuffle: Whether to shuffle indices
- seed: Random seed for reproducibility
+ """Initialize session-specific sampler.
+
+ Parameters
+ ----------
+ indices : list
+ Dataset indices belonging to this session.
+ batch_size : int
+ Number of samples per batch.
+ drop_last : bool, optional
+ Whether to drop the last batch if smaller than batch_size.
+ Default is False.
+ shuffle : bool, optional
+ Whether to shuffle indices. Default is False.
+ seed : int, optional
+ Random seed for reproducibility.
"""
self.indices = list(indices) # Make a copy to avoid modification issues
self.batch_size = batch_size
diff --git a/pyproject.toml b/pyproject.toml
index 391d610..5f968c8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "experanto"
-version = "0.1"
+version = "0.1.0"
description = "Python package to interpolate recordings and stimuli of neuroscience experiments"
readme = "README.md"
requires-python = ">=3.9"
@@ -9,10 +9,10 @@ requires-python = ">=3.9"
[tool.setuptools.packages.find]
where = ["."]
-include = ["experanto*", "configs"]
+include = ["experanto*"]
[tool.setuptools.package-data]
-"configs" = ["*.yaml"]
+experanto = ["../configs/*.yaml"]
[project.urls]
Homepage = "https://github.com/sensorium-competition/experanto"
@@ -23,7 +23,7 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 88
-target-version = ["py312"]
+target-version = ["py39"]
include = '\.pyi?$'
exclude = '''
/(