Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Experanto
Experanto is a Python package designed for interpolating recordings and stimuli in neuroscience experiments. It enables users to load single or multiple experiments and create efficient dataloaders for machine learning applications.

## Docs
[![Docs](https://readthedocs.org/projects/experanto/badge/?version=latest)](https://experanto.readthedocs.io/)

## Installation
To install Experanto, clone locally and run:
```bash
Expand Down
2 changes: 1 addition & 1 deletion configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dataset:
chunk_size: 16
offset: 0.0 # in seconds
transforms:
normalization: "standardize"
normalization: "normalize_variance_only"
interpolation:
interpolation_mode: "nearest_neighbor"
filters:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/concepts/demo_configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Default YAML configuration
chunk_size: 16
offset: 0.0
transforms:
normalization: standardize
normalization: normalize_variance_only
interpolation:
interpolation_mode: nearest_neighbor
filters:
Expand Down
15 changes: 8 additions & 7 deletions experanto/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@

def get_multisession_dataloader(
paths: List[str],
configs: Union[DictConfig, Dict, List[Union[DictConfig, Dict]]] = None,
configs: Optional[Union[DictConfig, Dict, List[Union[DictConfig, Dict]]]] = None,
shuffle_keys: bool = False,
**kwargs,
) -> DataLoader:
) -> LongCycler:
"""
Create a multisession dataloader from a list of paths and corresponding configs.
Args:
Expand All @@ -44,6 +44,7 @@ def get_multisession_dataloader(
if isinstance(configs, (DictConfig, dict)):
configs = [configs] * len(paths)

assert configs is not None
dataloaders = {}
for i, (path, cfg) in enumerate(zip(paths, configs)):
# TODO use saved meta dict to find data key
Expand All @@ -53,22 +54,22 @@ def get_multisession_dataloader(
dataset_name = path.split("_gaze")[0].split("datasets/")[1]
else:
dataset_name = f"session_{i}"
dataset = ChunkDataset(path, **cfg.dataset)
dataset = ChunkDataset(path, **cfg["dataset"])
dataloaders[dataset_name] = MultiEpochsDataLoader(
dataset,
**cfg.dataloader,
**cfg["dataloader"],
)

return LongCycler(dataloaders)


def get_multisession_concat_dataloader(
paths: List[str],
configs: Union[Dict, List[Dict]] = None,
configs: Optional[Union[Dict, List[Dict]]] = None,
seed: Optional[int] = 0,
dataloader_config: Optional[Dict] = None,
**kwargs,
) -> "FastSessionDataLoader":
) -> Optional["FastSessionDataLoader"]:
"""
Creates a multi-session dataloader using SessionConcatDataset and SessionDataLoader.
Returns (session_key, batch) pairs during iteration.
Expand All @@ -86,7 +87,7 @@ def get_multisession_concat_dataloader(
"""
if configs is None and "config" in kwargs:
configs = kwargs.pop("config")

assert configs is not None
# Convert single config to list for uniform handling
if not isinstance(configs, list):
configs = [configs] * len(paths)
Expand Down
59 changes: 31 additions & 28 deletions experanto/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import importlib
import json
import os
from collections import namedtuple
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -52,15 +51,15 @@ def __init__(
self._sample_times = np.arange(
self.start_time, self.end_time, 1.0 / self.sampling_rate
)
self.DataPoint = namedtuple("DataPoint", self.device_names)

def __len__(self):
return int(len(self._sample_times) / self.chunk_size)

def __getitem__(self, idx):
s = idx * self.chunk_size
times = self._sample_times[s : s + self.chunk_size]
data, _ = self._experiment.interpolate(times)
data = self._experiment.interpolate(times, return_valid=False)
assert isinstance(data, dict)
phase_shifts = self._experiment.devices["responses"]._phase_shifts
timestamps_neurons = (times - times.min())[:, None] + phase_shifts[None, :]
data["timestamps"] = timestamps_neurons
Expand All @@ -75,8 +74,8 @@ class ChunkDataset(Dataset):
def __init__(
self,
root_folder: str,
global_sampling_rate: None,
global_chunk_size: None,
global_sampling_rate: Optional[float] = None,
global_chunk_size: Optional[int] = None,
add_behavior_as_channels: bool = False,
replace_nans_with_means: bool = False,
cache_data: bool = False,
Expand Down Expand Up @@ -123,7 +122,7 @@ def __init__(
chunk_size: null
offset: 0.1
transforms:
standardize: true
normalize_variance_only: true # old standardize: true
interpolation:
interpolation_mode: nearest_neighbor
eye_tracker:
Expand Down Expand Up @@ -157,7 +156,7 @@ def __init__(

self.add_behavior_as_channels = add_behavior_as_channels
self.replace_nans_with_means = replace_nans_with_means
self.sample_stride = self.modality_config.screen.sample_stride
self.sample_stride = self.modality_config.screen.sample_stride # type: ignore[union-attr]
self._experiment = Experiment(
root_folder,
modality_config,
Expand Down Expand Up @@ -265,8 +264,8 @@ def initialize_statistics(self) -> None:
if not isinstance(mode, str):
means = np.array(mode.get("means", means))
stds = np.array(mode.get("stds", stds))
if mode == "standardize":
# If modality should only be standarized, set means to 0.
if mode == "normalize_variance_only":
# If modality should only be adjusted by variance (old "standardize"), set means to 0.
means = np.zeros_like(means)
elif mode == "recompute_responses":
means = np.zeros_like(means)
Expand Down Expand Up @@ -307,9 +306,9 @@ def initialize_transforms(self):
for device_name in self.device_names:
if device_name == "screen":
add_channel = Lambda(self.add_channel_function)
transform_list = []
transform_list: List[Any] = []

for v in self.modality_config.screen.transforms.values():
for v in self.modality_config.screen.transforms.values(): # type: ignore[union-attr]
if isinstance(v, dict): # config dict
module = instantiate(v)
if isinstance(module, torch.nn.Module):
Expand All @@ -318,7 +317,7 @@ def initialize_transforms(self):
transform_list.insert(0, add_channel)
else:

transform_list = [ToTensor()]
transform_list: List[Any] = [ToTensor()]

# Normalization.
if self.modality_config[device_name].transforms.get("normalization", False):
Expand Down Expand Up @@ -370,7 +369,7 @@ def _get_callable_filter(self, filter_config):
}

# Call the factory function with its arguments to get the actual implementation function
implementation_func = factory_func(**args)
implementation_func = factory_func(**args) # type: ignore[reportCallIssue]
return implementation_func

except (ImportError, AttributeError, KeyError, TypeError) as e:
Expand All @@ -385,7 +384,7 @@ def _get_callable_filter(self, filter_config):
def get_valid_intervals_from_filters(
self, visualize: bool = False
) -> List[TimeInterval]:
valid_intervals = None
valid_intervals: Optional[List[TimeInterval]] = None
for modality in self.modality_config:
if "filters" in self.modality_config[modality]:
device = self._experiment.devices[modality]
Expand All @@ -394,7 +393,7 @@ def get_valid_intervals_from_filters(
].items():
# Get the final callable filter function
filter_function = self._get_callable_filter(filter_config)
valid_intervals_ = filter_function(device_=device)
valid_intervals_: List[TimeInterval] = filter_function(device_=device) # type: ignore[assignment]
if visualize:
print(f"modality: {modality}, filter: {filter_name}")
visualization_string = get_stats_for_valid_interval(
Expand All @@ -408,7 +407,7 @@ def get_valid_intervals_from_filters(
valid_intervals, valid_intervals_
)

return valid_intervals
return valid_intervals if valid_intervals is not None else []

def get_condition_mask_from_meta_conditions(
self, valid_conditions_sum_of_product: List[dict]
Expand All @@ -424,7 +423,7 @@ def get_condition_mask_from_meta_conditions(
Returns:
np.ndarray: Boolean mask indicating which trials satisfy at least one set of conditions.
"""
all_conditions = None
all_conditions: Optional[np.ndarray] = None
for valid_conditions_product in valid_conditions_sum_of_product:
conditions_of_product = None
for k, valid_condition in valid_conditions_product.items():
Expand All @@ -440,6 +439,8 @@ def get_condition_mask_from_meta_conditions(
all_conditions = conditions_of_product
else:
all_conditions |= conditions_of_product
if all_conditions is None:
return np.array([], dtype=bool)
return all_conditions

def get_screen_sample_mask_from_meta_conditions(
Expand Down Expand Up @@ -506,7 +507,7 @@ def get_screen_sample_mask_from_meta_conditions(

def get_full_valid_sample_times(
self, filter_for_valid_intervals: bool = True
) -> Iterable:
) -> np.ndarray:
"""
iterates through all sample times and checks if they could be used as
start times, eg if the next `self.chunk_sizes["screen"]` points are still valid
Expand All @@ -530,14 +531,14 @@ def get_full_valid_sample_times(
if not isinstance(valid_conditions, (list, tuple, ListConfig)):
valid_conditions = [valid_conditions]

valid_conditions = list(valid_conditions)

if self.modality_config["screen"]["include_blanks"]:
additional_valid_conditions = {"tier": "blank"}
valid_conditions.append(additional_valid_conditions)

sample_mask_from_meta_conditions = (
self.get_screen_sample_mask_from_meta_conditions(
chunk_size, valid_conditions, filter_for_valid_intervals
)
sample_mask_from_meta_conditions = self.get_screen_sample_mask_from_meta_conditions(
chunk_size, valid_conditions, filter_for_valid_intervals # type: ignore[arg-type]
)

final_mask = duration_mask & sample_mask_from_meta_conditions
Expand All @@ -549,7 +550,7 @@ def shuffle_valid_screen_times(self) -> None:
Shuffle valid screen times using the dataset's random number generator
for reproducibility.
"""
times = self._full_valid_sample_times
times = self._full_valid_sample_times_filtered
if self.seed is not None:
self._valid_screen_times = np.sort(
self._rng.choice(
Expand All @@ -563,7 +564,7 @@ def shuffle_valid_screen_times(self) -> None:
)
)

def get_data_key_from_root_folder(cls, root_folder):
def get_data_key_from_root_folder(self, root_folder):
"""
Extract a data key from the root folder path by checking for a meta.json file.

Expand Down Expand Up @@ -596,10 +597,10 @@ def get_data_key_from_root_folder(cls, root_folder):
data_key = f"{key['animal_id']}-{key['session']}-{key['scan_idx']}"
return data_key
if "dynamic" in root_folder:
dataset_name = path.split("dynamic")[1].split("-Video")[0]
dataset_name = root_folder.split("dynamic")[1].split("-Video")[0]
return dataset_name
elif "_gaze" in path:
dataset_name = path.split("_gaze")[0].split("datasets/")[1]
elif "_gaze" in root_folder:
dataset_name = root_folder.split("_gaze")[0].split("datasets/")[1]
return dataset_name
else:
print(
Expand Down Expand Up @@ -636,7 +637,9 @@ def __getitem__(self, idx) -> dict:
# scale everything back to truncated values
times = times.astype(np.float64) / self.scale_precision

data, _ = self._experiment.interpolate(times, device=device_name)
data = self._experiment.interpolate(
times, device=device_name, return_valid=False
)
out[device_name] = self.transforms[device_name](data).squeeze(
0
) # remove dim0 for response/eye_tracker/treadmill
Expand Down
31 changes: 24 additions & 7 deletions experanto/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import logging
import re
import warnings
from collections import namedtuple
from collections.abc import Sequence
from pathlib import Path
from typing import Optional, Union

import numpy as np
from hydra.utils import instantiate
Expand Down Expand Up @@ -75,7 +75,7 @@ def _load_devices(self) -> None:
warnings.warn(
"Falling back to original Interpolator creation logic.", UserWarning
)
dev = Interpolator.create(d, cache_data=self.cache_data, **interp_conf)
dev = Interpolator.create(d, cache_data=self.cache_data, **interp_conf) # type: ignore[arg-type]

self.devices[d.name] = dev
self.start_time = dev.start_time
Expand All @@ -86,16 +86,33 @@ def _load_devices(self) -> None:
def device_names(self):
return tuple(self.devices.keys())

def interpolate(self, times: slice, device=None) -> tuple[np.ndarray, np.ndarray]:
def interpolate(
self,
times: np.ndarray,
device: Union[str, Interpolator, None] = None,
return_valid: bool = False,
) -> Union[tuple[dict, dict], dict, tuple[np.ndarray, np.ndarray], np.ndarray]:
if device is None:
values = {}
valid = {}
for d, interp in self.devices.items():
values[d], valid[d] = interp.interpolate(times)
res = interp.interpolate(times, return_valid=return_valid)
if return_valid:
vals, vlds = res
values[d] = vals
valid[d] = vlds
else:
values[d] = res
if return_valid:
return values, valid
else:
return values
elif isinstance(device, str):
assert device in self.devices, "Unknown device '{}'".format(device)
values, valid = self.devices[device].interpolate(times)
return values, valid
res = self.devices[device].interpolate(times, return_valid=return_valid)
return res
else:
raise ValueError(f"Unsupported device type: {type(device)}")

def get_valid_range(self, device_name) -> tuple:
def get_valid_range(self, device_name) -> tuple[float, float]:
return tuple(self.devices[device_name].valid_interval)
Loading