Skip to content

Add numeric solver to synapses #1208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 30 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
43e6183
Numeric solver for synapse models
pnbabu May 15, 2024
cab344e
Merge remote-tracking branch 'upstream/master' into synapse_numeric
pnbabu May 15, 2024
a50f875
Add numeric solver for synapses
pnbabu Jun 2, 2024
88069e0
Merge remote-tracking branch 'upstream/master' into synapse_numeric
pnbabu Jun 19, 2024
0aa8ba1
Numeric solution
pnbabu Jun 20, 2024
fb0ec2d
Merge remote-tracking branch 'upstream/master' into synapse_numeric
pnbabu Dec 12, 2024
80faa93
Add numneric solver to synapses
pnbabu Dec 13, 2024
bc2c078
Merge remote-tracking branch 'upstream/master' into synapse_numeric
pnbabu Jan 17, 2025
78f1b1f
Merge remote-tracking branch 'upstream/master' into synapse_numeric
pnbabu Feb 20, 2025
bef5bcb
Merge remote-tracking branch 'upstream/master' into synapse_numeric
pnbabu Mar 20, 2025
0f92473
Merge remote-tracking branch 'upstream/master' into synapse_numeric
pnbabu Apr 24, 2025
129f4c9
Add numeric solver to synapse template
pnbabu May 7, 2025
5cc3bf9
Merge remote-tracking branch 'upstream/master' into synapse_numeric
pnbabu May 7, 2025
4c5b06c
Modify synapse templates
pnbabu May 8, 2025
ee59762
Modify synapse templates
pnbabu May 8, 2025
f59b15d
Merge remote-tracking branch 'upstream/master' into synapse_numeric
pnbabu Jul 4, 2025
07a45ec
Add test for non-linear synpase
pnbabu Jul 7, 2025
18be17f
Fix pycodestyle error
pnbabu Jul 7, 2025
d3d3a42
Fix tests
pnbabu Jul 7, 2025
940c9a4
Add synapse model inside a namespace
pnbabu Jul 7, 2025
b3dc52b
Fix test failures
pnbabu Jul 8, 2025
bda28b0
Modify templates to fix namespaces in synapse models
pnbabu Jul 9, 2025
6372da3
Fix templates
pnbabu Jul 9, 2025
af4e117
Merge remote-tracking branch 'upstream/master' into synapse_numeric
pnbabu Jul 9, 2025
637eb6c
Add copyright header
pnbabu Jul 9, 2025
7a85ffd
Fix test failure
pnbabu Jul 10, 2025
4f9b327
Add test
pnbabu Jul 10, 2025
8b02787
Merge remote-tracking branch 'upstream/master' into synapse_numeric
pnbabu Jul 14, 2025
0739a2e
Modify test
pnbabu Jul 16, 2025
65715a1
Fix pycodestyle
pnbabu Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,23 @@ def setup_printers(self):
self._gsl_variable_printer = GSLVariablePrinter(None)
if self.option_exists("nest_version") and (self.get_option("nest_version").startswith("2") or self.get_option("nest_version").startswith("v2")):
self._gsl_function_call_printer = NEST2GSLFunctionCallPrinter(None)
self._gsl_function_call_printer_no_origin = NEST2GSLFunctionCallPrinter(None)
else:
self._gsl_function_call_printer = NESTGSLFunctionCallPrinter(None)
self._gsl_function_call_printer_no_origin = NEST2GSLFunctionCallPrinter(None)

self._gsl_printer = CppExpressionPrinter(simple_expression_printer=CppSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer,
constant_printer=self._constant_printer,
function_call_printer=self._gsl_function_call_printer))
self._gsl_function_call_printer._expression_printer = self._gsl_printer

self._gsl_variable_printer_no_origin = GSLVariablePrinter(None, with_origin=False)
self._gsl_printer_no_origin = CppExpressionPrinter(simple_expression_printer=CppSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer_no_origin,
constant_printer=self._constant_printer,
function_call_printer=self._gsl_function_call_printer))
self._gsl_variable_printer_no_origin._expression_printer = self._gsl_printer_no_origin
self._gsl_function_call_printer_no_origin._expression_printer = self._gsl_printer_no_origin

# ODE-toolbox printers
self._ode_toolbox_variable_printer = ODEToolboxVariablePrinter(None)
self._ode_toolbox_function_call_printer = ODEToolboxFunctionCallPrinter(None)
Expand Down Expand Up @@ -521,6 +530,7 @@ def _get_model_namespace(self, astnode: ASTModel) -> Dict:
namespace["printer"] = self._nest_printer
namespace["printer_no_origin"] = self._printer_no_origin
namespace["gsl_printer"] = self._gsl_printer
namespace["gsl_printer_no_origin"] = self._gsl_printer_no_origin
namespace["nestml_printer"] = NESTMLPrinter()
namespace["type_symbol_printer"] = self._type_symbol_printer

