Skip to content

Update to curation format v2 #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ classifiers = [
]

dependencies = [
"spikeinterface[full]>=0.102.3",
"spikeinterface[full]>=0.103.0",
"markdown"
]

Expand Down
48 changes: 29 additions & 19 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import spikeinterface.qualitymetrics
from spikeinterface.core.sorting_tools import spike_vector_to_indices
from spikeinterface.core.core_tools import check_json
from spikeinterface.curation import validate_curation_dict
from spikeinterface.widgets.utils import make_units_table_from_analyzer

from .curation_tools import adding_group, default_label_definitions, empty_curation_data
from .curation_tools import add_merge, default_label_definitions, empty_curation_data

spike_dtype =[('sample_index', 'int64'), ('unit_index', 'int64'),
('channel_index', 'int64'), ('segment_index', 'int64'),
Expand All @@ -26,7 +27,6 @@
color_mode='color_by_unit',
)

# TODO handle return_scaled
from spikeinterface.widgets.sorting_summary import _default_displayed_unit_properties


Expand All @@ -53,7 +53,7 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
self.analyzer = analyzer
assert self.analyzer.get_extension("random_spikes") is not None

self.return_scaled = True
self.return_in_uV = self.analyzer.return_in_uV
self.save_on_compute = save_on_compute

self.verbose = verbose
Expand Down Expand Up @@ -319,7 +319,17 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
if curation_data is None:
self.curation_data = empty_curation_data.copy()
else:
self.curation_data = curation_data
# validate the curation data
format_version = curation_data.get("format_version", None)
# assume version 2 if not present
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I would not do this. This will be a mess with previous curation files.
Lets be strict and accept only dict that do not the format_version, this is part of the spec. no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No curation format is needed for validation!

What we can do is try 2 and then 1. Ok?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add it to the test curation. The controller should fail if not provided

if format_version is None:
raise ValueError("Curation data format version is missing and is required in the curation data.")
try:
validate_curation_dict(curation_data)
self.curation_data = curation_data
except Exception as e:
print(f"Invalid curation data. Initializing with empty curation data.\nError: {e}")
self.curation_data = empty_curation_data.copy()

self.has_default_quality_labels = False
if "label_definitions" not in self.curation_data:
Expand Down Expand Up @@ -538,7 +548,7 @@ def get_traces(self, trace_source='preprocessed', **kargs):
elif trace_source == 'raw':
raise NotImplemented
# TODO get with parent recording the non process recording
kargs['return_scaled'] = self.return_scaled
kargs['return_in_uV'] = self.return_in_uV
traces = rec.get_traces(**kargs)
# put in cache for next call
self._traces_cached[cache_key] = traces
Expand Down Expand Up @@ -668,7 +678,7 @@ def curation_can_be_saved(self):

def construct_final_curation(self):
d = dict()
d["format_version"] = "1"
d["format_version"] = "2"
d["unit_ids"] = self.unit_ids.tolist()
d.update(self.curation_data.copy())
return d
Expand Down Expand Up @@ -699,14 +709,14 @@ def make_manual_delete_if_possible(self, removed_unit_ids):
if not self.curation:
return

all_merged_units = sum(self.curation_data["merge_unit_groups"], [])
all_merged_units = sum([m["unit_ids"] for m in self.curation_data["merges"]], [])
for unit_id in removed_unit_ids:
if unit_id in self.curation_data["removed_units"]:
if unit_id in self.curation_data["removed"]:
continue
# TODO: check if unit is already in a merge group
if unit_id in all_merged_units:
continue
self.curation_data["removed_units"].append(unit_id)
self.curation_data["removed"].append(unit_id)
if self.verbose:
print(f"Unit {unit_id} is removed from the curation data")

Expand All @@ -718,10 +728,10 @@ def make_manual_restore(self, restore_unit_ids):
return

for unit_id in restore_unit_ids:
if unit_id in self.curation_data["removed_units"]:
if unit_id in self.curation_data["removed"]:
if self.verbose:
print(f"Unit {unit_id} is restored from the curation data")
self.curation_data["removed_units"].remove(unit_id)
self.curation_data["removed"].remove(unit_id)

