Skip to content

Commit 5963f3e

Browse files
test changes
1 parent 41c8dca commit 5963f3e

File tree

1 file changed

+83
-28
lines changed

1 file changed

+83
-28
lines changed

test/regular_jumps.jl

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,90 @@
11
using JumpProcesses, DiffEqBase
2-
using Test, LinearAlgebra
3-
using StableRNGs, Plots
2+
using Test, LinearAlgebra, Statistics
3+
using StableRNGs
44
rng = StableRNG(12345)
55

6-
# Parameters
7-
c1 = 1.0 # S1 -> 0
8-
c2 = 10.0 # S1 + S1 <- S2
9-
c3 = 1000.0 # S1 + S1 -> S2
10-
c4 = 0.1 # S2 -> S3
11-
p = (c1, c2, c3, c4)
12-
13-
regular_rate_implicit = (out, u, p, t) -> begin
14-
out[1] = p[1] * u[1] # S1 -> 0
15-
out[2] = p[2] * u[2] # S1 + S1 <- S2
16-
out[3] = p[3] * u[1] * (u[1] - 1) / 2 # S1 + S1 -> S2
17-
out[4] = p[4] * u[2] # S2 -> S3
18-
end
6+
Nsims = 8000
7+
8+
# SIR model with influx
9+
let
10+
β = 0.1 / 1000.0
11+
ν = 0.01
12+
influx_rate = 1.0
13+
p = (β, ν, influx_rate)
14+
15+
regular_rate = (out, u, p, t) -> begin
16+
out[1] = p[1] * u[1] * u[2] # β*S*I (infection)
17+
out[2] = p[2] * u[2] # ν*I (recovery)
18+
out[3] = p[3] # influx_rate
19+
end
20+
21+
regular_c = (dc, u, p, t, counts, mark) -> begin
22+
dc .= 0.0
23+
dc[1] = -counts[1] + counts[3] # S: -infection + influx
24+
dc[2] = counts[1] - counts[2] # I: +infection - recovery
25+
dc[3] = counts[2] # R: +recovery
26+
end
27+
28+
u0 = [999.0, 10.0, 0.0] # S, I, R
29+
tspan = (0.0, 250.0)
30+
31+
prob_disc = DiscreteProblem(u0, tspan, p)
32+
rj = RegularJump(regular_rate, regular_c, 3)
33+
jump_prob = JumpProblem(prob_disc, Direct(), rj)
34+
35+
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
36+
mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims)
37+
38+
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
39+
mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims)
40+
41+
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
42+
mean_adaptive = mean(sol.u[i][1,end] for i in 1:Nsims)
1943

20-
regular_c_implicit = (dc, u, p, t, counts, mark) -> begin
21-
dc .= 0.0
22-
dc[1] = -counts[1] - 2 * counts[3] + 2 * counts[2] # S1: -decay - 2*forward + 2*backward
23-
dc[2] = counts[3] - counts[2] - counts[4] # S2: +forward - backward - decay
24-
dc[3] = counts[4] # S3: +decay
44+
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
45+
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
2546
end
2647

27-
u0 = [10000.0, 0.0, 0.0] # S1, S2, S3
28-
tspan = (0.0, 4.0)
2948

30-
# Create JumpProblem with proper parameter passing
31-
prob_disc = DiscreteProblem(u0, tspan, p)
32-
rj = RegularJump(regular_rate_implicit, regular_c_implicit, 4)
33-
jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
34-
sol = solve(jump_prob, SimpleAdaptiveTauLeaping())
35-
plot(sol)
49+
# SEIR model with exposed compartment
50+
let
51+
β = 0.3 / 1000.0
52+
σ = 0.2
53+
ν = 0.01
54+
p = (β, σ, ν)
55+
56+
regular_rate = (out, u, p, t) -> begin
57+
out[1] = p[1] * u[1] * u[3] # β*S*I (infection)
58+
out[2] = p[2] * u[2] # σ*E (progression)
59+
out[3] = p[3] * u[3] # ν*I (recovery)
60+
end
61+
62+
regular_c = (dc, u, p, t, counts, mark) -> begin
63+
dc .= 0.0
64+
dc[1] = -counts[1] # S: -infection
65+
dc[2] = counts[1] - counts[2] # E: +infection - progression
66+
dc[3] = counts[2] - counts[3] # I: +progression - recovery
67+
dc[4] = counts[3] # R: +recovery
68+
end
69+
70+
# Initial state
71+
u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R
72+
tspan = (0.0, 250.0)
73+
74+
# Create JumpProblem
75+
prob_disc = DiscreteProblem(u0, tspan, p)
76+
rj = RegularJump(regular_rate, regular_c, 3)
77+
jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
78+
79+
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
80+
mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims)
81+
82+
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
83+
mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims)
84+
85+
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
86+
mean_adaptive = mean(sol.u[i][end,end] for i in 1:Nsims)
87+
88+
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
89+
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
90+
end

0 commit comments

Comments
 (0)