-
Notifications
You must be signed in to change notification settings - Fork 223
Add to_pynapple_tsgroup
function
#4074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
samuelgarcia
merged 21 commits into
SpikeInterface:main
from
chrishalcrow:export-to-pynapple
Jul 29, 2025
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
8c48252
initial go
chrishalcrow 24d437d
initial commit
chrishalcrow 4bb5c51
remove pynapple from core deps
chrishalcrow f000f8c
Merge branch 'main' into export-to-pynapple
chrishalcrow b97d0f8
respond zach
chrishalcrow c112436
Merge branch 'export-to-pynapple' of https://github.com/chrishalcrow/…
chrishalcrow e3fc302
remove pandas from test
chrishalcrow 263a5a8
docs
chrishalcrow d59557c
add multisegment stuff
chrishalcrow 5e0031b
Apply suggestions from code review
chrishalcrow c9fb8e5
Merge branch 'main' into export-to-pynapple
chrishalcrow f36c14e
make metadata a bool choice
chrishalcrow 9c0c557
Merge branch 'main' into export-to-pynapple
chrishalcrow 12351ff
remove metadata from tests
chrishalcrow fc37984
Merge branch 'export-to-pynapple' of https://github.com/chrishalcrow/…
chrishalcrow c4b5633
deal with str unit ids
chrishalcrow 1fb0abe
remove pynapple from extractors
chrishalcrow ab825cd
Merge branch 'main' into export-to-pynapple
chrishalcrow 5e1e0db
respond to Zach
chrishalcrow c48f28c
Merge branch 'main' into export-to-pynapple
chrishalcrow 0519eef
Merge branch 'main' into export-to-pynapple
alejoe91 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .to_phy import export_to_phy | ||
from .report import export_report | ||
from .to_ibl import export_to_ibl_gui | ||
from .to_pynapple import to_pynapple_tsgroup |
78 changes: 78 additions & 0 deletions
78
src/spikeinterface/exporters/tests/test_export_to_pynapple.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from spikeinterface.generation import generate_ground_truth_recording | ||
from spikeinterface.core import create_sorting_analyzer, NumpySorting | ||
from spikeinterface.exporters import to_pynapple_tsgroup | ||
|
||
|
||
def test_export_analyzer_to_pynapple(): | ||
""" | ||
Checks to see `to_pynapple_tsgroup` works using a generated sorting analyzer. | ||
Then checks that it works when units are not simply 0,1,2,3... . | ||
""" | ||
|
||
rec, sort = generate_ground_truth_recording(num_units=6) | ||
analyzer = create_sorting_analyzer(sorting=sort, recording=rec) | ||
|
||
unit_ids = analyzer.unit_ids | ||
int_unit_ids = np.array([int(unit_id) for unit_id in unit_ids]) | ||
|
||
a_TsGroup = to_pynapple_tsgroup(analyzer) | ||
|
||
assert np.all(a_TsGroup.index == int_unit_ids) | ||
|
||
subset_of_unit_ids = analyzer.unit_ids[[1, 3, 5]] | ||
int_subset_of_unit_ids = np.array([int(unit_id) for unit_id in subset_of_unit_ids]) | ||
subset_analyzer = analyzer.select_units(unit_ids=subset_of_unit_ids) | ||
|
||
a_sub_TsGroup = to_pynapple_tsgroup(subset_analyzer) | ||
|
||
assert np.all(a_sub_TsGroup.index == int_subset_of_unit_ids) | ||
|
||
# now test automatic metadata | ||
subset_analyzer.compute(["random_spikes", "templates", "unit_locations"]) | ||
a_sub_TsGroup_with_locations = to_pynapple_tsgroup(subset_analyzer) | ||
assert a_sub_TsGroup_with_locations["x"] is not None | ||
|
||
subset_analyzer.compute({"noise_levels": {}, "quality_metrics": {"metric_names": ["snr"]}}) | ||
a_sub_TsGroup_with_qm = to_pynapple_tsgroup(subset_analyzer) | ||
assert a_sub_TsGroup_with_qm["snr"] is not None | ||
|
||
subset_analyzer.compute({"template_metrics": {"metric_names": ["half_width"]}}) | ||
a_sub_TsGroup_with_tm = to_pynapple_tsgroup(subset_analyzer) | ||
assert a_sub_TsGroup_with_tm["half_width"] is not None | ||
|
||
|
||
def test_non_int_unit_ids(): | ||
""" | ||
Pynapple only accepts integer unit ids. If a user passes unit ids which are not castable to ints, | ||
`to_pynapple_tsgroup` will set the index to (0,1,2...) and save the original unit_ids in the | ||
`unit_id` column of the tsgroup metadata. | ||
""" | ||
|
||
# generate fake data with string unit ids | ||
|
||
max_sample = 1000 | ||
num_spikes = 200 | ||
num_units = 3 | ||
|
||
rng = np.random.default_rng(1205) | ||
sample_index = np.sort(rng.choice(range(max_sample), size=num_spikes, replace=False)) | ||
unit_index = rng.choice(range(num_units), size=num_spikes) | ||
segment_index = np.zeros(shape=num_spikes).astype("int") | ||
|
||
spikes = np.zeros( | ||
shape=(200), dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] | ||
) | ||
spikes["sample_index"] = sample_index | ||
spikes["unit_index"] = unit_index | ||
spikes["segment_index"] = segment_index | ||
|
||
sorting = NumpySorting(spikes, sampling_frequency=30_000, unit_ids=["zero", "one", "two"]) | ||
|
||
# the str typed `unit_ids`` should raise a warning | ||
with pytest.warns(UserWarning): | ||
ts = to_pynapple_tsgroup(sorting, attach_unit_metadata=False) | ||
|
||
assert np.all(ts.metadata["unit_id"].values == np.array(["zero", "one", "two"])) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from spikeinterface.core import SortingAnalyzer, BaseSorting | ||
import numpy as np | ||
from warnings import warn | ||
|
||
|
||
def to_pynapple_tsgroup( | ||
sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting, | ||
attach_unit_metadata=True, | ||
segment_index=None, | ||
): | ||
""" | ||
Returns a pynapple TsGroup object based on spike train data. | ||
|
||
Parameters | ||
---------- | ||
sorting_analyzer_or_sorting : SortingAnalyzer | ||
A SortingAnalyzer object | ||
attach_unit_metadata : bool, default: True | ||
If True, any relevant available metadata is attached to the TsGroup. Will attach | ||
`unit_locations`, `quality_metrics` and `template_metrics` if computed. If False, | ||
no metadata is included. | ||
segment_index : int | None, default: None | ||
The segment index. Can be None if mono-segment sorting. | ||
|
||
Returns | ||
------- | ||
spike_train_TsGroup : pynapple.TsGroup | ||
A TsGroup object from the pynapple package. | ||
""" | ||
from pynapple import TsGroup, Ts | ||
import pandas as pd | ||
|
||
if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): | ||
sorting = sorting_analyzer_or_sorting.sorting | ||
elif isinstance(sorting_analyzer_or_sorting, BaseSorting): | ||
sorting = sorting_analyzer_or_sorting | ||
else: | ||
raise TypeError( | ||
f"The `sorting_analyzer_or_sorting` argument must be a SortingAnalyzer or Sorting object, not a {type(sorting_analyzer_or_sorting)} type object." | ||
) | ||
|
||
unit_ids = sorting.unit_ids | ||
|
||
unit_ids_castable = True | ||
try: | ||
unit_ids_ints = [int(unit_id) for unit_id in unit_ids] | ||
except ValueError: | ||
warn_msg = "Pynapple requires integer unit ids, but `unit_ids` cannot be cast to int. " | ||
warn_msg += "We will set the index of the TsGroup to [0,1,2,...] and attach the original " | ||
warn_msg += "unit ids to the TsGroup as metadata with the name 'unit_id'." | ||
warn(warn_msg) | ||
unit_ids_ints = np.arange(len(unit_ids)) | ||
unit_ids_castable = False | ||
|
||
spikes_trains = { | ||
unit_id_int: sorting.get_unit_spike_train(unit_id=unit_id, return_times=True, segment_index=segment_index) | ||
for unit_id_int, unit_id in zip(unit_ids_ints, unit_ids) | ||
} | ||
|
||
metadata_list = [] | ||
if not unit_ids_castable: | ||
metadata_list.append(pd.DataFrame(unit_ids, columns=["unit_id"])) | ||
|
||
# Look for good metadata to add, if there is a sorting analyzer | ||
if attach_unit_metadata and isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): | ||
|
||
metadata_list = [] | ||
if (unit_locations := sorting_analyzer_or_sorting.get_extension("unit_locations")) is not None: | ||
array_of_unit_locations = unit_locations.get_data() | ||
n_dims = np.shape(sorting_analyzer_or_sorting.get_extension("unit_locations").get_data())[1] | ||
pd_of_unit_locations = pd.DataFrame( | ||
array_of_unit_locations, columns=["x", "y", "z"][:n_dims], index=unit_ids | ||
) | ||
metadata_list.append(pd_of_unit_locations) | ||
if (quality_metrics := sorting_analyzer_or_sorting.get_extension("quality_metrics")) is not None: | ||
metadata_list.append(quality_metrics.get_data()) | ||
if (template_metrics := sorting_analyzer_or_sorting.get_extension("template_metrics")) is not None: | ||
metadata_list.append(template_metrics.get_data()) | ||
|
||
if len(metadata_list) > 0: | ||
metadata = pd.concat(metadata_list, axis=1) | ||
metadata.index = unit_ids_ints | ||
else: | ||
metadata = None | ||
|
||
spike_train_tsgroup = TsGroup( | ||
{unit_id: Ts(spike_train) for unit_id, spike_train in spikes_trains.items()}, | ||
metadata=metadata, | ||
) | ||
|
||
return spike_train_tsgroup |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another strategy that wouldn't use the try-except would be
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind. This only works if the ids are string. --which maybe in the future they will be ;P
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I thought about this too and ended up try/excepting.
all([isinstance(unit_id, int) or unit_id.isdigit() for unit_id in unit_ids])
works but is a bit gross.
I'd vote to keep the try/except for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The costs are minimal, but based on my reading a try-except is slightly faster than an if-else if you succeed most of the time, but is quite a bit slower if you except often. That being said even a 10x slowdown of one step isn't really that meaningful. so now that you added in the specific except I'm okay with this. Thanks for humoring me :)