Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions doc/modules/exporters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,48 @@ Exporters module
The :py:mod:`spikeinterface.exporters` module includes functions to export SpikeInterface objects to other commonly
used frameworks.

Exporting to Pynapple
---------------------

The Python package `Pynapple <https://pynapple.org/>`_ is often used for combining ephys
and behavioral data. It can be used to decode behavior, make tuning curves, compute spectrograms, and more!
The :py:func:`~spikeinterface.exporters.to_pynapple_tsgroup` function allows you to convert a
SortingAnalyzer to Pynapple's ``TsGroup`` object on the fly.

**Note** : When creating the ``TsGroup``, we will use the underlying time support of the SortingAnalyzer.
How this works depends on your acquisition system. You can use the ``get_times`` method on a recording
(``my_recording.get_times()``) to find the time support of your recording.

When constructed, if ``attach_unit_metadata`` is set to ``True``, any relevant unit information
is propagated to the ``TsGroup``. The ``to_pynapple_tsgroup`` checks if unit locations, quality
metrics and template metrics have been computed. Whatever has been computed is attached to the
returned object. For more control, set ``attach_unit_metadata`` to ``False`` and attach metadata
using ``Pynapple``'s ``set_info`` method.

The following code creates a ``TsGroup`` from a ``SortingAnalyzer``, then saves it using ``Pynapple``'s
save method.

.. code-block:: python

import spikeinterface as si
from spikeinterface.exporters import to_pynapple_tsgroup

# load in an analyzer
analyzer = si.load_sorting_analyzer("path/to/analyzer")

my_tsgroup = to_pynapple_tsgroup(
sorting_analyzer=analyzer,
attach_unit_metadata=True,
)

# Note: can add metadata using e.g.
# my_tsgroup.set_info({'brain_region': ['MEC', 'MEC', ...]})

my_tsgroup.save("my_tsgroup_output.npz")

If you have a multi-segment sorting, you need to pass the ``segment_index`` argument to the
``to_pynapple_tsgroup`` function. This way, you can generate one ``TsGroup`` per segment.
You can later concatenate these ``TsGroup`` s using Pynapple's ``concatenate`` functionality.

Exporting to Phy
----------------
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ test = [
# streaming templates
"s3fs",

# exporters
"pynapple",

# tridesclous2
"numba<0.61.0;python_version<'3.13'",
"numba>=0.61.0;python_version>='3.13'",
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/exporters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .to_phy import export_to_phy
from .report import export_report
from .to_ibl import export_to_ibl_gui
from .to_pynapple import to_pynapple_tsgroup
78 changes: 78 additions & 0 deletions src/spikeinterface/exporters/tests/test_export_to_pynapple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import pytest

from spikeinterface.generation import generate_ground_truth_recording
from spikeinterface.core import create_sorting_analyzer, NumpySorting
from spikeinterface.exporters import to_pynapple_tsgroup


def test_export_analyzer_to_pynapple():
"""
Checks to see `to_pynapple_tsgroup` works using a generated sorting analyzer.
Then checks that it works when units are not simply 0,1,2,3... .
"""

rec, sort = generate_ground_truth_recording(num_units=6)
analyzer = create_sorting_analyzer(sorting=sort, recording=rec)

unit_ids = analyzer.unit_ids
int_unit_ids = np.array([int(unit_id) for unit_id in unit_ids])

a_TsGroup = to_pynapple_tsgroup(analyzer)

assert np.all(a_TsGroup.index == int_unit_ids)

subset_of_unit_ids = analyzer.unit_ids[[1, 3, 5]]
int_subset_of_unit_ids = np.array([int(unit_id) for unit_id in subset_of_unit_ids])
subset_analyzer = analyzer.select_units(unit_ids=subset_of_unit_ids)

a_sub_TsGroup = to_pynapple_tsgroup(subset_analyzer)

assert np.all(a_sub_TsGroup.index == int_subset_of_unit_ids)

# now test automatic metadata
subset_analyzer.compute(["random_spikes", "templates", "unit_locations"])
a_sub_TsGroup_with_locations = to_pynapple_tsgroup(subset_analyzer)
assert a_sub_TsGroup_with_locations["x"] is not None

subset_analyzer.compute({"noise_levels": {}, "quality_metrics": {"metric_names": ["snr"]}})
a_sub_TsGroup_with_qm = to_pynapple_tsgroup(subset_analyzer)
assert a_sub_TsGroup_with_qm["snr"] is not None

subset_analyzer.compute({"template_metrics": {"metric_names": ["half_width"]}})
a_sub_TsGroup_with_tm = to_pynapple_tsgroup(subset_analyzer)
assert a_sub_TsGroup_with_tm["half_width"] is not None


def test_non_int_unit_ids():
"""
Pynapple only accepts integer unit ids. If a user passes unit ids which are not castable to ints,
`to_pynapple_tsgroup` will set the index to (0,1,2...) and save the original unit_ids in the
`unit_id` column of the tsgroup metadata.
"""

# generate fake data with string unit ids

max_sample = 1000
num_spikes = 200
num_units = 3

rng = np.random.default_rng(1205)
sample_index = np.sort(rng.choice(range(max_sample), size=num_spikes, replace=False))
unit_index = rng.choice(range(num_units), size=num_spikes)
segment_index = np.zeros(shape=num_spikes).astype("int")

spikes = np.zeros(
shape=(200), dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")]
)
spikes["sample_index"] = sample_index
spikes["unit_index"] = unit_index
spikes["segment_index"] = segment_index

sorting = NumpySorting(spikes, sampling_frequency=30_000, unit_ids=["zero", "one", "two"])

# the str typed `unit_ids`` should raise a warning
with pytest.warns(UserWarning):
ts = to_pynapple_tsgroup(sorting, attach_unit_metadata=False)

assert np.all(ts.metadata["unit_id"].values == np.array(["zero", "one", "two"]))
91 changes: 91 additions & 0 deletions src/spikeinterface/exporters/to_pynapple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from spikeinterface.core import SortingAnalyzer, BaseSorting
import numpy as np
from warnings import warn


def to_pynapple_tsgroup(
sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting,
attach_unit_metadata=True,
segment_index=None,
):
"""
Returns a pynapple TsGroup object based on spike train data.

Parameters
----------
sorting_analyzer_or_sorting : SortingAnalyzer
A SortingAnalyzer object
attach_unit_metadata : bool, default: True
If True, any relevant available metadata is attached to the TsGroup. Will attach
`unit_locations`, `quality_metrics` and `template_metrics` if computed. If False,
no metadata is included.
segment_index : int | None, default: None
The segment index. Can be None if mono-segment sorting.

Returns
-------
spike_train_TsGroup : pynapple.TsGroup
A TsGroup object from the pynapple package.
"""
from pynapple import TsGroup, Ts
import pandas as pd

if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):
sorting = sorting_analyzer_or_sorting.sorting
elif isinstance(sorting_analyzer_or_sorting, BaseSorting):
sorting = sorting_analyzer_or_sorting
else:
raise TypeError(
f"The `sorting_analyzer_or_sorting` argument must be a SortingAnalyzer or Sorting object, not a {type(sorting_analyzer_or_sorting)} type object."
)

