Skip to content
This repository was archived by the owner on Mar 25, 2025. It is now read-only.

Commit d868bd3

Browse files
Make "sparse" solver check if equations are linear.
If the system is linear, then newtons method always converges in exactly one iteration. When using the sparse solver on linear systems omit the newtons iteration and solve directly. This should make the resulting code run marginally faster by skipping the check for convergence. Currently the check for convergence is implemented as "error = sqrt(|F|^2)".
1 parent cde5dbf commit d868bd3

File tree

5 files changed

+50
-42
lines changed

5 files changed

+50
-42
lines changed

nmodl/ode.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):
272272

273273
eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants)
274274

275+
linear = _is_linear(eqs, state_vars, sympy_vars)
276+
275277
custom_fcts = _get_custom_functions(function_calls)
276278

277279
jacobian = sp.Matrix(eqs).jacobian(state_vars)
@@ -291,7 +293,19 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):
291293
# interweave
292294
code = _interweave_eqs(vecFcode, vecJcode)
293295

294-
return code
296+
return code, linear
297+
298+
299+
def _is_linear(eqs, state_vars, sympy_vars):
300+
for expr in eqs:
301+
for x in state_vars:
302+
for y in state_vars:
303+
try:
304+
if not sp.Eq(sp.diff(expr, x, y), 0):
305+
return False
306+
except TypeError:
307+
return False
308+
return True
295309

296310

297311
def integrate2c(diff_string, dt_var, vars, use_pade_approx=False):

src/pybind/pyembed.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ struct SolveNonLinearSystemExecutor: public PythonExecutor {
5656
// output
5757
// returns a vector of solutions, i.e. new statements to add to block:
5858
std::vector<std::string> solutions;
59+
// returns if the system is linear or not.
60+
bool linear;
5961
// may also return a python exception message:
6062
std::string exception_message;
6163

src/pybind/wrapper.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,22 @@ void SolveNonLinearSystemExecutor::operator()() {
6666
from nmodl.ode import solve_non_lin_system
6767
exception_message = ""
6868
try:
69-
solutions = solve_non_lin_system(equation_strings,
69+
solutions, linear = solve_non_lin_system(equation_strings,
7070
state_vars,
7171
vars,
7272
function_calls)
7373
except Exception as e:
7474
# if we fail, fail silently and return empty string
7575
solutions = [""]
76+
linear = False
7677
new_local_vars = [""]
7778
exception_message = str(e)
7879
)",
7980
py::globals(),
8081
locals);
8182
// returns a vector of solutions, i.e. new statements to add to block:
8283
solutions = locals["solutions"].cast<std::vector<std::string>>();
84+
linear = locals["linear"].cast<bool>();
8385
// may also return a python exception message:
8486
exception_message = locals["exception_message"].cast<std::string>();
8587
}

src/visitors/sympy_solver_visitor.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ void SympySolverVisitor::solve_non_linear_system(
356356
(*solver)();
357357
// returns a vector of solutions, i.e. new statements to add to block:
358358
auto solutions = solver->solutions;
359+
bool linear = solver->linear;
359360
// may also return a python exception message:
360361
auto exception_message = solver->exception_message;
361362
pywrap::EmbeddedPythonLoader::get_instance().api()->destroy_nsls_executor(solver);
@@ -364,8 +365,13 @@ void SympySolverVisitor::solve_non_linear_system(
364365
exception_message);
365366
return;
366367
}
367-
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
368-
construct_eigen_solver_block(pre_solve_statements, solutions, false);
368+
if (!linear) {
369+
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
370+
}
371+
else {
372+
logger->debug("SympySolverVisitor :: Constructing eigen solve block");
373+
}
374+
construct_eigen_solver_block(pre_solve_statements, solutions, linear);
369375
}
370376