Expand Down Expand Up @@ -666,6 +676,9 @@ def _get_synapse_model_namespace(self, synapse: ASTModel) -> Dict:
expr_ast.accept(ASTSymbolTableVisitor())
namespace["numeric_update_expressions"][sym] = expr_ast

ASTUtils.assign_numeric_non_numeric_state_variables(synapse, namespace["numeric_state_variables"],
namespace["numeric_update_expressions"] if "numeric_update_expressions" in namespace.keys() else None, namespace["update_expressions"] if "update_expressions" in namespace.keys() else None)

namespace["spike_updates"] = synapse.spike_updates

# special case for NEST delay variable (state or parameter)
Expand Down
3 changes: 0 additions & 3 deletions pynestml/codegeneration/nest_code_generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ def print_symbol_origin(cls, variable_symbol: VariableSymbol, variable: ASTVaria
if variable_symbol.block_type == BlockType.INTERNALS:
return "V_.%s"

if variable_symbol.block_type == BlockType.INPUT:
return "B_.%s"

return ""

@classmethod
Expand Down
52 changes: 29 additions & 23 deletions pynestml/codegeneration/printers/gsl_variable_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
from pynestml.codegeneration.nest_code_generator_utils import NESTCodeGeneratorUtils
from pynestml.codegeneration.nest_unit_converter import NESTUnitConverter
from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter
from pynestml.codegeneration.printers.expression_printer import ExpressionPrinter
from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.symbols.predefined_units import PredefinedUnits
from pynestml.symbols.symbol import SymbolKind
Expand All @@ -33,46 +35,42 @@ class GSLVariablePrinter(CppVariablePrinter):
Variable printer for C++ syntax and using the GSL (GNU Scientific Library) API from inside the ``extern "C"`` stepping function.
"""

def print_variable(self, node: ASTVariable) -> str:
def __init__(self, expression_printer: ExpressionPrinter, with_origin: bool = True, ):
super().__init__(expression_printer)
self.with_origin = with_origin

def print_variable(self, variable: ASTVariable) -> str:
"""
Converts a single name reference to a gsl processable format.
:param node: a single variable
:param variable: a single variable
:return: a gsl processable format of the variable
"""
assert isinstance(node, ASTVariable)
symbol = node.get_scope().resolve_to_symbol(node.get_complete_name(), SymbolKind.VARIABLE)
assert isinstance(variable, ASTVariable)
symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE)

if symbol is None:
# test if variable name can be resolved to a type
if PredefinedUnits.is_unit(node.get_complete_name()):
return str(NESTUnitConverter.get_factor(PredefinedUnits.get_unit(node.get_complete_name()).get_unit()))
if PredefinedUnits.is_unit(variable.get_complete_name()):
return str(
NESTUnitConverter.get_factor(PredefinedUnits.get_unit(variable.get_complete_name()).get_unit()))

code, message = Messages.get_could_not_resolve(node.get_name())
code, message = Messages.get_could_not_resolve(variable.get_name())
Logger.log_message(log_level=LoggingLevel.ERROR, code=code, message=message,
error_position=node.get_source_position())
error_position=variable.get_source_position())
return ""

if node.is_delay_variable():
return self._print_delay_variable(node)
if variable.is_delay_variable():
return self._print_delay_variable(variable)

if symbol.is_state() and not symbol.is_inline_expression:
if "_is_numeric" in dir(node) and node._is_numeric:
if "_is_numeric" in dir(variable) and variable._is_numeric:
# ode_state[] here is---and must be---the state vector supplied by the integrator, not the state vector in the node, node.S_.ode_state[].
return "ode_state[State_::" + CppVariablePrinter._print_cpp_name(node.get_complete_name()) + "]"

# non-ODE state symbol
return "node.S_." + CppVariablePrinter._print_cpp_name(node.get_complete_name())

if symbol.is_parameters():
return "node.P_." + super().print_variable(node)

if symbol.is_internals():
return "node.V_." + super().print_variable(node)
return "ode_state[State_::" + CppVariablePrinter._print_cpp_name(variable.get_complete_name()) + "]"

if symbol.is_input():
return "node.B_." + self._print_buffer_value(node)
return "node.B_." + self._print_buffer_value(variable)

raise Exception("Unknown node type")
return self._print(variable, symbol, with_origin=self.with_origin)

def _print_delay_variable(self, variable: ASTVariable) -> str:
"""
Expand Down Expand Up @@ -104,3 +102,11 @@ def _print_buffer_value(self, variable: ASTVariable) -> str:
return "spike_inputs_grid_sum_[node." + var_name + " - node.MIN_SPIKE_RECEPTOR]"

return variable_symbol.get_symbol_name() + '_grid_sum_'

def _print(self, variable, symbol, with_origin: bool = True):
variable_name = CppVariablePrinter._print_cpp_name(variable.get_complete_name())

if with_origin:
return "node." + NESTCodeGeneratorUtils.print_symbol_origin(symbol, variable) % variable_name

return "node." + variable_name
Loading
Loading