Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1dae16f
added support for compressed data formats and added tests for jpg ima…
schewskone Jun 2, 2025
f2b906f
style: auto-format with black and isort
invalid-email-address Jun 2, 2025
406b9a9
added torchcodec to requirements
schewskone Jun 2, 2025
34c88c7
added ffmpeg to github actions setup
schewskone Jun 2, 2025
902f3ca
Merge branch 'sensorium-competition:main' into compressed_data
schewskone Jun 2, 2025
c009980
replaced inproper usage of torchvision functions for transforms on no…
schewskone Jun 13, 2025
8a5314c
style: auto-format with black and isort
invalid-email-address Jun 13, 2025
76ded2d
added option to read stimulis via image_names instead of index_naming…
schewskone Jul 2, 2025
ad1f7cb
style: auto-format with black and isort
invalid-email-address Jul 2, 2025
9a525b6
adjusted variable namings
schewskone Jul 2, 2025
98c9924
Merge remote-tracking branch 'origin/ToTensor_fix' into compressed_data
schewskone Jul 3, 2025
721d554
interpolation no longer returns valid argument
schewskone Jul 3, 2025
df018f7
style: auto-format with black and isort
invalid-email-address Jul 3, 2025
c07759d
added formating function that handles multi channel data and added ve…
schewskone Jul 4, 2025
a350f82
formatting
schewskone Jul 4, 2025
e95bd0d
Merge branch 'compressed_data' of github.com:schewskone/experanto int…
schewskone Jul 4, 2025
ff24676
cleared notebook output
schewskone Jul 4, 2025
6fc95cf
changed shape to (T,C,H,W), added channels and image_names to metafil…
schewskone Jul 4, 2025
8da38d7
style
schewskone Jul 4, 2025
f87e878
Merge branch 'main' into compressed_data. Merge gitignore from recent…
schewskone Jul 4, 2025
af94841
recombined normalization in datasets transforms and cleaned condition…
schewskone Jul 4, 2025
d18a9db
added unsqueze to transforms to make torchvision transform function o…
schewskone Aug 3, 2025
7c8325e
style: auto-format with black and isort
invalid-email-address Aug 3, 2025
8805869
merged pr#84
schewskone Nov 12, 2025
1da304c
changed screen interpolation to enable sliced decoding instead of ful…
schewskone Jan 27, 2026
0b75107
style: auto-format with black and isort
invalid-email-address Jan 27, 2026
16a2a14
merged custom_interpolator from main
schewskone Jan 27, 2026
b3b0079
revised workflow setup and adjusted bug where indexing was performed …
schewskone Jan 27, 2026
137295c
style: auto-format with black and isort
invalid-email-address Jan 27, 2026
5cdc810
removed debug prints
schewskone Jan 27, 2026
128efde
code review with copilot
schewskone Jan 29, 2026
512a42b
style: auto-format with black and isort
invalid-email-address Jan 29, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ jobs:
with:
python-version: '3.12.8'

- name: Install FFmpeg
run: |
sudo apt-get update
sudo apt-get install -y ffmpeg

- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,6 @@ cython_debug/

*.sif
*.bak

