Skip to content

Commit ff87c6f

Browse files
committed
Add basic tests
1 parent 4fe3d2b commit ff87c6f

File tree

3 files changed

+87
-0
lines changed

3 files changed

+87
-0
lines changed

DESCRIPTION

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,6 @@ LinkingTo:
3131
BH,
3232
RcppParallel,
3333
rapidjsonr
34+
Suggests:
35+
testthat (>= 3.0.0)
36+
Config/testthat/edition: 3

tests/testthat.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
library(testthat)
2+
library(StanEstimators)
3+
4+
test_check("StanEstimators")

tests/testthat/test-basic.R

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
set.seed(1234)
2+
y <- rnorm(500, 10, 2)
3+
4+
loglik_fun <- function(v, x) {
5+
sum(dnorm(x, v[1], v[2], log = TRUE))
6+
}
7+
8+
inits <- c(0, 5)
9+
10+
grad <- function(v, x) {
11+
inv_sigma <- 1 / v[2]
12+
y_scaled = (x - v[1]) * inv_sigma
13+
scaled_diff = inv_sigma * y_scaled
14+
c(sum(scaled_diff),
15+
sum(inv_sigma * (y_scaled*y_scaled) - inv_sigma)
16+
)
17+
}
18+
19+
test_that("stan_sample runs", {
20+
expect_no_error(
21+
fit <- stan_sample(loglik_fun, inits, y, lower = c(-Inf, 0),
22+
num_chains = 1, seed = 1234)
23+
)
24+
expect_no_error(
25+
fit <- stan_sample(loglik_fun, inits, y, grad_fun = grad,
26+
lower = c(-Inf, 0),
27+
num_chains = 1, seed = 1234)
28+
)
29+
expect_no_error(
30+
fit <- stan_sample(loglik_fun, inits, y, grad_fun = grad,
31+
lower = c(-Inf, 0),
32+
metric = "dense_e",
33+
num_chains = 1, seed = 1234)
34+
)
35+
})
36+
37+
test_that("stan_optimize runs", {
38+
expect_no_error(
39+
fit <- stan_optimize(loglik_fun, inits, y, lower = c(-Inf, 0),
40+
seed = 1234)
41+
)
42+
expect_no_error(
43+
fit <- stan_optimize(loglik_fun, inits, y, grad_fun = grad,
44+
lower = c(-Inf, 0), seed = 1234)
45+
)
46+
expect_no_error(
47+
fit <- stan_optimize(loglik_fun, inits, y, grad_fun = grad,
48+
lower = c(-Inf, 0),
49+
algorithm = "bfgs",
50+
seed = 1234)
51+
)
52+
})
53+
54+
test_that("stan_variational runs", {
55+
expect_no_error(
56+
fit <- stan_variational(loglik_fun, inits, y, lower = c(-Inf, 0),
57+
seed = 1234)
58+
)
59+
expect_no_error(
60+
fit <- stan_variational(loglik_fun, inits, y, grad_fun = grad,
61+
lower = c(-Inf, 0), seed = 1234)
62+
)
63+
expect_no_error(
64+
fit <- stan_variational(loglik_fun, inits, y, grad_fun = grad,
65+
lower = c(-Inf, 0),
66+
algorithm = "fullrank",
67+
seed = 1234)
68+
)
69+
})
70+
71+
test_that("stan_pathfinder runs", {
72+
expect_no_error(
73+
fit <- stan_pathfinder(loglik_fun, inits, y, lower = c(-Inf, 0),
74+
seed = 1234)
75+
)
76+
expect_no_error(
77+
fit <- stan_pathfinder(loglik_fun, inits, y, grad_fun = grad,
78+
lower = c(-Inf, 0), seed = 1234)
79+
)
80+
})

0 commit comments

Comments
 (0)