diff --git a/doc/modules/exporters.rst b/doc/modules/exporters.rst index 4819a7b33f..abcb3319c4 100644 --- a/doc/modules/exporters.rst +++ b/doc/modules/exporters.rst @@ -4,6 +4,48 @@ Exporters module The :py:mod:`spikeinterface.exporters` module includes functions to export SpikeInterface objects to other commonly used frameworks. +Exporting to Pynapple +--------------------- + +The Python package `Pynapple `_ is often used for combining ephys +and behavioral data. It can be used to decode behavior, make tuning curves, compute spectrograms, and more! +The :py:func:`~spikeinterface.exporters.to_pynapple_tsgroup` function allows you to convert a +SortingAnalyzer to Pynapple's ``TsGroup`` object on the fly. + +**Note** : When creating the ``TsGroup``, we will use the underlying time support of the SortingAnalyzer. +How this works depends on your acquisition system. You can use the ``get_times`` method on a recording +(``my_recording.get_times()``) to find the time support of your recording. + +When constructed, if ``attach_unit_metadata`` is set to ``True``, any relevant unit information +is propagated to the ``TsGroup``. The ``to_pynapple_tsgroup`` checks if unit locations, quality +metrics and template metrics have been computed. Whatever has been computed is attached to the +returned object. For more control, set ``attach_unit_metadata`` to ``False`` and attach metadata +using ``Pynapple``'s ``set_info`` method. + +The following code creates a ``TsGroup`` from a ``SortingAnalyzer``, then saves it using ``Pynapple``'s +save method. + +.. code-block:: python + + import spikeinterface as si + from spikeinterface.exporters import to_pynapple_tsgroup + + # load in an analyzer + analyzer = si.load_sorting_analyzer("path/to/analyzer") + + my_tsgroup = to_pynapple_tsgroup( + sorting_analyzer=analyzer, + attach_unit_metadata=True, + ) + + # Note: can add metadata using e.g. + # my_tsgroup.set_info({'brain_region': ['MEC', 'MEC', ...]}) + + my_tsgroup.save("my_tsgroup_output.npz") + +If you have a multi-segment sorting, you need to pass the ``segment_index`` argument to the +``to_pynapple_tsgroup`` function. This way, you can generate one ``TsGroup`` per segment. + You can later concatenate these ``TsGroup`` s using Pynapple's ``concatenate`` functionality. Exporting to Phy ---------------- diff --git a/pyproject.toml b/pyproject.toml index cb17ffb8e8..1e680fc17d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -161,6 +161,9 @@ test = [ # streaming templates "s3fs", + # exporters + "pynapple", + # tridesclous2 "numba<0.61.0;python_version<'3.13'", "numba>=0.61.0;python_version>='3.13'", diff --git a/src/spikeinterface/exporters/__init__.py b/src/spikeinterface/exporters/__init__.py index dd0d7b0755..97d0f64126 100644 --- a/src/spikeinterface/exporters/__init__.py +++ b/src/spikeinterface/exporters/__init__.py @@ -1,3 +1,4 @@ from .to_phy import export_to_phy from .report import export_report from .to_ibl import export_to_ibl_gui +from .to_pynapple import to_pynapple_tsgroup diff --git a/src/spikeinterface/exporters/tests/test_export_to_pynapple.py b/src/spikeinterface/exporters/tests/test_export_to_pynapple.py new file mode 100644 index 0000000000..a82bc95829 --- /dev/null +++ b/src/spikeinterface/exporters/tests/test_export_to_pynapple.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest + +from spikeinterface.generation import generate_ground_truth_recording +from spikeinterface.core import create_sorting_analyzer, NumpySorting +from spikeinterface.exporters import to_pynapple_tsgroup + + +def test_export_analyzer_to_pynapple(): + """ + Checks to see `to_pynapple_tsgroup` works using a generated sorting analyzer. + Then checks that it works when units are not simply 0,1,2,3... . + """ + + rec, sort = generate_ground_truth_recording(num_units=6) + analyzer = create_sorting_analyzer(sorting=sort, recording=rec) + + unit_ids = analyzer.unit_ids + int_unit_ids = np.array([int(unit_id) for unit_id in unit_ids]) + + a_TsGroup = to_pynapple_tsgroup(analyzer) + + assert np.all(a_TsGroup.index == int_unit_ids) + + subset_of_unit_ids = analyzer.unit_ids[[1, 3, 5]] + int_subset_of_unit_ids = np.array([int(unit_id) for unit_id in subset_of_unit_ids]) + subset_analyzer = analyzer.select_units(unit_ids=subset_of_unit_ids) + + a_sub_TsGroup = to_pynapple_tsgroup(subset_analyzer) + + assert np.all(a_sub_TsGroup.index == int_subset_of_unit_ids) + + # now test automatic metadata + subset_analyzer.compute(["random_spikes", "templates", "unit_locations"]) + a_sub_TsGroup_with_locations = to_pynapple_tsgroup(subset_analyzer) + assert a_sub_TsGroup_with_locations["x"] is not None + + subset_analyzer.compute({"noise_levels": {}, "quality_metrics": {"metric_names": ["snr"]}}) + a_sub_TsGroup_with_qm = to_pynapple_tsgroup(subset_analyzer) + assert a_sub_TsGroup_with_qm["snr"] is not None + + subset_analyzer.compute({"template_metrics": {"metric_names": ["half_width"]}}) + a_sub_TsGroup_with_tm = to_pynapple_tsgroup(subset_analyzer) + assert a_sub_TsGroup_with_tm["half_width"] is not None + + +def test_non_int_unit_ids(): + """ + Pynapple only accepts integer unit ids. If a user passes unit ids which are not castable to ints, + `to_pynapple_tsgroup` will set the index to (0,1,2...) and save the original unit_ids in the + `unit_id` column of the tsgroup metadata. + """ + + # generate fake data with string unit ids + + max_sample = 1000 + num_spikes = 200 + num_units = 3 + + rng = np.random.default_rng(1205) + sample_index = np.sort(rng.choice(range(max_sample), size=num_spikes, replace=False)) + unit_index = rng.choice(range(num_units), size=num_spikes) + segment_index = np.zeros(shape=num_spikes).astype("int") + + spikes = np.zeros( + shape=(200), dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + ) + spikes["sample_index"] = sample_index + spikes["unit_index"] = unit_index + spikes["segment_index"] = segment_index + + sorting = NumpySorting(spikes, sampling_frequency=30_000, unit_ids=["zero", "one", "two"]) + + # the str typed `unit_ids`` should raise a warning + with pytest.warns(UserWarning): + ts = to_pynapple_tsgroup(sorting, attach_unit_metadata=False) + + assert np.all(ts.metadata["unit_id"].values == np.array(["zero", "one", "two"])) diff --git a/src/spikeinterface/exporters/to_pynapple.py b/src/spikeinterface/exporters/to_pynapple.py new file mode 100644 index 0000000000..b81e5a521d --- /dev/null +++ b/src/spikeinterface/exporters/to_pynapple.py @@ -0,0 +1,91 @@ +from spikeinterface.core import SortingAnalyzer, BaseSorting +import numpy as np +from warnings import warn + + +def to_pynapple_tsgroup( + sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting, + attach_unit_metadata=True, + segment_index=None, +): + """ + Returns a pynapple TsGroup object based on spike train data. + + Parameters + ---------- + sorting_analyzer_or_sorting : SortingAnalyzer + A SortingAnalyzer object + attach_unit_metadata : bool, default: True + If True, any relevant available metadata is attached to the TsGroup. Will attach + `unit_locations`, `quality_metrics` and `template_metrics` if computed. If False, + no metadata is included. + segment_index : int | None, default: None + The segment index. Can be None if mono-segment sorting. + + Returns + ------- + spike_train_TsGroup : pynapple.TsGroup + A TsGroup object from the pynapple package. + """ + from pynapple import TsGroup, Ts + import pandas as pd + + if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): + sorting = sorting_analyzer_or_sorting.sorting + elif isinstance(sorting_analyzer_or_sorting, BaseSorting): + sorting = sorting_analyzer_or_sorting + else: + raise TypeError( + f"The `sorting_analyzer_or_sorting` argument must be a SortingAnalyzer or Sorting object, not a {type(sorting_analyzer_or_sorting)} type object." + ) + + unit_ids = sorting.unit_ids + + unit_ids_castable = True + try: + unit_ids_ints = [int(unit_id) for unit_id in unit_ids] + except ValueError: + warn_msg = "Pynapple requires integer unit ids, but `unit_ids` cannot be cast to int. " + warn_msg += "We will set the index of the TsGroup to [0,1,2,...] and attach the original " + warn_msg += "unit ids to the TsGroup as metadata with the name 'unit_id'." + warn(warn_msg) + unit_ids_ints = np.arange(len(unit_ids)) + unit_ids_castable = False + + spikes_trains = { + unit_id_int: sorting.get_unit_spike_train(unit_id=unit_id, return_times=True, segment_index=segment_index) + for unit_id_int, unit_id in zip(unit_ids_ints, unit_ids) + } + + metadata_list = [] + if not unit_ids_castable: + metadata_list.append(pd.DataFrame(unit_ids, columns=["unit_id"])) + + # Look for good metadata to add, if there is a sorting analyzer + if attach_unit_metadata and isinstance(sorting_analyzer_or_sorting, SortingAnalyzer): + + metadata_list = [] + if (unit_locations := sorting_analyzer_or_sorting.get_extension("unit_locations")) is not None: + array_of_unit_locations = unit_locations.get_data() + n_dims = np.shape(sorting_analyzer_or_sorting.get_extension("unit_locations").get_data())[1] + pd_of_unit_locations = pd.DataFrame( + array_of_unit_locations, columns=["x", "y", "z"][:n_dims], index=unit_ids + ) + metadata_list.append(pd_of_unit_locations) + if (quality_metrics := sorting_analyzer_or_sorting.get_extension("quality_metrics")) is not None: + metadata_list.append(quality_metrics.get_data()) + if (template_metrics := sorting_analyzer_or_sorting.get_extension("template_metrics")) is not None: + metadata_list.append(template_metrics.get_data()) + + if len(metadata_list) > 0: + metadata = pd.concat(metadata_list, axis=1) + metadata.index = unit_ids_ints + else: + metadata = None + + spike_train_tsgroup = TsGroup( + {unit_id: Ts(spike_train) for unit_id, spike_train in spikes_trains.items()}, + metadata=metadata, + ) + + return spike_train_tsgroup