Skip to content
This repository was archived by the owner on Mar 25, 2025. It is now read-only.

Make "sparse" solver check if equations are linear. #860

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from importlib import import_module

import sympy as sp
import itertools

# import known_functions through low-level mechanism because the ccode
# module is overwritten in sympy and contents of that submodule cannot be
Expand Down Expand Up @@ -272,6 +273,8 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):

eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants)

linear = _is_linear(eqs, state_vars, sympy_vars)

custom_fcts = _get_custom_functions(function_calls)

jacobian = sp.Matrix(eqs).jacobian(state_vars)
Expand All @@ -291,7 +294,18 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):
# interweave
code = _interweave_eqs(vecFcode, vecJcode)

return code
return code, linear


def _is_linear(eqs, state_vars, sympy_vars):
for expr in eqs:
for (x, y) in itertools.combinations_with_replacement(state_vars, 2):
try:
if not sp.Eq(sp.diff(expr, x, y), 0):
return False
except TypeError:
return False
return True


def integrate2c(diff_string, dt_var, vars, use_pade_approx=False):
Expand Down
2 changes: 2 additions & 0 deletions src/pybind/pyembed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ struct SolveNonLinearSystemExecutor: public PythonExecutor {
// output
// returns a vector of solutions, i.e. new statements to add to block:
std::vector<std::string> solutions;
// returns if the system is linear or not.
bool linear;
// may also return a python exception message:
std::string exception_message;

Expand Down
4 changes: 3 additions & 1 deletion src/pybind/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,22 @@ void SolveNonLinearSystemExecutor::operator()() {
from nmodl.ode import solve_non_lin_system
exception_message = ""
try:
solutions = solve_non_lin_system(equation_strings,
solutions, linear = solve_non_lin_system(equation_strings,
state_vars,
vars,
function_calls)
except Exception as e:
# if we fail, fail silently and return empty string
solutions = [""]
linear = False
new_local_vars = [""]
exception_message = str(e)
)",
py::globals(),
locals);
// returns a vector of solutions, i.e. new statements to add to block:
solutions = locals["solutions"].cast<std::vector<std::string>>();
linear = locals["linear"].cast<bool>();
// may also return a python exception message:
exception_message = locals["exception_message"].cast<std::string>();
}
Expand Down
10 changes: 8 additions & 2 deletions src/visitors/sympy_solver_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ void SympySolverVisitor::solve_non_linear_system(
(*solver)();
// returns a vector of solutions, i.e. new statements to add to block:
auto solutions = solver->solutions;
bool linear = solver->linear;
// may also return a python exception message:
auto exception_message = solver->exception_message;
pywrap::EmbeddedPythonLoader::get_instance().api()->destroy_nsls_executor(solver);
Expand All @@ -364,8 +365,13 @@ void SympySolverVisitor::solve_non_linear_system(
exception_message);
return;
}
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
construct_eigen_solver_block(pre_solve_statements, solutions, false);
if (!linear) {
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
}
else {
logger->debug("SympySolverVisitor :: Constructing eigen solve block");
}
construct_eigen_solver_block(pre_solve_statements, solutions, linear);
}

void SympySolverVisitor::visit_var_name(ast::VarName& node) {
Expand Down
Loading