diff --git a/README.md b/README.md index dec0fdf..ca98f27 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,9 @@ # Experanto Experanto is a Python package designed for interpolating recordings and stimuli in neuroscience experiments. It enables users to load single or multiple experiments and create efficient dataloaders for machine learning applications. +## Docs +[![Docs](https://readthedocs.org/projects/experanto/badge/?version=latest)](https://experanto.readthedocs.io/) + ## Installation To install Experanto, clone locally and run: ```bash diff --git a/configs/default.yaml b/configs/default.yaml index 6983f8d..7d9b2ac 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -30,7 +30,7 @@ dataset: chunk_size: 16 offset: 0.0 # in seconds transforms: - normalization: "standardize" + normalization: "normalize_variance_only" interpolation: interpolation_mode: "nearest_neighbor" filters: diff --git a/docs/source/concepts/demo_configs.rst b/docs/source/concepts/demo_configs.rst index 0203541..89dfdc2 100644 --- a/docs/source/concepts/demo_configs.rst +++ b/docs/source/concepts/demo_configs.rst @@ -47,7 +47,7 @@ Default YAML configuration chunk_size: 16 offset: 0.0 transforms: - normalization: standardize + normalization: normalize_variance_only interpolation: interpolation_mode: nearest_neighbor filters: diff --git a/experanto/dataloaders.py b/experanto/dataloaders.py index 8be2d68..37f344e 100644 --- a/experanto/dataloaders.py +++ b/experanto/dataloaders.py @@ -22,10 +22,10 @@ def get_multisession_dataloader( paths: List[str], - configs: Union[DictConfig, Dict, List[Union[DictConfig, Dict]]] = None, + configs: Optional[Union[DictConfig, Dict, List[Union[DictConfig, Dict]]]] = None, shuffle_keys: bool = False, **kwargs, -) -> DataLoader: +) -> LongCycler: """ Create a multisession dataloader from a list of paths and corresponding configs. Args: @@ -44,6 +44,7 @@ def get_multisession_dataloader( if isinstance(configs, (DictConfig, dict)): configs = [configs] * len(paths) + assert configs is not None dataloaders = {} for i, (path, cfg) in enumerate(zip(paths, configs)): # TODO use saved meta dict to find data key @@ -53,10 +54,10 @@ def get_multisession_dataloader( dataset_name = path.split("_gaze")[0].split("datasets/")[1] else: dataset_name = f"session_{i}" - dataset = ChunkDataset(path, **cfg.dataset) + dataset = ChunkDataset(path, **cfg["dataset"]) dataloaders[dataset_name] = MultiEpochsDataLoader( dataset, - **cfg.dataloader, + **cfg["dataloader"], ) return LongCycler(dataloaders) @@ -64,11 +65,11 @@ def get_multisession_dataloader( def get_multisession_concat_dataloader( paths: List[str], - configs: Union[Dict, List[Dict]] = None, + configs: Optional[Union[Dict, List[Dict]]] = None, seed: Optional[int] = 0, dataloader_config: Optional[Dict] = None, **kwargs, -) -> "FastSessionDataLoader": +) -> Optional["FastSessionDataLoader"]: """ Creates a multi-session dataloader using SessionConcatDataset and SessionDataLoader. Returns (session_key, batch) pairs during iteration. @@ -86,7 +87,7 @@ def get_multisession_concat_dataloader( """ if configs is None and "config" in kwargs: configs = kwargs.pop("config") - + assert configs is not None # Convert single config to list for uniform handling if not isinstance(configs, list): configs = [configs] * len(paths) diff --git a/experanto/datasets.py b/experanto/datasets.py index 92d103a..1045263 100644 --- a/experanto/datasets.py +++ b/experanto/datasets.py @@ -4,7 +4,6 @@ import importlib import json import os -from collections import namedtuple from collections.abc import Iterable from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -52,7 +51,6 @@ def __init__( self._sample_times = np.arange( self.start_time, self.end_time, 1.0 / self.sampling_rate ) - self.DataPoint = namedtuple("DataPoint", self.device_names) def __len__(self): return int(len(self._sample_times) / self.chunk_size) @@ -60,7 +58,8 @@ def __len__(self): def __getitem__(self, idx): s = idx * self.chunk_size times = self._sample_times[s : s + self.chunk_size] - data, _ = self._experiment.interpolate(times) + data = self._experiment.interpolate(times, return_valid=False) + assert isinstance(data, dict) phase_shifts = self._experiment.devices["responses"]._phase_shifts timestamps_neurons = (times - times.min())[:, None] + phase_shifts[None, :] data["timestamps"] = timestamps_neurons @@ -75,8 +74,8 @@ class ChunkDataset(Dataset): def __init__( self, root_folder: str, - global_sampling_rate: None, - global_chunk_size: None, + global_sampling_rate: Optional[float] = None, + global_chunk_size: Optional[int] = None, add_behavior_as_channels: bool = False, replace_nans_with_means: bool = False, cache_data: bool = False, @@ -123,7 +122,7 @@ def __init__( chunk_size: null offset: 0.1 transforms: - standardize: true + normalize_variance_only: true # old standardize: true interpolation: interpolation_mode: nearest_neighbor eye_tracker: @@ -157,7 +156,7 @@ def __init__( self.add_behavior_as_channels = add_behavior_as_channels self.replace_nans_with_means = replace_nans_with_means - self.sample_stride = self.modality_config.screen.sample_stride + self.sample_stride = self.modality_config.screen.sample_stride # type: ignore[union-attr] self._experiment = Experiment( root_folder, modality_config, @@ -265,8 +264,8 @@ def initialize_statistics(self) -> None: if not isinstance(mode, str): means = np.array(mode.get("means", means)) stds = np.array(mode.get("stds", stds)) - if mode == "standardize": - # If modality should only be standarized, set means to 0. + if mode == "normalize_variance_only": + # If modality should only be adjusted by variance (old "standardize"), set means to 0. means = np.zeros_like(means) elif mode == "recompute_responses": means = np.zeros_like(means) @@ -307,9 +306,9 @@ def initialize_transforms(self): for device_name in self.device_names: if device_name == "screen": add_channel = Lambda(self.add_channel_function) - transform_list = [] + transform_list: List[Any] = [] - for v in self.modality_config.screen.transforms.values(): + for v in self.modality_config.screen.transforms.values(): # type: ignore[union-attr] if isinstance(v, dict): # config dict module = instantiate(v) if isinstance(module, torch.nn.Module): @@ -318,7 +317,7 @@ def initialize_transforms(self): transform_list.insert(0, add_channel) else: - transform_list = [ToTensor()] + transform_list: List[Any] = [ToTensor()] # Normalization. if self.modality_config[device_name].transforms.get("normalization", False): @@ -370,7 +369,7 @@ def _get_callable_filter(self, filter_config): } # Call the factory function with its arguments to get the actual implementation function - implementation_func = factory_func(**args) + implementation_func = factory_func(**args) # type: ignore[reportCallIssue] return implementation_func except (ImportError, AttributeError, KeyError, TypeError) as e: @@ -385,7 +384,7 @@ def _get_callable_filter(self, filter_config): def get_valid_intervals_from_filters( self, visualize: bool = False ) -> List[TimeInterval]: - valid_intervals = None + valid_intervals: Optional[List[TimeInterval]] = None for modality in self.modality_config: if "filters" in self.modality_config[modality]: device = self._experiment.devices[modality] @@ -394,7 +393,7 @@ def get_valid_intervals_from_filters( ].items(): # Get the final callable filter function filter_function = self._get_callable_filter(filter_config) - valid_intervals_ = filter_function(device_=device) + valid_intervals_: List[TimeInterval] = filter_function(device_=device) # type: ignore[assignment] if visualize: print(f"modality: {modality}, filter: {filter_name}") visualization_string = get_stats_for_valid_interval( @@ -408,7 +407,7 @@ def get_valid_intervals_from_filters( valid_intervals, valid_intervals_ ) - return valid_intervals + return valid_intervals if valid_intervals is not None else [] def get_condition_mask_from_meta_conditions( self, valid_conditions_sum_of_product: List[dict] @@ -424,7 +423,7 @@ def get_condition_mask_from_meta_conditions( Returns: np.ndarray: Boolean mask indicating which trials satisfy at least one set of conditions. """ - all_conditions = None + all_conditions: Optional[np.ndarray] = None for valid_conditions_product in valid_conditions_sum_of_product: conditions_of_product = None for k, valid_condition in valid_conditions_product.items(): @@ -440,6 +439,8 @@ def get_condition_mask_from_meta_conditions( all_conditions = conditions_of_product else: all_conditions |= conditions_of_product + if all_conditions is None: + return np.array([], dtype=bool) return all_conditions def get_screen_sample_mask_from_meta_conditions( @@ -506,7 +507,7 @@ def get_screen_sample_mask_from_meta_conditions( def get_full_valid_sample_times( self, filter_for_valid_intervals: bool = True - ) -> Iterable: + ) -> np.ndarray: """ iterates through all sample times and checks if they could be used as start times, eg if the next `self.chunk_sizes["screen"]` points are still valid @@ -530,14 +531,14 @@ def get_full_valid_sample_times( if not isinstance(valid_conditions, (list, tuple, ListConfig)): valid_conditions = [valid_conditions] + valid_conditions = list(valid_conditions) + if self.modality_config["screen"]["include_blanks"]: additional_valid_conditions = {"tier": "blank"} valid_conditions.append(additional_valid_conditions) - sample_mask_from_meta_conditions = ( - self.get_screen_sample_mask_from_meta_conditions( - chunk_size, valid_conditions, filter_for_valid_intervals - ) + sample_mask_from_meta_conditions = self.get_screen_sample_mask_from_meta_conditions( + chunk_size, valid_conditions, filter_for_valid_intervals # type: ignore[arg-type] ) final_mask = duration_mask & sample_mask_from_meta_conditions @@ -549,7 +550,7 @@ def shuffle_valid_screen_times(self) -> None: Shuffle valid screen times using the dataset's random number generator for reproducibility. """ - times = self._full_valid_sample_times + times = self._full_valid_sample_times_filtered if self.seed is not None: self._valid_screen_times = np.sort( self._rng.choice( @@ -563,7 +564,7 @@ def shuffle_valid_screen_times(self) -> None: ) ) - def get_data_key_from_root_folder(cls, root_folder): + def get_data_key_from_root_folder(self, root_folder): """ Extract a data key from the root folder path by checking for a meta.json file. @@ -596,10 +597,10 @@ def get_data_key_from_root_folder(cls, root_folder): data_key = f"{key['animal_id']}-{key['session']}-{key['scan_idx']}" return data_key if "dynamic" in root_folder: - dataset_name = path.split("dynamic")[1].split("-Video")[0] + dataset_name = root_folder.split("dynamic")[1].split("-Video")[0] return dataset_name - elif "_gaze" in path: - dataset_name = path.split("_gaze")[0].split("datasets/")[1] + elif "_gaze" in root_folder: + dataset_name = root_folder.split("_gaze")[0].split("datasets/")[1] return dataset_name else: print( @@ -636,7 +637,9 @@ def __getitem__(self, idx) -> dict: # scale everything back to truncated values times = times.astype(np.float64) / self.scale_precision - data, _ = self._experiment.interpolate(times, device=device_name) + data = self._experiment.interpolate( + times, device=device_name, return_valid=False + ) out[device_name] = self.transforms[device_name](data).squeeze( 0 ) # remove dim0 for response/eye_tracker/treadmill diff --git a/experanto/experiment.py b/experanto/experiment.py index 0e2b505..e9df4cf 100644 --- a/experanto/experiment.py +++ b/experanto/experiment.py @@ -3,9 +3,9 @@ import logging import re import warnings -from collections import namedtuple from collections.abc import Sequence from pathlib import Path +from typing import Optional, Union import numpy as np from hydra.utils import instantiate @@ -75,7 +75,7 @@ def _load_devices(self) -> None: warnings.warn( "Falling back to original Interpolator creation logic.", UserWarning ) - dev = Interpolator.create(d, cache_data=self.cache_data, **interp_conf) + dev = Interpolator.create(d, cache_data=self.cache_data, **interp_conf) # type: ignore[arg-type] self.devices[d.name] = dev self.start_time = dev.start_time @@ -86,16 +86,33 @@ def _load_devices(self) -> None: def device_names(self): return tuple(self.devices.keys()) - def interpolate(self, times: slice, device=None) -> tuple[np.ndarray, np.ndarray]: + def interpolate( + self, + times: np.ndarray, + device: Union[str, Interpolator, None] = None, + return_valid: bool = False, + ) -> Union[tuple[dict, dict], dict, tuple[np.ndarray, np.ndarray], np.ndarray]: if device is None: values = {} valid = {} for d, interp in self.devices.items(): - values[d], valid[d] = interp.interpolate(times) + res = interp.interpolate(times, return_valid=return_valid) + if return_valid: + vals, vlds = res + values[d] = vals + valid[d] = vlds + else: + values[d] = res + if return_valid: + return values, valid + else: + return values elif isinstance(device, str): assert device in self.devices, "Unknown device '{}'".format(device) - values, valid = self.devices[device].interpolate(times) - return values, valid + res = self.devices[device].interpolate(times, return_valid=return_valid) + return res + else: + raise ValueError(f"Unsupported device type: {type(device)}") - def get_valid_range(self, device_name) -> tuple: + def get_valid_range(self, device_name) -> tuple[float, float]: return tuple(self.devices[device_name].valid_interval) diff --git a/experanto/interpolators.py b/experanto/interpolators.py index f240740..4f89ebb 100644 --- a/experanto/interpolators.py +++ b/experanto/interpolators.py @@ -7,6 +7,7 @@ import warnings from abc import abstractmethod from pathlib import Path +from typing import Union import cv2 import numpy as np @@ -30,7 +31,9 @@ def load_meta(self): return meta @abstractmethod - def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + def interpolate( + self, times: np.ndarray, return_valid: bool = False + ) -> Union[tuple[np.ndarray, np.ndarray], np.ndarray]: ... # returns interpolated signal and boolean mask of valid samples @@ -66,6 +69,7 @@ def create(root_folder: str, cache_data: bool = False, **kwargs) -> "Interpolato ) def valid_times(self, times: np.ndarray) -> np.ndarray: + assert self.valid_interval is not None return self.valid_interval.intersect(times) def close(self): @@ -150,7 +154,9 @@ def normalize_data(self, data): data = data * self._precision return data - def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + def interpolate( + self, times: np.ndarray, return_valid: bool = False + ) -> Union[tuple[np.ndarray, np.ndarray], np.ndarray]: valid = self.valid_times(times) valid_times = times[valid] @@ -158,7 +164,11 @@ def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]: warnings.warn( "Sequence interpolation returns empty array, no valid times queried" ) - return np.empty((0, self._data.shape[1])), valid + return ( + (np.empty((0, self._data.shape[1])), valid) + if return_valid + else np.empty((0, self._data.shape[1])) + ) idx_lower = np.floor((valid_times - self.start_time) / self.time_delta).astype( int @@ -167,7 +177,7 @@ def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]: if self.interpolation_mode == "nearest_neighbor": data = self._data[idx_lower] - return data, valid + return (data, valid) if return_valid else data elif self.interpolation_mode == "linear": idx_upper = idx_lower + 1 @@ -205,7 +215,7 @@ def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]: # Replace NaNs with the column means directly np.copyto(interpolated, neuron_means, where=np.isnan(interpolated)) - return interpolated, valid + return (interpolated, valid) if return_valid else interpolated else: raise NotImplementedError( @@ -248,7 +258,9 @@ def __init__( + (np.min(self._phase_shifts) if len(self._phase_shifts) > 0 else 0), ) - def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + def interpolate( + self, times: np.ndarray, return_valid: bool = False + ) -> Union[tuple[np.ndarray, np.ndarray], np.ndarray]: valid = self.valid_times(times) valid_times = times[valid] @@ -256,7 +268,11 @@ def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]: warnings.warn( "Sequence interpolation returns empty array, no valid times queried" ) - return np.empty((0, self._data.shape[1])), valid + return ( + (np.empty((0, self._data.shape[1])), valid) + if return_valid + else np.empty((0, self._data.shape[1])) + ) idx_lower = np.floor( ( @@ -269,7 +285,7 @@ def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]: if self.interpolation_mode == "nearest_neighbor": data = np.take_along_axis(self._data, idx_lower, axis=0) - return data, valid + return (data, valid) if return_valid else data elif self.interpolation_mode == "linear": idx_upper = idx_lower + 1 @@ -308,7 +324,7 @@ def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]: # Replace NaNs with the column means directly np.copyto(interpolated, neuron_means, where=np.isnan(interpolated)) - return interpolated, valid + return (interpolated, valid) if return_valid else interpolated else: raise NotImplementedError( @@ -322,7 +338,7 @@ def __init__( root_folder: str, cache_data: bool = False, # New parameter rescale: bool = False, - rescale_size: typing.Optional[tuple(int, int)] = None, + rescale_size: typing.Optional[tuple[int, int]] = None, normalize: bool = False, **kwargs, ) -> None: @@ -400,7 +416,7 @@ def is_numbered_yml(file_name): with open(output_path, "w") as file: json.dump(all_data, file) - def read_combined_meta(self) -> None: + def read_combined_meta(self) -> tuple[list, list]: if not (self.root_folder / "combined_meta.json").exists(): print("Combining metadatas...") self._combine_metadatas() @@ -429,7 +445,9 @@ def _parse_trials(self) -> None: ) ) - def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + def interpolate( + self, times: np.ndarray, return_valid: bool = False + ) -> Union[tuple[np.ndarray, np.ndarray], np.ndarray]: valid = self.valid_times(times) valid_times = times[valid] valid_times += 1e-4 # add small offset to avoid numerical issues @@ -467,9 +485,9 @@ def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]: out[idx_for_this_file] = data[ idx[idx_for_this_file] - self._first_frame_idx[u_idx] ] - return out, valid + return (out, valid) if return_valid else out - def rescale_frame(self, frame: np.array) -> np.array: + def rescale_frame(self, frame: np.ndarray) -> np.ndarray: """ Changes the resolution of the image to this size. Returns: Rescaled image @@ -496,7 +514,9 @@ def __init__(self, root_folder: str, cache_data: bool = False, **kwargs): for label, filename in self.meta_labels.items() } - def interpolate(self, times: np.ndarray) -> np.ndarray: + def interpolate( + self, times: np.ndarray, return_valid: bool = False + ) -> Union[tuple[np.ndarray, np.ndarray], np.ndarray]: """ Interpolate time intervals for labeled events. @@ -567,13 +587,13 @@ def interpolate(self, times: np.ndarray) -> np.ndarray: mask = (valid_times >= start) & (valid_times < end) out[mask, i] = True - return out + return (out, valid) if return_valid else out class ScreenTrial: def __init__( self, - data_file_name: str, + data_file_name: Union[str, Path], meta_data: dict, image_size: tuple, first_frame_idx: int, @@ -593,18 +613,19 @@ def __init__( @staticmethod def create( - data_file_name: str, meta_data: dict, cache_data: bool = False + data_file_name: Union[str, Path], meta_data: dict, cache_data: bool = False ) -> "ScreenTrial": modality = meta_data.get("modality") + assert modality is not None class_name = modality.lower().capitalize() + "Trial" assert class_name in globals(), f"Unknown modality: {modality}" return globals()[class_name](data_file_name, meta_data, cache_data=cache_data) - def get_data_(self) -> np.array: + def get_data_(self) -> np.ndarray: """Base implementation for loading/generating data""" return np.load(self.data_file_name) - def get_data(self) -> np.array: + def get_data(self) -> np.ndarray: """Wrapper that handles caching""" if self._cached_data is not None: return self._cached_data @@ -651,7 +672,7 @@ def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None: cache_data=cache_data, ) - def get_data_(self) -> np.array: + def get_data_(self) -> np.ndarray: """Override base implementation to generate blank data""" return np.full((1,) + self.image_size, self.interleave_value, dtype=np.float32) @@ -669,6 +690,6 @@ def __init__(self, data_file_name, meta_data, cache_data: bool = False) -> None: cache_data=cache_data, ) - def get_data_(self) -> np.array: + def get_data_(self) -> np.ndarray: """Override base implementation to generate blank data""" return np.full((1,) + self.image_size, self.interleave_value, dtype=np.float32) diff --git a/experanto/intervals.py b/experanto/intervals.py index 1ab9737..6fd0d4e 100644 --- a/experanto/intervals.py +++ b/experanto/intervals.py @@ -1,5 +1,5 @@ import typing -from typing import List +from typing import List, Optional import numpy as np @@ -13,7 +13,7 @@ def __contains__(self, time): def find_intersection_between_two_intervals( self, other_interval: "TimeInterval" - ) -> "TimeInterval": + ) -> Optional["TimeInterval"]: start = max(self.start, other_interval.start) end = min(self.end, other_interval.end) if start <= end: @@ -93,7 +93,7 @@ def find_intersection_between_two_interval_arrays( def find_intersection_across_arrays_of_intervals( intervals_array: List[List[TimeInterval]], -) -> TimeInterval: +) -> List[TimeInterval]: common_interval_array = intervals_array[0] for interval_array in intervals_array[1:]: @@ -106,7 +106,7 @@ def find_intersection_across_arrays_of_intervals( def find_union_across_arrays_of_intervals( intervals_array: List[List[TimeInterval]], -) -> TimeInterval: +) -> List[TimeInterval]: union_array = [] for interval_array in intervals_array: union_array.extend(interval_array) diff --git a/experanto/utils.py b/experanto/utils.py index d289132..9551e56 100644 --- a/experanto/utils.py +++ b/experanto/utils.py @@ -13,7 +13,6 @@ from collections import defaultdict from copy import deepcopy from functools import partial -from itertools import cycle from typing import Any, Dict, Iterator, List, Optional, Tuple, Union # third-party libraries @@ -26,7 +25,7 @@ from .intervals import TimeInterval -def replace_nan_with_batch_mean(data: np.array) -> np.array: +def replace_nan_with_batch_mean(data: np.ndarray) -> np.ndarray: row, col = np.where(np.isnan(data)) for i, j in zip(row, col): new_value = np.nanmean(data[:, j]) @@ -111,13 +110,13 @@ def __init__( self.shuffle_each_epoch = shuffle_each_epoch def __len__(self): - return len(self.batch_sampler.sampler) + return len(self.batch_sampler) - def __iter__(self): + def __iter__(self): # type: ignore[override] if self.shuffle_each_epoch and hasattr( self.dataset, "shuffle_valid_screen_times" ): - self.dataset.shuffle_valid_screen_times() + self.dataset.shuffle_valid_screen_times() # type: ignore[union-attr] for i in range(len(self)): yield next(self.iterator) @@ -524,8 +523,8 @@ def set_state(self, state): # Restore RNG state for the main dataloader dataloader_rng_state = state.get("dataloader_rng_state") - if dataloader_rng_state is not None and self.rng is not None: - self.rng.set_state(dataloader_rng_state) + if dataloader_rng_state is not None and hasattr(self, "rng") and self.rng is not None: # type: ignore[attr-defined] + self.rng.set_state(dataloader_rng_state) # type: ignore[attr-defined] # Restore RNG state for the batch sampler batch_sampler_state = state.get("batch_sampler_state") diff --git a/tests/test_screen_interpolator.py b/tests/test_screen_interpolator.py index 25df521..90f4dcc 100644 --- a/tests/test_screen_interpolator.py +++ b/tests/test_screen_interpolator.py @@ -55,9 +55,55 @@ def test_nearest_neighbor_interpolation(duration, fps, image_frame_count, num_vi expected_indices = np.round((times - timestamps[0]) * fps).astype(int) expected_frames = raw_array[expected_indices] - interp, valid = interp_obj.interpolate(times=times) + interp, valid = interp_obj.interpolate(times=times, return_valid=True) assert times.shape == valid.shape, "All interpolated frames should be valid" assert np.allclose( interp, expected_frames, atol=1e-5 ), "Nearest neighbor interpolation mismatch" + + +def test_nearest_neighbor_interpolation_return_valid_false(): + with create_screen_data( + duration=10, + frame_shape=(32, 32), + fps=10.0, + image_frame_count=10, + num_videos=1, + ) as timestamps: + interp_obj = Interpolator.create("tests/screen_data") + assert isinstance(interp_obj, ScreenInterpolator), "Expected ScreenInterpolator" + + delta_t = 1.0 / 10.0 + times = timestamps[:-1] + 0.4 * delta_t + + result = interp_obj.interpolate(times=times, return_valid=False) + assert isinstance(result, np.ndarray), "Expected np.ndarray, not a tuple" + + interp, _ = interp_obj.interpolate(times=times, return_valid=True) + assert np.array_equal( + result, interp + ), "Data from return_valid=False should match data from return_valid=True" + + +def test_nearest_neighbor_interpolation_default_return_valid(): + with create_screen_data( + duration=10, + frame_shape=(32, 32), + fps=10.0, + image_frame_count=10, + num_videos=1, + ) as timestamps: + interp_obj = Interpolator.create("tests/screen_data") + assert isinstance(interp_obj, ScreenInterpolator), "Expected ScreenInterpolator" + + delta_t = 1.0 / 10.0 + times = timestamps[:-1] + 0.4 * delta_t + + result = interp_obj.interpolate(times=times) + assert isinstance(result, np.ndarray), "Expected np.ndarray, not a tuple" + + interp, _ = interp_obj.interpolate(times=times, return_valid=True) + assert np.array_equal( + result, interp + ), "Data from default (no return_valid) should match data from return_valid=True" diff --git a/tests/test_sequence_interpolator.py b/tests/test_sequence_interpolator.py index 4d045d5..b32c5e7 100644 --- a/tests/test_sequence_interpolator.py +++ b/tests/test_sequence_interpolator.py @@ -32,7 +32,7 @@ def test_nearest_neighbor_interpolation(n_signals, sampling_rate, use_mem_mapped times = timestamps[:DEFAULT_SEQUENCE_LENGTH] + 1e-9 interp, valid = seq_interp.interpolate( - times=times + times=times, return_valid=True ) # Add a small epsilon to avoid floating point errors assert times.shape == valid.shape, "All samples should be valid" assert ( @@ -66,7 +66,7 @@ def test_nearest_neighbor_interpolation_handles_nans(n_signals, keep_nans): times = timestamps[:DEFAULT_SEQUENCE_LENGTH] + 1e-9 interp, valid = seq_interp.interpolate( - times=times + times=times, return_valid=True ) # Add a small epsilon to avoid floating point errors assert times.shape == valid.shape, "All samples should be valid" assert np.array_equal( @@ -100,7 +100,7 @@ def test_nearest_neighbor_interpolation_with_inbetween_times(n_signals, sampling # timestamps multiplied by 0.8 should be floored to the same timestamp times = timestamps[:DEFAULT_SEQUENCE_LENGTH] + 0.8 * delta_t - interp, valid = seq_interp.interpolate(times=times) + interp, valid = seq_interp.interpolate(times=times, return_valid=True) assert times.shape == valid.shape, "All samples should be valid" assert ( interp == data[:DEFAULT_SEQUENCE_LENGTH] @@ -108,7 +108,7 @@ def test_nearest_neighbor_interpolation_with_inbetween_times(n_signals, sampling # timestamps multiplied by 1.2 should be floored to the next timestamp times = timestamps[:DEFAULT_SEQUENCE_LENGTH] + 1.2 * delta_t - interp, valid = seq_interp.interpolate(times=times) + interp, valid = seq_interp.interpolate(times=times, return_valid=True) assert times.shape == valid.shape, "All samples should be valid" assert ( interp == data[1 : DEFAULT_SEQUENCE_LENGTH + 1] @@ -138,7 +138,7 @@ def test_nearest_neighbor_interpolation_with_phase_shifts( times = ( timestamps[1 : DEFAULT_SEQUENCE_LENGTH + 1] + 1e-9 ) # Add a small epsilon to avoid floating point errors - interp, valid = seq_interp.interpolate(times=times) + interp, valid = seq_interp.interpolate(times=times, return_valid=True) assert times.shape == valid.shape, "All samples should be valid" assert ( interp == data[0:DEFAULT_SEQUENCE_LENGTH] @@ -156,7 +156,9 @@ def test_nearest_neighbor_interpolation_with_phase_shifts( for dt in np.linspace(0, 0.99) * delta_t: shifted_times = times + shift[i] + dt - interp, valid = seq_interp.interpolate(times=shifted_times) + interp, valid = seq_interp.interpolate( + times=shifted_times, return_valid=True + ) assert ( interp[:, i] == data[1 : DEFAULT_SEQUENCE_LENGTH + 1, i] ).all(), f"Data at {dt} does not match original data (use_mem_mapped={use_mem_mapped}, sampling_rate={sampling_rate}, shifts_per_signal={True})" @@ -164,7 +166,9 @@ def test_nearest_neighbor_interpolation_with_phase_shifts( for dt in np.linspace(1.0, 1.99) * delta_t: shifted_times = times + shift[i] + dt - interp, valid = seq_interp.interpolate(times=shifted_times) + interp, valid = seq_interp.interpolate( + times=shifted_times, return_valid=True + ) assert ( interp[:, i] == data[2 : DEFAULT_SEQUENCE_LENGTH + 2, i] ).all(), f"Data at {dt} does not match original data (use_mem_mapped={use_mem_mapped}, sampling_rate={sampling_rate}, shifts_per_signal={True})" @@ -193,7 +197,7 @@ def test_nearest_neighbor_interpolation_with_phase_shifts_handles_nans( times = ( timestamps[1 : DEFAULT_SEQUENCE_LENGTH + 1] + 1e-9 ) # Add a small epsilon to avoid floating point errors - interp, valid = seq_interp.interpolate(times=times) + interp, valid = seq_interp.interpolate(times=times, return_valid=True) assert times.shape == valid.shape, "All samples should be valid" assert np.array_equal( interp, data[0:DEFAULT_SEQUENCE_LENGTH], equal_nan=True @@ -242,7 +246,7 @@ def test_linear_interpolation( expected = y1 + ((times[:, np.newaxis] - t1) / (t2 - t1)) * (y2 - y1) if not keep_nans: np.copyto(expected, np.nanmean(expected, axis=0), where=np.isnan(expected)) - interp, valid = seq_interp.interpolate(times=times) + interp, valid = seq_interp.interpolate(times=times, return_valid=True) assert times.shape == valid.shape, "All samples should be valid" assert np.allclose( @@ -311,7 +315,9 @@ def test_linear_interpolation_with_phase_shifts( expected, np.nanmean(expected, axis=0), where=np.isnan(expected) ) - interp, valid = seq_interp.interpolate(times=shifted_times) + interp, valid = seq_interp.interpolate( + times=shifted_times, return_valid=True + ) valid_indices = np.where(valid)[0] if len(valid_indices) > 0: @@ -347,7 +353,7 @@ def test_interpolation_for_invalid_times(interpolation_mode, end_time, keep_nans seq_interp.interpolation_mode = interpolation_mode times = np.array([-5.0, -0.1, 0.1, 4.9, 5.0, 5.1, 10.0]) - interp, valid = seq_interp.interpolate(times=times) + interp, valid = seq_interp.interpolate(times=times, return_valid=True) expected_valid = ( np.where((times >= 0.0) & (times <= end_time))[0] if interpolation_mode == "nearest_neighbor" @@ -389,7 +395,7 @@ def test_interpolation_with_phase_shifts_for_invalid_times( seq_interp.interpolation_mode = interpolation_mode times = np.array([-5.0, -0.1, 0.1, 4.9, 4.9999999, 5.0, 5.0000001, 5.1, 10.0]) - interp, valid = seq_interp.interpolate(times=times) + interp, valid = seq_interp.interpolate(times=times, return_valid=True) assert ( np.where( (times >= np.min(phase_shifts)) @@ -425,16 +431,120 @@ def test_interpolation_for_empty_times(interpolation_mode, phase_shifts): UserWarning, match="Sequence interpolation returns empty array, no valid times queried", ): - interp, valid = seq_interp.interpolate(times=np.array([])) + interp, valid = seq_interp.interpolate( + times=np.array([]), return_valid=True + ) assert interp.shape[0] == 0, "No data expected" assert valid.shape[0] == 0, "No data expected" +def test_nearest_neighbor_interpolation_return_valid_false(): + with sequence_data_and_interpolator( + data_kwargs=dict( + n_signals=10, + use_mem_mapped=False, + t_end=5.0, + sampling_rate=10.0, + ) + ) as (timestamps, data, _, seq_interp): + times = timestamps[:DEFAULT_SEQUENCE_LENGTH] + 1e-9 + + result = seq_interp.interpolate(times=times, return_valid=False) + assert isinstance(result, np.ndarray), "Expected np.ndarray, not a tuple" + assert result.shape == ( + DEFAULT_SEQUENCE_LENGTH, + 10, + ), f"Expected shape ({DEFAULT_SEQUENCE_LENGTH}, 10), got {result.shape}" + + interp, _ = seq_interp.interpolate(times=times, return_valid=True) + assert np.array_equal( + result, interp + ), "Data from return_valid=False should match data from return_valid=True" + + +def test_nearest_neighbor_interpolation_default_return_valid(): + with sequence_data_and_interpolator( + data_kwargs=dict( + n_signals=10, + use_mem_mapped=False, + t_end=5.0, + sampling_rate=10.0, + ) + ) as (timestamps, data, _, seq_interp): + times = timestamps[:DEFAULT_SEQUENCE_LENGTH] + 1e-9 + + result = seq_interp.interpolate(times=times) + assert isinstance(result, np.ndarray), "Expected np.ndarray, not a tuple" + assert result.shape == ( + DEFAULT_SEQUENCE_LENGTH, + 10, + ), f"Expected shape ({DEFAULT_SEQUENCE_LENGTH}, 10), got {result.shape}" + + interp, _ = seq_interp.interpolate(times=times, return_valid=True) + assert np.array_equal( + result, interp + ), "Data from default (no return_valid) should match data from return_valid=True" + + +def test_linear_interpolation_return_valid_false(): + with sequence_data_and_interpolator( + data_kwargs=dict( + n_signals=10, + use_mem_mapped=False, + t_end=5.0, + sampling_rate=10.0, + ) + ) as (timestamps, data, _, seq_interp): + seq_interp.interpolation_mode = "linear" + + delta_t = 1.0 / 10.0 + times = timestamps[1 : DEFAULT_SEQUENCE_LENGTH + 1] + 0.5 * delta_t + + result = seq_interp.interpolate(times=times, return_valid=False) + assert isinstance(result, np.ndarray), "Expected np.ndarray, not a tuple" + assert result.shape == ( + DEFAULT_SEQUENCE_LENGTH, + 10, + ), f"Expected shape ({DEFAULT_SEQUENCE_LENGTH}, 10), got {result.shape}" + + interp, _ = seq_interp.interpolate(times=times, return_valid=True) + assert np.allclose( + result, interp, equal_nan=True + ), "Data from return_valid=False should match data from return_valid=True" + + +def test_linear_interpolation_default_return_valid(): + with sequence_data_and_interpolator( + data_kwargs=dict( + n_signals=10, + use_mem_mapped=False, + t_end=5.0, + sampling_rate=10.0, + ) + ) as (timestamps, data, _, seq_interp): + seq_interp.interpolation_mode = "linear" + + delta_t = 1.0 / 10.0 + times = timestamps[1 : DEFAULT_SEQUENCE_LENGTH + 1] + 0.5 * delta_t + + result = seq_interp.interpolate(times=times) + assert isinstance(result, np.ndarray), "Expected np.ndarray, not a tuple" + assert result.shape == ( + DEFAULT_SEQUENCE_LENGTH, + 10, + ), f"Expected shape ({DEFAULT_SEQUENCE_LENGTH}, 10), got {result.shape}" + + interp, _ = seq_interp.interpolate(times=times, return_valid=True) + assert np.allclose( + result, interp, equal_nan=True + ), "Data from default (no return_valid) should match data from return_valid=True" + + def test_interpolation_mode_not_implemented(): with sequence_data_and_interpolator() as (_, _, _, seq_interp): seq_interp.interpolation_mode = "unsupported_mode" with pytest.raises(NotImplementedError): - seq_interp.interpolate(np.array([0.0, 1.0, 2.0])) + seq_interp.interpolate(np.array([0.0, 1.0, 2.0]), return_valid=True) if __name__ == "__main__":