-
Notifications
You must be signed in to change notification settings - Fork 2
Add SONATA point-process cell support #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
626e54f
bd23ea4
17195f0
3a323c0
988664b
c9f0db8
e3b12e9
7adda77
c203180
3185fbb
63a1b59
dc95223
8249daa
cba681e
708b303
45b18aa
e74a00b
f0264f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,198 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| from pathlib import Path | ||
| from typing import Optional | ||
|
|
||
| from bluecellulab.cell import Cell | ||
| from bluecellulab.circuit.simulation_access import get_synapse_replay_spikes | ||
| from bluecellulab.exceptions import BluecellulabError | ||
| from bluecellulab.circuit import SynapseProperty | ||
| from neuron import h | ||
| import numpy as np | ||
|
|
||
| from bluecellulab.circuit.node_id import CellId | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class BasePointProcessCell(Cell): | ||
| """Base class for NEURON artificial point processes (IntFire1/2/...).""" | ||
|
|
||
| def __init__(self, cell_id: Optional[CellId]) -> None: | ||
| if cell_id is None: | ||
| raise ValueError("PointProcessCell requires valid cell_id") | ||
| self.cell_id = cell_id | ||
|
|
||
| self._spike_times = h.Vector() | ||
| self._spike_detector: Optional[h.NetCon] = None | ||
| self.pointcell = None # type: ignore[assignment] | ||
| self.synapses: dict = {} | ||
| self.connections: dict = {} | ||
|
|
||
| @property | ||
| def hoc_cell(self): | ||
| return self.pointcell | ||
|
|
||
| def init_callbacks(self): | ||
| pass | ||
|
|
||
| def connect_to_circuit(self, proxy) -> None: | ||
| self._circuit_proxy = proxy | ||
|
|
||
| def delete(self) -> None: | ||
| # Stop recording | ||
| if self._spike_detector is not None: | ||
| # NetCon will be GC'd when no Python refs remain | ||
| self._spike_detector = None | ||
| if self._spike_times is not None: | ||
| self._spike_times = None | ||
|
|
||
| # Drop pointer to underlying NEURON object | ||
| self.pointcell = None | ||
|
|
||
| def get_spike_times(self) -> list[float]: | ||
| return list(self._spike_times) | ||
|
|
||
| def create_netcon_spikedetector( | ||
| self, | ||
| sec, # ignored for artificial cells | ||
| location=None, # ignored for artificial cells | ||
| threshold: float = 0.0, | ||
| ) -> h.NetCon: | ||
| if self.pointcell is None: | ||
| raise ValueError("attempting to create netcon without valid pointprocess") | ||
| nc = h.NetCon(self.pointcell.pointcell, None) | ||
| nc.threshold = threshold # harmless for artificial cells | ||
| return nc | ||
|
|
||
| def is_recording_spikes(self, location=None, threshold: float | None = None) -> bool: | ||
| return self._spike_detector is not None | ||
|
|
||
| def start_recording_spikes(self, sec, location=None, threshold: float = 0.0) -> None: | ||
| if self._spike_detector is not None: | ||
| return | ||
| if self.pointcell is None: | ||
| raise ValueError("attempting to record spikes without valid pointprocess") | ||
| self._spike_times = h.Vector() | ||
| self._spike_detector = h.NetCon(self.pointcell.pointcell, None) | ||
| self._spike_detector.threshold = threshold | ||
| self._spike_detector.record(self._spike_times) | ||
|
|
||
|
|
||
| class HocPointProcessCell(BasePointProcessCell): | ||
| """Point process that wraps an arbitrary HOC/mod artificial mechanism.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| cell_id: Optional[CellId], | ||
| mechanism_name: str, | ||
| spike_threshold: float = 1.0, | ||
| ) -> None: | ||
| super().__init__(cell_id) | ||
|
|
||
| try: | ||
| mech_cls = getattr(h, mechanism_name) | ||
| except AttributeError as exc: | ||
| raise BluecellulabError( | ||
| f"Point mechanism '{mechanism_name}' not found in NEURON. " | ||
| "Make sure the mod/hoc files are compiled and loaded." | ||
| ) from exc | ||
|
|
||
| if cell_id is None: | ||
| raise ValueError("call to create pointprocess mechanism without valid cell_id") | ||
| point = mech_cls(cell_id.id) | ||
|
|
||
| self.pointcell = point | ||
| self.start_recording_spikes(None, None, threshold=spike_threshold) | ||
|
|
||
| def add_synapse_replay(self, stimulus, spike_threshold: float, spike_location: str) -> None: | ||
| """SONATA-style spike replay for point processes. | ||
|
|
||
| This is a simplified analogue of Cell.add_synapse_replay, but | ||
| instead of mapping spikes to individual synapses, we directly | ||
| connect each presynaptic node_id's spike train to this | ||
| artificial cell via VecStim → NetCon. | ||
| """ | ||
| file_path = Path(stimulus.spike_file).expanduser() | ||
|
|
||
| if not file_path.is_absolute(): | ||
| config_dir = stimulus.config_dir | ||
| if config_dir is not None: | ||
| file_path = Path(config_dir) / file_path | ||
|
|
||
| file_path = file_path.resolve() | ||
|
|
||
| if not file_path.exists(): | ||
| raise FileNotFoundError(f"Spike file not found: {str(file_path)}") | ||
|
|
||
| synapse_spikes = get_synapse_replay_spikes(str(file_path)) | ||
|
|
||
| if not hasattr(self, "_replay_vecs"): | ||
| self._replay_vecs: list[h.Vector] = [] | ||
| if not hasattr(self, "_replay_vecstims"): | ||
| self._replay_vecstims: list[h.VecStim] = [] | ||
| if not hasattr(self, "_replay_netcons"): | ||
| self._replay_netcons: list[h.NetCon] = [] | ||
|
Comment on lines
+131
to
+136
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could |
||
|
|
||
| for pre_node_id, spikes in synapse_spikes.items(): | ||
| delay = getattr(stimulus, "delay", 0.0) or 0.0 | ||
| duration = getattr(stimulus, "duration", np.inf) | ||
|
|
||
| spikes_of_interest = spikes[ | ||
| (spikes >= delay) & (spikes <= duration) | ||
| ] | ||
| if spikes_of_interest.size == 0: | ||
| continue | ||
|
|
||
| vec = h.Vector(spikes_of_interest) | ||
| vs = h.VecStim() | ||
| vs.play(vec) | ||
|
|
||
| if self.pointcell is None: | ||
| raise ValueError("attempting to add replay spikes with valid pointprocess") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the message be |
||
| nc = h.NetCon(vs, self.pointcell.pointcell) | ||
| # Use stimulus weight if available, otherwise default to 1.0 | ||
| weight = getattr(stimulus, "weight", 1.0) | ||
| nc.weight[0] = weight | ||
| nc.delay = 0.0 # delay already baked into spike times | ||
|
|
||
| self._replay_vecs.append(vec) | ||
| self._replay_vecstims.append(vs) | ||
| self._replay_netcons.append(nc) | ||
|
|
||
| logger.debug( | ||
| f"Added replay connection from pre_node_id={pre_node_id} " | ||
| f"to point neuron {self.cell_id}" | ||
| ) | ||
|
|
||
| def add_replay_synapse(self, syn_id, syn_description, syn_connection_parameters, condition_parameters, | ||
| popids, extracellular_calcium): | ||
| """For Point Neurons, the replay simply queues events directly to the | ||
| point obj.""" | ||
| from bluecellulab.point.point_connection import PointProcessConnection | ||
| from bluecellulab.point.connection_params import PointProcessConnParameters | ||
|
|
||
| # syn_connection_parameters should only have 1 element, PointProcessConnection will confirm | ||
| point_params = PointProcessConnParameters(syn_description[SynapseProperty.PRE_GID], syn_description[SynapseProperty.PRE_GID], | ||
| syn_description[SynapseProperty.AXONAL_DELAY]) | ||
|
Comment on lines
+177
to
+178
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| self.pointConn = PointProcessConnection([point_params]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is assigned but never stored in |
||
|
|
||
|
|
||
| def mechanism_name_from_model_template(template_path: str, model_template: str) -> str: | ||
| """Translate SONATA model_template into a NEURON mechanism name. | ||
|
|
||
| Examples: | ||
| 'hoc:AllenPointCell' -> 'AllenPointCell' | ||
| 'nrn:IntFire1' -> 'IntFire1' | ||
| 'AllenPointCell' -> 'AllenPointCell' | ||
| """ | ||
| mt = str(model_template).strip() | ||
| if ":" in mt: | ||
| prefix, name = mt.split(":", 1) | ||
| prefix = prefix.lower() | ||
| if prefix in ("hoc", "nrn"): | ||
| h.load_file(template_path) | ||
| return name | ||
| return mt | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -182,7 +182,11 @@ def _select_edge_pop_names(self, projections) -> list[str]: | |
| def extract_synapses( | ||
| self, cell_id: CellId, projections: Optional[list[str] | str] | ||
| ) -> pd.DataFrame: | ||
| """Extract the synapses.""" | ||
| """Extract the synapses. Checks available fields to determine which are | ||
| present in the edge file to determine the properties to extract. | ||
|
|
||
| If projections is None, all the synapses are extracted. | ||
| """ | ||
| snap_node_id = CircuitNodeId(cell_id.population_name, cell_id.id) | ||
| edges = self._circuit.edges | ||
|
|
||
|
|
@@ -200,7 +204,10 @@ def extract_synapses( | |
|
|
||
| # remove optional properties if they are not present | ||
| for optional_property in [SynapseProperty.U_HILL_COEFFICIENT, | ||
| SynapseProperty.CONDUCTANCE_RATIO]: | ||
| SynapseProperty.CONDUCTANCE_RATIO, | ||
| SynapseProperty.AFFERENT_SECTION_POS, | ||
| SynapseProperty.POST_SEGMENT_ID, | ||
| SynapseProperty.POST_SEGMENT_OFFSET]: | ||
| if optional_property.to_snap() not in edge_population.property_names: | ||
| edge_properties.remove(optional_property) | ||
|
|
||
|
|
@@ -211,6 +218,20 @@ def extract_synapses( | |
| ): | ||
| edge_properties += list(SynapseProperties.plasticity) | ||
|
|
||
| # check for allen instance - replace the entire edge_properties list as appropriate | ||
| # properties for allen point/chemical neuron connection type edges | ||
| if len(edge_population.property_names) < 10: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you document in comments why 10 was chosen (e.g., "Allen point/chemical edges have 5–8 properties vs 15+ for BBP-style edges")? I think it may be fine for now. However, can we rely on the number |
||
| if all( | ||
| x in edge_population.property_names | ||
| for x in SynapseProperties.allen_point | ||
| ): | ||
| edge_properties = list(SynapseProperties.allen_point) | ||
| if all( | ||
| x in edge_population.property_names | ||
| for x in SynapseProperties.allen_chemical | ||
| ): | ||
| edge_properties = list(SynapseProperties.allen_chemical) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| snap_properties = properties_to_snap(edge_properties) | ||
| synapses: pd.DataFrame = edge_population.get(afferent_edges, snap_properties) | ||
| column_names = list(synapses.columns) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ class SynapseProperty(Enum): | |
| PRE_GID = "pre_gid" | ||
| AXONAL_DELAY = "axonal_delay" | ||
| POST_SECTION_ID = "post_section_id" | ||
| POST_SECTION_POS = "post_section_pos" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| POST_SEGMENT_ID = "post_segment_id" | ||
| POST_SEGMENT_OFFSET = "post_segment_offset" | ||
| G_SYNX = "g_synx" | ||
|
|
@@ -83,6 +84,14 @@ class SynapseProperties: | |
| "volume_CR", "rho0_GB", "Use_d_TM", "Use_p_TM", "gmax_d_AMPA", | ||
| "gmax_p_AMPA", "theta_d", "theta_p" | ||
| ) | ||
| allen_chemical = ( | ||
| "afferent_section_id", "afferent_section_pos", "conductance", "delay", "tau1", "tau2", "erev", | ||
| "@source_node" | ||
| ) | ||
| allen_point = ( | ||
| "afferent_section_id", "afferent_section_pos", "conductance", "delay", | ||
| "@source_node" | ||
| ) | ||
|
|
||
|
|
||
| snap_to_synproperty = MappingProxyType({ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,6 +66,8 @@ | |
| from bluecellulab.simulation.modifications import apply_modifications | ||
| from bluecellulab.synapse.synapse_types import SynapseID | ||
|
|
||
| from bluecellulab.cell.point_process import BasePointProcessCell, HocPointProcessCell, mechanism_name_from_model_template | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
|
|
@@ -433,7 +435,32 @@ def _add_stimuli( | |
| except ValueError: | ||
| pass | ||
|
|
||
| all_point_processes = all( | ||
| isinstance(cell, BasePointProcessCell) for cell in self.cells.values() | ||
| ) | ||
|
|
||
| for stimulus in stimuli_entries: | ||
|
|
||
| # 1) SynapseReplay: works for both morpho cells and point processes | ||
| if isinstance(stimulus, circuit_stimulus_definitions.SynapseReplay): | ||
| for cell_id, cell in self.cells.items(): | ||
| if self.circuit_access.target_contains_cell(stimulus.target, cell_id): | ||
| if hasattr(cell, "add_synapse_replay"): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| print("Adding SynapseReplay to cell", cell_id) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| cell.add_synapse_replay( | ||
| stimulus, self.spike_threshold, self.spike_location | ||
| ) | ||
| logger.debug( | ||
| f"Added SynapseReplay {stimulus} to point/morpho cell {cell_id}" | ||
| ) | ||
| # No section/compartment logic needed for SynapseReplay | ||
| continue | ||
|
|
||
| # 2) Other stimuli: require morphology | ||
| # If all cells are point processes, skip these stimuli entirely. | ||
| if all_point_processes: | ||
| continue | ||
|
|
||
| # Build a unified list of (cell_id, section, segx, section_name) targets | ||
| targets: list[tuple] = [] | ||
|
|
||
|
|
@@ -444,6 +471,9 @@ def _add_stimuli( | |
| stimulus.node_set | ||
| ) | ||
| for cell_id in self.cells: | ||
| # Skip point processes: they have no soma | ||
| if isinstance(self.cells[cell_id], BasePointProcessCell): | ||
| continue | ||
| if cell_id not in gids_of_target: | ||
| continue | ||
| sec = self.cells[cell_id].soma | ||
|
|
@@ -584,21 +614,18 @@ def _add_cell_synapses( | |
| logger.warning( | ||
| f"No presynaptic cells found for gid {cell_id}, no synapses added" | ||
| ) | ||
|
|
||
| else: | ||
| for idx, syn_description in syn_descriptions.iterrows(): | ||
| popids = ( | ||
| syn_description["source_popid"], | ||
| syn_description["target_popid"], | ||
| ) | ||
|
|
||
| self._instantiate_synapse( | ||
| cell_id=cell_id, | ||
| syn_id=idx, # type: ignore | ||
| syn_description=syn_description, | ||
| add_minis=add_minis, | ||
| popids=popids, | ||
|
|
||
| ) | ||
| logger.info(f"Added {syn_descriptions} synapses for gid {cell_id}") | ||
| if add_minis: | ||
|
|
@@ -1164,6 +1191,25 @@ def fetch_cell_kwargs(self, cell_id: CellId) -> dict: | |
|
|
||
| def create_cell_from_circuit(self, cell_id: CellId) -> bluecellulab.Cell: | ||
| """Create a Cell object from the circuit.""" | ||
| if self.circuit_format == CircuitFormat.SONATA: | ||
| try: | ||
| info = self.circuit_access.fetch_cell_info(cell_id) # type: ignore[attr-defined] | ||
| except AttributeError: | ||
| info = pd.Series() | ||
|
|
||
| model_type = str(info.get("model_type", "")).lower() | ||
| model_template = str(info.get("model_template", "")) | ||
|
|
||
| if model_type == "point_process": | ||
| mech_name = mechanism_name_from_model_template(template_path=self.circuit_access.emodel_path(cell_id), model_template=model_template) | ||
|
|
||
| # TODO (later): parse dynamics_params and feed param_overrides | ||
| return HocPointProcessCell( | ||
| cell_id=cell_id, | ||
| mechanism_name=mech_name, | ||
| spike_threshold=self.spike_threshold, | ||
| ) | ||
|
|
||
| cell_kwargs = self.fetch_cell_kwargs(cell_id) | ||
| return bluecellulab.Cell( | ||
| template_path=cell_kwargs["template_path"], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from __future__ import annotations | ||
| from dataclasses import dataclass | ||
|
|
||
|
|
||
| @dataclass | ||
| class PointProcessConnParameters: | ||
| """Point-neuron connection parameters (Allen-style / Neurodamus mirror).""" | ||
|
|
||
| sgid: int # source gid | ||
| delay: float # ms | ||
| weight: float # NetCon weight | ||
|
|
||
| # isec: int = -1 | ||
| # ipt: int = -1 | ||
| # offset: float = 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This inherits from
Cellbut never callssuper().__init__(), so none ofCell's attributes(self.recordings, self.soma, self.cell, self.record_dt, etc.)are initialized. This means inherited methods likeget_time(),get_recording(), etc., will crash with AttributeError. The bare except Exception: continue blocks added in utils.py may catch the errors though. Maybe initialising the minimum required attributes so inherited methods degrade gracefully instead of relying on catch-all exception handlers.