unit_ids = sorting.unit_ids

unit_ids_castable = True
try:
unit_ids_ints = [int(unit_id) for unit_id in unit_ids]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another strategy that wouldn't use the try-except would be

unit_ids_castable = all([unit_id.isdigit() for unit_id in unit_ids])
if unit_ids_castable:
    unit_ids_ints = [int(unit_id) for unit_id in unit_ids]
else:
xx

Copy link
Member

@zm711 zm711 Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind. This only works if the ids are string. --which maybe in the future they will be ;P

Copy link
Member Author

@chrishalcrow chrishalcrow Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I thought about this too and ended up try/excepting.
all([isinstance(unit_id, int) or unit_id.isdigit() for unit_id in unit_ids])
works but is a bit gross.

I'd vote to keep the try/except for now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The costs are minimal, but based on my reading a try-except is slightly faster than an if-else if you succeed most of the time, but is quite a bit slower if you except often. That being said even a 10x slowdown of one step isn't really that meaningful. so now that you added in the specific except I'm okay with this. Thanks for humoring me :)

except ValueError:
warn_msg = "Pynapple requires integer unit ids, but `unit_ids` cannot be cast to int. "
warn_msg += "We will set the index of the TsGroup to [0,1,2,...] and attach the original "
warn_msg += "unit ids to the TsGroup as metadata with the name 'unit_id'."
warn(warn_msg)
unit_ids_ints = np.arange(len(unit_ids))
unit_ids_castable = False

spikes_trains = {
unit_id_int: sorting.get_unit_spike_train(unit_id=unit_id, return_times=True, segment_index=segment_index)
for unit_id_int, unit_id in zip(unit_ids_ints, unit_ids)
}

metadata_list = []
if not unit_ids_castable:
metadata_list.append(pd.DataFrame(unit_ids, columns=["unit_id"]))

# Look for good metadata to add, if there is a sorting analyzer
if attach_unit_metadata and isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):

metadata_list = []
if (unit_locations := sorting_analyzer_or_sorting.get_extension("unit_locations")) is not None:
array_of_unit_locations = unit_locations.get_data()
n_dims = np.shape(sorting_analyzer_or_sorting.get_extension("unit_locations").get_data())[1]
pd_of_unit_locations = pd.DataFrame(
array_of_unit_locations, columns=["x", "y", "z"][:n_dims], index=unit_ids
)
metadata_list.append(pd_of_unit_locations)
if (quality_metrics := sorting_analyzer_or_sorting.get_extension("quality_metrics")) is not None:
metadata_list.append(quality_metrics.get_data())
if (template_metrics := sorting_analyzer_or_sorting.get_extension("template_metrics")) is not None:
metadata_list.append(template_metrics.get_data())

if len(metadata_list) > 0:
metadata = pd.concat(metadata_list, axis=1)
metadata.index = unit_ids_ints
else:
metadata = None

spike_train_tsgroup = TsGroup(
{unit_id: Ts(spike_train) for unit_id, spike_train in spikes_trains.items()},
metadata=metadata,
)

return spike_train_tsgroup
Loading