87
87
88
88
function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleAdaptiveTauLeaping ;
89
89
seed = nothing ,
90
- dtmin = 1e-10 )
90
+ dtmin = 1e-10 ,
91
+ saveat = nothing )
91
92
@assert isempty (jump_prob. jump_callback. continuous_callbacks)
92
93
@assert isempty (jump_prob. jump_callback. discrete_callbacks)
93
94
prob = jump_prob. prob
@@ -112,40 +113,65 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
112
113
113
114
# Compute initial stoichiometry and HOR
114
115
nu = zeros (Int, length (u0), numjumps)
116
+ counts_temp = zeros (Int, numjumps)
115
117
for j in 1 : numjumps
116
- counts_temp = zeros (numjumps )
118
+ fill! ( counts_temp, 0 )
117
119
counts_temp[j] = 1
118
120
c (du, u0, p, t[1 ], counts_temp, nothing )
119
121
nu[:, j] = du
120
122
end
121
-
122
123
hor = zeros (Int, size (nu, 2 ))
123
124
for j in 1 : size (nu, 2 )
124
125
hor[j] = sum (abs .(nu[:, j])) > maximum (abs .(nu[:, j])) ? 2 : 1
125
126
end
126
127
128
+ saveat_times = isnothing (saveat) ? Float64[] : saveat isa Number ? collect (range (tspan[1 ], tspan[2 ], step= saveat)) : collect (saveat)
129
+ save_idx = 1
130
+
127
131
while t[end ] < t_end
128
132
u_prev = u[end ]
129
133
t_prev = t[end ]
130
134
# Recompute stoichiometry
131
135
for j in 1 : numjumps
132
- counts_temp = zeros (numjumps )
136
+ fill! ( counts_temp, 0 )
133
137
counts_temp[j] = 1
134
138
c (du, u_prev, p, t_prev, counts_temp, nothing )
135
139
nu[:, j] = du
136
140
end
137
141
rate (rate_cache, u_prev, p, t_prev)
138
142
tau = compute_tau_explicit (u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin)
139
143
tau = min (tau, t_end - t_prev)
144
+ if ! isempty (saveat_times)
145
+ if save_idx <= length (saveat_times) && t_prev + tau > saveat_times[save_idx]
146
+ tau = saveat_times[save_idx] - t_prev
147
+ end
148
+ end
140
149
counts .= pois_rand .(rng, max .(rate_cache * tau, 0.0 ))
141
150
c (du, u_prev, p, t_prev, counts, nothing )
142
- u_new = max .( u_prev + du, 0 )
151
+ u_new = u_prev + du
143
152
if any (u_new .< 0 )
153
+ # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)
144
154
tau /= 2
145
155
continue
146
156
end
157
+ u_new = max .(u_new, 0 ) # Ensure non-negative states
147
158
push! (u, u_new)
148
159
push! (t, t_prev + tau)
160
+ if ! isempty (saveat_times) && save_idx <= length (saveat_times) && t[end ] >= saveat_times[save_idx]
161
+ save_idx += 1
162
+ end
163
+ end
164
+
165
+ # Interpolate to saveat times if specified
166
+ if ! isempty (saveat_times)
167
+ t_out = saveat_times
168
+ u_out = [u[end ]]
169
+ for t_save in saveat_times
170
+ idx = findlast (ti -> ti <= t_save, t)
171
+ push! (u_out, u[idx])
172
+ end
173
+ t = t_out
174
+ u = u_out[2 : end ]
149
175
end
150
176
151
177
sol = DiffEqBase. build_solution (prob, alg, t, u,
0 commit comments