371377
void SympySolverVisitor::visit_var_name(ast::VarName& node) {

test/unit/visitor/sympy_solver.cpp

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
619619
)";
620620
std::string expected_result = R"(
621621
DERIVATIVE states {
622-
EIGEN_NEWTON_SOLVE[1]{
622+
EIGEN_LINEAR_SOLVE[1]{
623623
LOCAL old_m
624624
}{
625625
IF (mInf == 1) {
@@ -628,7 +628,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
628628
old_m = m
629629
}{
630630
nmodl_eigen_x[0] = m
631-
}{
632631
nmodl_eigen_f[0] = (-nmodl_eigen_x[0]*dt+dt*mInf+mTau*(-nmodl_eigen_x[0]+old_m))/mTau
633632
nmodl_eigen_j[0] = -(dt+mTau)/mTau
634633
}{
@@ -659,15 +658,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
659658
})";
660659
std::string expected_result = R"(
661660
DERIVATIVE states {
662-
EIGEN_NEWTON_SOLVE[2]{
661+
EIGEN_LINEAR_SOLVE[2]{
663662
LOCAL a, b, old_y, old_x
664663
}{
665664
old_y = y
666665
old_x = x
667666
}{
668667
nmodl_eigen_x[0] = x
669668
nmodl_eigen_x[1] = y
670-
}{
671669
nmodl_eigen_f[0] = -nmodl_eigen_x[1]+a*dt+old_y
672670
nmodl_eigen_j[0] = 0
673671
nmodl_eigen_j[2] = -1.0
@@ -703,15 +701,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
703701
})";
704702
std::string expected_result = R"(
705703
DERIVATIVE states {
706-
EIGEN_NEWTON_SOLVE[2]{
704+
EIGEN_LINEAR_SOLVE[2]{
707705
LOCAL a, b, old_M_1, old_M_0
708706
}{
709707
old_M_1 = M[1]
710708
old_M_0 = M[0]
711709
}{
712710
nmodl_eigen_x[0] = M[0]
713711
nmodl_eigen_x[1] = M[1]
714-
}{
715712
nmodl_eigen_f[0] = -nmodl_eigen_x[1]+a*dt+old_M_1
716713
nmodl_eigen_j[0] = 0
717714
nmodl_eigen_j[2] = -1.0
@@ -748,15 +745,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
748745
})";
749746
std::string expected_result = R"(
750747
DERIVATIVE states {
751-
EIGEN_NEWTON_SOLVE[2]{
748+
EIGEN_LINEAR_SOLVE[2]{
752749
LOCAL a, b, old_x, old_y
753750
}{
754751
old_x = x
755752
old_y = y
756753
}{
757754
nmodl_eigen_x[0] = x
758755
nmodl_eigen_x[1] = y
759-
}{
760756
nmodl_eigen_f[0] = -nmodl_eigen_x[0]+a*dt+old_x
761757
nmodl_eigen_j[0] = -1.0
762758
nmodl_eigen_j[2] = 0
@@ -825,15 +821,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
825821
DERIVATIVE states {
826822
LOCAL a, b
827823
IF (a == 1) {
828-
EIGEN_NEWTON_SOLVE[2]{
824+
EIGEN_LINEAR_SOLVE[2]{
829825
LOCAL old_x, old_y
830826
}{
831827
old_x = x
832828
old_y = y
833829
}{
834830
nmodl_eigen_x[0] = x
835831
nmodl_eigen_x[1] = y
836-
}{
837832
nmodl_eigen_f[0] = -nmodl_eigen_x[0]+a*dt+old_x
838833
nmodl_eigen_j[0] = -1.0
839834
nmodl_eigen_j[2] = 0
@@ -875,15 +870,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
875870
})";
876871
std::string expected_result = R"(
877872
DERIVATIVE states {
878-
EIGEN_NEWTON_SOLVE[2]{
873+
EIGEN_LINEAR_SOLVE[2]{
879874
LOCAL a, b, old_x, old_y
880875
}{
881876
old_x = x
882877
old_y = y
883878
}{
884879
nmodl_eigen_x[0] = x
885880
nmodl_eigen_x[1] = y
886-
}{
887881
nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[1]*a*dt+b*dt+old_x
888882
nmodl_eigen_j[0] = -1.0
889883
nmodl_eigen_j[2] = a*dt
@@ -901,15 +895,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
901895
})";
902896
std::string expected_result_cse = R"(
903897
DERIVATIVE states {
904-
EIGEN_NEWTON_SOLVE[2]{
898+
EIGEN_LINEAR_SOLVE[2]{
905899
LOCAL a, b, old_x, old_y
906900
}{
907901
old_x = x
908902
old_y = y
909903
}{
910904
nmodl_eigen_x[0] = x
911905
nmodl_eigen_x[1] = y
912-
}{
913906
nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[1]*a*dt+b*dt+old_x
914907
nmodl_eigen_j[0] = -1.0
915908
nmodl_eigen_j[2] = a*dt
@@ -954,7 +947,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
954947
)";
955948
std::string expected_result = R"(
956949
DERIVATIVE states {
957-
EIGEN_NEWTON_SOLVE[3]{
950+
EIGEN_LINEAR_SOLVE[3]{
958951
LOCAL a, b, c, d, h, old_x, old_y, old_z
959952
}{
960953
old_x = x
@@ -964,7 +957,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
964957
nmodl_eigen_x[0] = x
965958
nmodl_eigen_x[1] = y
966959
nmodl_eigen_x[2] = z
967-
}{
968960
nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[2]*a*dt+b*dt*h+old_x
969961
nmodl_eigen_j[0] = -1.0
970962
nmodl_eigen_j[3] = 0
@@ -986,7 +978,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
986978
})";
987979
std::string expected_cse_result = R"(
988980
DERIVATIVE states {
989-
EIGEN_NEWTON_SOLVE[3]{
981+
EIGEN_LINEAR_SOLVE[3]{
990982
LOCAL a, b, c, d, h, old_x, old_y, old_z
991983
}{
992984
old_x = x
@@ -996,7 +988,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
996988
nmodl_eigen_x[0] = x
997989
nmodl_eigen_x[1] = y
998990
nmodl_eigen_x[2] = z
999-
}{
1000991
nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[2]*a*dt+b*dt*h+old_x
1001992
nmodl_eigen_j[0] = -1.0
1002993
nmodl_eigen_j[3] = 0
@@ -1042,15 +1033,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
10421033
)";
10431034
std::string expected_result = R"(
10441035
DERIVATIVE scheme1 {
1045-
EIGEN_NEWTON_SOLVE[2]{
1036+
EIGEN_LINEAR_SOLVE[2]{
10461037
LOCAL old_mc, old_m
10471038
}{
10481039
old_mc = mc
10491040
old_m = m
10501041
}{
10511042
nmodl_eigen_x[0] = mc
10521043
nmodl_eigen_x[1] = m
1053-
}{
10541044
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc
10551045
nmodl_eigen_j[0] = -a*dt-1.0
10561046
nmodl_eigen_j[2] = b*dt
@@ -1086,14 +1076,13 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
10861076
)";
10871077
std::string expected_result = R"(
10881078
DERIVATIVE scheme1 {
1089-
EIGEN_NEWTON_SOLVE[2]{
1079+
EIGEN_LINEAR_SOLVE[2]{
10901080
LOCAL old_mc
10911081
}{
10921082
old_mc = mc
10931083
}{
10941084
nmodl_eigen_x[0] = mc
10951085
nmodl_eigen_x[1] = m
1096-
}{
10971086
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc
10981087
nmodl_eigen_j[0] = -a*dt-1.0
10991088
nmodl_eigen_j[2] = b*dt
@@ -1131,15 +1120,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
11311120
)";
11321121
std::string expected_result = R"(
11331122
DERIVATIVE scheme1 {
1134-
EIGEN_NEWTON_SOLVE[2]{
1123+
EIGEN_LINEAR_SOLVE[2]{
11351124
LOCAL old_mc, old_m
11361125
}{
11371126
old_mc = mc
11381127
old_m = m
11391128
}{
11401129
nmodl_eigen_x[0] = mc
11411130
nmodl_eigen_x[1] = m
1142-
}{
11431131
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc
11441132
nmodl_eigen_j[0] = -a*dt-1.0
11451133
nmodl_eigen_j[2] = b*dt
@@ -1180,7 +1168,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
11801168
})";
11811169
std::string expected_result = R"(
11821170
DERIVATIVE ihkin {
1183-
EIGEN_NEWTON_SOLVE[5]{
1171+
EIGEN_LINEAR_SOLVE[5]{
11841172
LOCAL alpha, beta, k3p, k4, k1ca, k2, old_c1, old_o1, old_p0
11851173
}{
11861174
evaluate_fct(v, cai)
@@ -1193,7 +1181,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
11931181
nmodl_eigen_x[2] = o2
11941182
nmodl_eigen_x[3] = p0
11951183
nmodl_eigen_x[4] = p1
1196-
}{
11971184
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*alpha*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*beta*dt+old_c1
11981185
nmodl_eigen_j[0] = -alpha*dt-1.0
11991186
nmodl_eigen_j[5] = beta*dt
@@ -1260,13 +1247,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
12601247
)";
12611248
std::string expected_result = R"(
12621249
DERIVATIVE scheme1 {
1263-
EIGEN_NEWTON_SOLVE[1]{
1250+
EIGEN_LINEAR_SOLVE[1]{
12641251
LOCAL old_W_0
12651252
}{
12661253
old_W_0 = W[0]
12671254
}{
12681255
nmodl_eigen_x[0] = W[0]
1269-
}{
12701256
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]+nmodl_eigen_x[0]*dt*B[0]-nmodl_eigen_x[0]+3.0*dt*A[1]+old_W_0
12711257
nmodl_eigen_j[0] = -dt*A[0]+dt*B[0]-1.0
12721258
}{
@@ -1300,15 +1286,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
13001286
)";
13011287
std::string expected_result = R"(
13021288
DERIVATIVE scheme1 {
1303-
EIGEN_NEWTON_SOLVE[2]{
1289+
EIGEN_LINEAR_SOLVE[2]{
13041290
LOCAL old_M_0, old_M_1
13051291
}{
13061292
old_M_0 = M[0]
13071293
old_M_1 = M[1]
13081294
}{
13091295
nmodl_eigen_x[0] = M[0]
13101296
nmodl_eigen_x[1] = M[1]
1311-
}{
13121297
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]-nmodl_eigen_x[0]+nmodl_eigen_x[1]*dt*B[0]+old_M_0
13131298
nmodl_eigen_j[0] = -dt*A[0]-1.0
13141299
nmodl_eigen_j[2] = dt*B[0]
@@ -1346,13 +1331,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
13461331
)";
13471332
std::string expected_result = R"(
13481333
DERIVATIVE scheme1 {
1349-
EIGEN_NEWTON_SOLVE[1]{
1334+
EIGEN_LINEAR_SOLVE[1]{
13501335
LOCAL old_W_0
13511336
}{
13521337
old_W_0 = W[0]
13531338
}{
13541339
nmodl_eigen_x[0] = W[0]
1355-
}{
13561340
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]+nmodl_eigen_x[0]*dt*B[0]-nmodl_eigen_x[0]+3.0*dt*A[1]+old_W_0
13571341
nmodl_eigen_j[0] = -dt*A[0]+dt*B[0]-1.0
13581342
}{
@@ -2053,7 +2037,7 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s
20532037
x
20542038
}
20552039
NONLINEAR nonlin {
2056-
~ x = 5
2040+
~ x * x * x = 5
20572041
})";
20582042
std::string expected_text = R"(
20592043
NONLINEAR nonlin {
@@ -2062,8 +2046,8 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s
20622046
}{
20632047
nmodl_eigen_x[0] = x
20642048
}{
2065-
nmodl_eigen_f[0] = 5.0-nmodl_eigen_x[0]
2066-
nmodl_eigen_j[0] = -1.0
2049+
nmodl_eigen_f[0] = 5.0-pow(nmodl_eigen_x[0], 3)
2050+
nmodl_eigen_j[0] = -3.0 * pow(nmodl_eigen_x[0], 2)
20672051
}{
20682052
x = nmodl_eigen_x[0]
20692053
}{
@@ -2084,7 +2068,7 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s
20842068
NONLINEAR nonlin {
20852069
~ s[0] = 1
20862070
~ s[1] = 3
2087-
~ s[2] + s[1] = s[0]
2071+
~ s[2] + s[1] = s[0] * s[0]
20882072
})";
20892073
std::string expected_text = R"(
20902074
NONLINEAR nonlin {
@@ -2097,14 +2081,14 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s
20972081
}{
20982082
nmodl_eigen_f[0] = 1.0-nmodl_eigen_x[0]
20992083
nmodl_eigen_f[1] = 3.0-nmodl_eigen_x[1]
2100-
nmodl_eigen_f[2] = nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]
2084+
nmodl_eigen_f[2] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]-nmodl_eigen_x[2]
21012085
nmodl_eigen_j[0] = -1.0
21022086
nmodl_eigen_j[3] = 0
21032087
nmodl_eigen_j[6] = 0
21042088
nmodl_eigen_j[1] = 0
21052089
nmodl_eigen_j[4] = -1.0
21062090
nmodl_eigen_j[7] = 0
2107-
nmodl_eigen_j[2] = 1.0
2091+
nmodl_eigen_j[2] = 2.0 * nmodl_eigen_x[0]
21082092
nmodl_eigen_j[5] = -1.0
21092093
nmodl_eigen_j[8] = -1.0
21102094
}{

0 commit comments

Comments
 (0)