Skip to content

Commit b5c55b2

Browse files
added saveat in SimpleAdaptiveTauLeaping
1 parent 06735ba commit b5c55b2

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

src/simple_regular_solve.jl

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ end
8787

8888
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
8989
seed = nothing,
90-
dtmin = 1e-10)
90+
dtmin = 1e-10,
91+
saveat = nothing)
9192
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
9293
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
9394
prob = jump_prob.prob
@@ -112,40 +113,65 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
112113

113114
# Compute initial stoichiometry and HOR
114115
nu = zeros(Int, length(u0), numjumps)
116+
counts_temp = zeros(Int, numjumps)
115117
for j in 1:numjumps
116-
counts_temp = zeros(numjumps)
118+
fill!(counts_temp, 0)
117119
counts_temp[j] = 1
118120
c(du, u0, p, t[1], counts_temp, nothing)
119121
nu[:, j] = du
120122
end
121-
122123
hor = zeros(Int, size(nu, 2))
123124
for j in 1:size(nu, 2)
124125
hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1
125126
end
126127

128+
saveat_times = isnothing(saveat) ? Float64[] : saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat)
129+
save_idx = 1
130+
127131
while t[end] < t_end
128132
u_prev = u[end]
129133
t_prev = t[end]
130134
# Recompute stoichiometry
131135
for j in 1:numjumps
132-
counts_temp = zeros(numjumps)
136+
fill!(counts_temp, 0)
133137
counts_temp[j] = 1
134138
c(du, u_prev, p, t_prev, counts_temp, nothing)
135139
nu[:, j] = du
136140
end
137141
rate(rate_cache, u_prev, p, t_prev)
138142
tau = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin)
139143
tau = min(tau, t_end - t_prev)
144+
if !isempty(saveat_times)
145+
if save_idx <= length(saveat_times) && t_prev + tau > saveat_times[save_idx]
146+
tau = saveat_times[save_idx] - t_prev
147+
end
148+
end
140149
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
141150
c(du, u_prev, p, t_prev, counts, nothing)
142-
u_new = max.(u_prev + du, 0)
151+
u_new = u_prev + du
143152
if any(u_new .< 0)
153+
# Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)
144154
tau /= 2
145155
continue
146156
end
157+
u_new = max.(u_new, 0) # Ensure non-negative states
147158
push!(u, u_new)
148159
push!(t, t_prev + tau)
160+
if !isempty(saveat_times) && save_idx <= length(saveat_times) && t[end] >= saveat_times[save_idx]
161+
save_idx += 1
162+
end
163+
end
164+
165+
# Interpolate to saveat times if specified
166+
if !isempty(saveat_times)
167+
t_out = saveat_times
168+
u_out = [u[end]]
169+
for t_save in saveat_times
170+
idx = findlast(ti -> ti <= t_save, t)
171+
push!(u_out, u[idx])
172+
end
173+
t = t_out
174+
u = u_out[2:end]
149175
end
150176

151177
sol = DiffEqBase.build_solution(prob, alg, t, u,

test/regular_jumps.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Nsims = 1000
4848
sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1)
4949

5050
# Solve with SimpleAdaptiveTauLeaping
51-
sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims)
51+
sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat = 1.0)
5252

5353
# Compute mean trajectories at t = 0, 1, ..., 250
5454
t_points = 0:1.0:250.0
@@ -106,7 +106,7 @@ end
106106
sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1)
107107

108108
# Solve with SimpleAdaptiveTauLeaping
109-
sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims)
109+
sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat = 1.0)
110110

111111
# Compute mean trajectories at t = 0, 1, ..., 250
112112
t_points = 0:1.0:250.0

0 commit comments

Comments
 (0)