def make_manual_merge_if_possible(self, merge_unit_ids):
"""
Expand All @@ -740,22 +750,22 @@ def make_manual_merge_if_possible(self, merge_unit_ids):
return False

for unit_id in merge_unit_ids:
if unit_id in self.curation_data["removed_units"]:
if unit_id in self.curation_data["removed"]:
return False
merged_groups = adding_group(self.curation_data["merge_unit_groups"], merge_unit_ids)
self.curation_data["merge_unit_groups"] = merged_groups

new_merges = add_merge(self.curation_data["merges"], merge_unit_ids)
self.curation_data["merges"] = new_merges
if self.verbose:
print(f"Merged unit group: {merge_unit_ids}")
print(f"Merged unit group: {[str(u) for u in merge_unit_ids]}")
return True

def make_manual_restore_merge(self, merge_group_indices):
if not self.curation:
return
merge_groups_to_remove = [self.curation_data["merge_unit_groups"][merge_group_index] for merge_group_index in merge_group_indices]
for merge_group in merge_groups_to_remove:
for merge_index in merge_group_indices:
if self.verbose:
print(f"Unmerged merge group {merge_group}")
self.curation_data["merge_unit_groups"].remove(merge_group)
print(f"Unmerged merge group {self.curation_data['merge_unit_groups'][merge_index]['unit_ids']}")
self.curation_data["merges"].pop(merge_index)

def get_curation_label_definitions(self):
# give only label definition with exclusive
Expand Down
32 changes: 17 additions & 15 deletions spikeinterface_gui/curation_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,29 @@

empty_curation_data = {
"manual_labels": [],
"merge_unit_groups": [],
"removed_units": []
"merges": [],
"splits": [],
"removes": []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format_version="2"

}

def adding_group(previous_groups, new_group):
def add_merge(previous_merges, new_merge_unit_ids):
# this is to ensure that np.str_ types are rendered as str
to_merge = [np.array(new_group).tolist()]
to_merge = [np.array(new_merge_unit_ids).tolist()]
unchanged = []
for c_prev in previous_groups:
for c_prev in previous_merges:
is_unaffected = True

for c_new in new_group:
if c_new in c_prev:
c_prev_unit_ids = c_prev["unit_ids"]
for c_new in new_merge_unit_ids:
if c_new in c_prev_unit_ids:
is_unaffected = False
to_merge.append(c_prev)
to_merge.append(c_prev_unit_ids)
break

if is_unaffected:
unchanged.append(c_prev)
new_merge_group = [sum(to_merge, [])]
new_merge_group.extend(unchanged)
# Ensure the unicity
new_merge_group = [list(set(gp)) for gp in new_merge_group]
return new_merge_group
unchanged.append(c_prev_unit_ids)

new_merge_units = [sum(to_merge, [])]
new_merge_units.extend(unchanged)
# Ensure the uniqueness
new_merges = [{"unit_ids": list(set(gp))} for gp in new_merge_units]
return new_merges
12 changes: 6 additions & 6 deletions spikeinterface_gui/curationview.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _qt_make_layout(self):
def _qt_refresh(self):
from .myqt import QT
# Merged
merged_units = self.controller.curation_data["merge_unit_groups"]
merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]]
self.table_merge.clear()
self.table_merge.setRowCount(len(merged_units))
self.table_merge.setColumnCount(1)
Expand All @@ -115,7 +115,7 @@ def _qt_refresh(self):
self.table_merge.resizeColumnToContents(i)

## deleted
removed_units = self.controller.curation_data["removed_units"]
removed_units = self.controller.curation_data["removed"]
self.table_delete.clear()
self.table_delete.setRowCount(len(removed_units))
self.table_delete.setColumnCount(1)
Expand Down Expand Up @@ -161,7 +161,7 @@ def _qt_on_item_selection_changed_merge(self):

dtype = self.controller.unit_ids.dtype
ind = self.table_merge.selectedIndexes()[0].row()
visible_unit_ids = self.controller.curation_data["merge_unit_groups"][ind]
visible_unit_ids = [m["unit_ids"] for m in self.controller.curation_data["merges"]][ind]
visible_unit_ids = [dtype.type(unit_id) for unit_id in visible_unit_ids]
self.controller.set_visible_unit_ids(visible_unit_ids)
self.notify_unit_visibility_changed()
Expand All @@ -170,7 +170,7 @@ def _qt_on_item_selection_changed_delete(self):
if len(self.table_delete.selectedIndexes()) == 0:
return
ind = self.table_delete.selectedIndexes()[0].row()
unit_id = self.controller.curation_data["removed_units"][ind]
unit_id = self.controller.curation_data["removed"][ind]
self.controller.set_all_unit_visibility_off()
# convert to the correct type
unit_id = self.controller.unit_ids.dtype.type(unit_id)
Expand Down Expand Up @@ -332,7 +332,7 @@ def _panel_make_layout(self):
def _panel_refresh(self):
import pandas as pd
# Merged
merged_units = self.controller.curation_data["merge_unit_groups"]
merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]]

# for visualization, we make all row entries strings
merged_units_str = []
Expand All @@ -345,7 +345,7 @@ def _panel_refresh(self):
self.table_merge.selection = []

## deleted
removed_units = self.controller.curation_data["removed_units"]
removed_units = self.controller.curation_data["removed"]
removed_units = [str(unit_id) for unit_id in removed_units]
df = pd.DataFrame({"deleted_unit_id": removed_units})
self.table_delete.value = df
Expand Down
9 changes: 3 additions & 6 deletions spikeinterface_gui/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,13 +538,10 @@ def instantiate_analyzer_and_recording(analyzer_path=None, recording_path=None,
try:
recording = si.load(recording_path)
if recording_type == "raw":
from spikeinterface.preprocessing.pipeline import (
get_preprocessing_dict_from_analyzer,
apply_preprocessing_pipeline,
)
import spikeinterface.preprocessing as spre

preprocessing_pipeline = get_preprocessing_dict_from_analyzer(analyzer_path)
recording_processed = apply_preprocessing_pipeline(
preprocessing_pipeline = spre.get_preprocessing_dict_from_analyzer(analyzer_path)
recording_processed = spre.apply_preprocessing_pipeline(
recording,
preprocessing_pipeline,
)
Expand Down
10 changes: 9 additions & 1 deletion spikeinterface_gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def run_mainwindow(
address="localhost",
port=0,
panel_start_server_kwargs=None,
panel_window_servable=True,
verbose=False,
):
"""
Expand Down Expand Up @@ -75,6 +76,10 @@ def run_mainwindow(
- `{'dev': True}` to enable development mode (default is False).
- `{'autoreload': True}` to enable autoreload of the server when files change
(default is False).
panel_window_servable: bool, default: True
For "web" mode only. If True, the Panel app is made servable.
This is useful when embedding the GUI in another Panel app. In that case,
the `panel_window_servable` should be set to False.
verbose: bool, default: False
If True, print some information in the console
"""
Expand Down Expand Up @@ -130,7 +135,10 @@ def run_mainwindow(
elif backend == "panel":
from .backend_panel import PanelMainWindow, start_server
win = PanelMainWindow(controller, layout_preset=layout_preset, layout=layout)
win.main_layout.servable(title='SpikeInterface GUI')

if start_app or panel_window_servable:
win.main_layout.servable(title='SpikeInterface GUI')

if start_app:
panel_start_server_kwargs = panel_start_server_kwargs or {}
_ = start_server(win, address=address, port=port, **panel_start_server_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion spikeinterface_gui/mergeview.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_table_data(self, include_deleted=False):
unit_ids = list(self.controller.unit_ids)
for group_ids in self.proposed_merge_unit_groups:
if not include_deleted and self.controller.curation:
deleted_unit_ids = self.controller.curation_data["removed_units"]
deleted_unit_ids = self.controller.curation_data["removed"]
if any(unit_id in deleted_unit_ids for unit_id in group_ids):
continue

Expand Down
4 changes: 2 additions & 2 deletions spikeinterface_gui/tests/test_mainwindow_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_launcher(verbose=True):
if __name__ == '__main__':
if not test_folder.is_dir():
setup_module()
# win = test_mainwindow(start_app=True, verbose=True, curation=True)
win = test_mainwindow(start_app=True, verbose=True, curation=True)
# win = test_mainwindow(start_app=True, verbose=True, curation=False)

test_launcher(verbose=True)
# test_launcher(verbose=True)
5 changes: 3 additions & 2 deletions spikeinterface_gui/tests/testingtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def make_analyzer_folder(test_folder, case="small", unit_dtype="str"):
def make_curation_dict(analyzer):
unit_ids = analyzer.unit_ids.tolist()
curation_dict = {
"format_version": "2",
"unit_ids": unit_ids,
"label_definitions": {
"quality":{
Expand All @@ -153,8 +154,8 @@ def make_curation_dict(analyzer):
{'unit_id': unit_ids[2], "putative_type": ["exitatory"]},
{'unit_id': unit_ids[3], "quality": ["noise"], "putative_type": ["inhibitory"]},
],
"merge_unit_groups": [unit_ids[:3], unit_ids[3:5]],
"removed_units": unit_ids[5:8],
"merges": [{"unit_ids": unit_ids[:3]}, {"unit_ids": unit_ids[3:5]}],
"removed": unit_ids[5:8],
}
return curation_dict

Expand Down
1 change: 1 addition & 0 deletions spikeinterface_gui/waveformview.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_spike_waveform(self, ind):
trace_source='preprocessed',
segment_index=seg_num,
start_frame=peak_ind - nbefore, end_frame=peak_ind + nafter,
return_in_uV=self.controller.return_in_uV
)
return wf, width

Expand Down