# local workflow testing
bin/
215 changes: 215 additions & 0 deletions examples/allen_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5dc8ad24",
"metadata": {},
"source": [
"### Showcase of interpolation on Experiment level for encoded videos and images in greyscale and RGB."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4aecde6",
"metadata": {},
"outputs": [],
"source": [
"# import dependencies\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.animation as animation\n",
"from IPython.display import HTML"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9bc17dbb",
"metadata": {},
"outputs": [],
"source": [
"# get the experiment class from experanto \n",
"from experanto.experiment import Experiment\n",
"\n",
"# set experiment folder as root\n",
"root_folder = '../../allen-exporter/data/allen_data/experiment_951980471'\n",
"\n",
"# initialize experiment object\n",
"e = Experiment(root_folder)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2447e233",
"metadata": {},
"outputs": [],
"source": [
"times = np.arange(310., 340., 0.5)\n",
"video = e.interpolate(times, device=\"screen\") # shape: (T, C, H, W)\n",
"\n",
"# Clip and convert to uint8 to avoid matplotlib clipping warnings\n",
"video = np.clip(video, 0, 255).astype(np.uint8)\n",
"video = video.transpose(0, 2, 3, 1) # (T, H, W, C)\n",
"\n",
"n_frames, height, width, channels = video.shape\n",
"print(f\"Video shape: {video.shape}\")\n",
"\n",
"# Handle grayscale vs color\n",
"is_grayscale = (channels == 1)\n",
"if is_grayscale:\n",
" video = video[..., 0] # Now shape: (T, H, W)\n",
"\n",
"fig, ax = plt.subplots()\n",
"\n",
"# Initialize with appropriate cmap\n",
"if is_grayscale:\n",
" img = ax.imshow(video[0], cmap='gray', vmin=0, vmax=255)\n",
"else:\n",
" img = ax.imshow(video[0])\n",
"\n",
"ax.axis('off')\n",
"\n",
"def update(frame):\n",
" img.set_array(video[frame])\n",
" ax.set_title(f'Frame {frame}')\n",
" return [img]\n",
"\n",
"ani = animation.FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True)\n",
"\n",
"plt.close(fig)\n",
"HTML(ani.to_jshtml())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "129ae368",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"video = e.interpolate(times, device=\"screen\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d58eaaeb",
"metadata": {},
"outputs": [],
"source": [
"times = np.arange(4484., 4500., 0.5)\n",
"video = e.interpolate(times, device=\"screen\") # shape: (T, C, H, W)\n",
"\n",
"# Clip and convert to uint8 to avoid matplotlib clipping warnings\n",
"video = np.clip(video, 0, 255).astype(np.uint8)\n",
"video = video.transpose(0, 2, 3, 1) # (T, H, W, C)\n",
"\n",
"n_frames, height, width, channels = video.shape\n",
"print(f\"Video shape: {video.shape}\")\n",
"\n",
"# Handle grayscale vs color\n",
"is_grayscale = (channels == 1)\n",
"if is_grayscale:\n",
" video = video[..., 0] # Now shape: (T, H, W)\n",
"\n",
"fig, ax = plt.subplots()\n",
"\n",
"# Initialize with appropriate cmap\n",
"if is_grayscale:\n",
" img = ax.imshow(video[0], cmap='gray', vmin=0, vmax=255)\n",
"else:\n",
" img = ax.imshow(video[0])\n",
"\n",
"ax.axis('off')\n",
"\n",
"def update(frame):\n",
" img.set_array(video[frame])\n",
" ax.set_title(f'Frame {frame}')\n",
" return [img]\n",
"\n",
"ani = animation.FuncAnimation(fig, update, frames=n_frames, interval=50, blit=True)\n",
"\n",
"plt.close(fig)\n",
"HTML(ani.to_jshtml())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a74eb40",
"metadata": {},
"outputs": [],
"source": [
"%%timeit\n",
"video = e.interpolate(times, device=\"screen\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba9c51eb",
"metadata": {},
"outputs": [],
"source": [
"# Download a test_rgb_image.jpg (1200x1900!) put it into the folder with other stimuli and run this to test rgb interpolation in the third cell.\n",
"# My test image : https://wallpapercave.com/download/1900x1200-wallpapers-wp9725873\n",
"# Also adjust number_channels to 3 inside the default.yaml config file\n",
"\n",
"import cv2\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Load the JPEG image (in BGR format)\n",
"img = cv2.imread('../../../allen-exporter/data/allen_data/experiment_951980471/stimuli/test_rgb_image.jpg')\n",
"\n",
"# Optional: convert from BGR to RGB if you want\n",
"img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
"\n",
"# Plot the image\n",
"plt.imshow(img_rgb)\n",
"plt.axis('off') # hide axes\n",
"plt.show()\n",
"\n",
"# Save the image as a .npy file\n",
"np.save('../../../allen-exporter/data/allen_data/experiment_951980471/stimuli/im065.npy', img_rgb)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c52e9e3",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "experanto",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
12 changes: 8 additions & 4 deletions experanto/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,14 @@ def initialize_transforms(self):
transform_list.append(module)

transform_list.insert(0, add_channel)
else:

transform_list = [ToTensor()]
else:
transform_list = [
lambda x: torch.from_numpy(x)
.float()
.unsqueeze(0) # Add C dim to use torchvision transform
]

# Normalization.
if self.modality_config[device_name].transforms.get("normalization", False):
transform_list.append(
torchvision.transforms.Normalize(
Expand All @@ -330,6 +333,7 @@ def initialize_transforms(self):
)

transforms[device_name] = Compose(transform_list)

return transforms

def _get_callable_filter(self, filter_config):
Expand Down Expand Up @@ -636,7 +640,7 @@ def __getitem__(self, idx) -> dict:
# scale everything back to truncated values
times = times.astype(np.float64) / self.scale_precision

data, _ = self._experiment.interpolate(times, device=device_name)
data = self._experiment.interpolate(times, device=device_name)
out[device_name] = self.transforms[device_name](data).squeeze(
0
) # remove dim0 for response/eye_tracker/treadmill
Expand Down
9 changes: 4 additions & 5 deletions experanto/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,15 @@ def _load_devices(self) -> None:
def device_names(self):
return tuple(self.devices.keys())

def interpolate(self, times: slice, device=None) -> tuple[np.ndarray, np.ndarray]:
def interpolate(self, times: slice, device=None) -> np.ndarray:
if device is None:
values = {}
valid = {}
for d, interp in self.devices.items():
values[d], valid[d] = interp.interpolate(times)
values[d] = interp.interpolate(times)
elif isinstance(device, str):
assert device in self.devices, "Unknown device '{}'".format(device)
values, valid = self.devices[device].interpolate(times)
return values, valid
values = self.devices[device].interpolate(times)
return values

def get_valid_range(self, device_name) -> tuple:
return tuple(self.devices[device_name].valid_interval)
Loading