Skip to content

Commit 6c2b714

Browse files
committed
The latest JAX release introduced breaking changes for older NumPyro versions, and ArviZ currently has an issue with a function I was relying on.
1 parent c181b29 commit 6c2b714

File tree

7 files changed

+297
-479
lines changed

7 files changed

+297
-479
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Imports:
3636
rstudioapi (>= 0.11),
3737
lifecycle (>= 1.0.2),
3838
reticulate (>= 1.30)
39-
RoxygenNote: 7.3.2
39+
RoxygenNote: 7.3.3
4040
Suggests:
4141
knitr,
4242
covr,

R/dag_numpyro.R

Lines changed: 226 additions & 384 deletions
Large diffs are not rendered by default.

R/dagp_plot.R

Lines changed: 60 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,23 @@
7575
#' @importFrom stats quantile
7676
#' @importFrom lifecycle badge
7777
#' @export
78-
79-
80-
dagp_plot = function(drawsDF,densityPlot = FALSE, abbrevLabels = FALSE) { # case where untidy posterior draws are provided
78+
dagp_plot = function(drawsDF, densityPlot = FALSE, abbrevLabels = FALSE) { # case where untidy posterior draws are provided
8179
q95 <- density <- reasonableIntervalWidth <- credIQR <- shape <- param <- NULL ## place holder to pass devtools::check
8280

8381
if (densityPlot == TRUE) {
8482
if (abbrevLabels) { ## shorten labels if desired
8583
drawsDF = drawsDF %>%
8684
tidyr::gather() %>%
87-
dplyr::mutate(key = abbreviate(key, minlength = 10))} else {
88-
drawsDF = drawsDF %>%
89-
tidyr::gather()
90-
}
85+
dplyr::mutate(key = abbreviate(key, minlength = 10))
86+
} else {
87+
drawsDF = drawsDF %>%
88+
tidyr::gather()
89+
}
9190
plot = drawsDF %>% ## start with tidy draws
9291
ggplot2::ggplot(ggplot2::aes(x = value,
93-
y = ggplot2::after_stat(density))) +
92+
y = ggplot2::after_stat(density))) +
9493
ggplot2::geom_density(ggplot2::aes(fill = key)) +
95-
ggplot2::facet_wrap( ~ key, scales = "free_x") +
94+
ggplot2::facet_wrap(~ key, scales = "free_x") +
9695
ggplot2::theme_minimal() +
9796
ggplot2::theme(legend.position = "none")
9897

@@ -104,48 +103,62 @@ dagp_plot = function(drawsDF,densityPlot = FALSE, abbrevLabels = FALSE) { # case
104103
if (abbrevLabels) { ## shorten labels if desired
105104
drawsDF = drawsDF %>%
106105
addPriorGroups() %>%
107-
dplyr::mutate(param = abbreviate(param, minlength = 10))} else {
108-
drawsDF = drawsDF %>%
109-
addPriorGroups()
110-
}
106+
dplyr::mutate(param = abbreviate(param, minlength = 10))
107+
} else {
108+
drawsDF = drawsDF %>%
109+
addPriorGroups()
110+
}
111111
drawsDF = drawsDF %>%
112-
dplyr::mutate(priorGroup = ifelse(is.na(priorGroup),999999,priorGroup)) %>%
113-
dplyr::filter(!is.na(priorGroup)) ##if try works, erase this line
114-
priorGroups = unique(drawsDF$priorGroup)
115-
numPriorGroups = length(priorGroups)
116-
for (i in 1:numPriorGroups) {
117-
df = drawsDF %>% dplyr::filter(priorGroup == priorGroups[i])
112+
dplyr::mutate(priorGroup = ifelse(is.na(priorGroup), 999999, priorGroup)) %>%
113+
dplyr::filter(!is.na(priorGroup)) ## if try works, erase this line if desired
118114

119-
# create one plot per group
120-
# groups defined as params with same prior
121-
plotList[[i]] = df %>% dplyr::group_by(param) %>%
122-
dplyr::summarize(q05 = stats::quantile(value,0.05),
123-
q25 = stats::quantile(value,0.55),
124-
q45 = stats::quantile(value,0.45),
125-
q50 = stats::quantile(value,0.50),
126-
q55 = stats::quantile(value,0.55),
127-
q75 = stats::quantile(value,0.75),
128-
q95 = stats::quantile(value,0.95)) %>%
129-
dplyr::mutate(credIQR = q75 - q25) %>%
130-
dplyr::mutate(reasonableIntervalWidth = 1.5 * stats::quantile(credIQR,0.75)) %>%
131-
dplyr::mutate(alphaLevel = ifelse(.data$credIQR > .data$reasonableIntervalWidth, 0.3,1)) %>%
132-
dplyr::arrange(alphaLevel,.data$q50) %>%
133-
dplyr::mutate(param = factor(param, levels = param)) %>%
134-
ggplot2::ggplot(ggplot2::aes(y = param, yend = param)) +
135-
ggplot2::geom_segment(ggplot2::aes(x = q05, xend = q95, alpha = alphaLevel), linewidth = 4, color = "#5f9ea0") +
136-
ggplot2::geom_segment(ggplot2::aes(x = q45, xend = q55, alpha = alphaLevel), linewidth = 4, color = "#11114e") +
137-
ggplot2::scale_alpha_continuous(range = c(0.6,1)) +
138-
ggplot2::guides(alpha = "none") +
139-
ggplot2::theme_minimal(12) +
140-
ggplot2::labs(y = ggplot2::element_blank(),
141-
x = "parameter value",
142-
caption = ifelse(i == numPriorGroups,"Credible Intervals - 10% (dark) & 90% (light)",""))
115+
priorGroups = unique(drawsDF$priorGroup)
116+
numPriorGroups = length(priorGroups)
117+
for (i in 1:numPriorGroups) {
118+
df = drawsDF %>% dplyr::filter(priorGroup == priorGroups[i])
143119

144-
}
120+
# create one plot per group
121+
# groups defined as params with same prior
122+
plotList[[i]] = df %>%
123+
dplyr::group_by(param) %>%
124+
dplyr::summarize(
125+
q05 = stats::quantile(value, 0.05),
126+
q25 = stats::quantile(value, 0.25), # <- fixed from 0.55
127+
q45 = stats::quantile(value, 0.45),
128+
q50 = stats::quantile(value, 0.50),
129+
q55 = stats::quantile(value, 0.55),
130+
q75 = stats::quantile(value, 0.75),
131+
q95 = stats::quantile(value, 0.95),
132+
.groups = "drop"
133+
) %>%
134+
dplyr::mutate(credIQR = q75 - q25) %>%
135+
dplyr::mutate(reasonableIntervalWidth = 1.5 * stats::quantile(credIQR, 0.75)) %>%
136+
dplyr::mutate(alphaLevel = ifelse(.data$credIQR > .data$reasonableIntervalWidth, 0.3, 1)) %>%
137+
dplyr::arrange(alphaLevel, .data$q50) %>%
138+
dplyr::mutate(param = factor(param, levels = param)) %>%
139+
ggplot2::ggplot(ggplot2::aes(y = param, yend = param)) +
140+
ggplot2::geom_segment(ggplot2::aes(x = q05, xend = q95, alpha = alphaLevel),
141+
linewidth = 4, color = "#5f9ea0") +
142+
ggplot2::geom_segment(ggplot2::aes(x = q45, xend = q55, alpha = alphaLevel),
143+
linewidth = 4, color = "#11114e") +
144+
ggplot2::scale_alpha_continuous(range = c(0.6, 1)) +
145+
ggplot2::guides(alpha = "none") +
146+
ggplot2::theme_minimal(12) +
147+
ggplot2::labs(
148+
y = NULL, # <- use NULL/string here, not element_blank()
149+
x = "parameter value",
150+
caption = ifelse(i == numPriorGroups,
151+
"Credible Intervals - 10% (dark) & 90% (light)",
152+
"")
153+
) +
154+
ggplot2::theme(
155+
axis.title.y = ggplot2::element_blank() # hides the Y title cleanly
156+
)
157+
}
145158

146-
nCol <- ifelse(numPriorGroups==1,1,floor(1 + sqrt(numPriorGroups)))
147-
cowplot::plot_grid(plotlist = plotList, ncol = nCol)
159+
nCol <- ifelse(numPriorGroups == 1, 1, floor(1 + sqrt(numPriorGroups)))
160+
cowplot::plot_grid(plotlist = plotList, ncol = nCol)
148161
},
149-
error = function(c) dagp_plot(drawsDF, densityPlot = T)) # end try
162+
error = function(c) dagp_plot(drawsDF, densityPlot = TRUE)) # end try
150163
} # end else
151164
} # end function

R/install_causact_deps.R

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
#' Install causact's python dependencies like numpyro, arviz, and xarray.
1+
#' Install causact's python dependencies like numpyro.
22
#'
3-
#' `install_causact_deps()` installs python, the numpyro and arviz packages, and their
3+
#' `install_causact_deps()` installs python, numpyro and their
44
#' direct dependencies.
55
#'
66
#' @details You may be prompted to download and install miniconda if reticulate
@@ -35,11 +35,9 @@ install_causact_deps <-
3535
envname = "r-causact"
3636
## lock in package versions that are
3737
## guaranteed to work together with python 3.11
38-
packages = c("numpyro[cpu]==0.16.1",
39-
"arviz==0.20.0",
40-
"pandas==2.2.2",
41-
"scipy==1.13.1")
42-
python_version = "3.11"
38+
packages = c("numpyro[cpu]==0.19.0",
39+
"pandas==2.3.2")
40+
python_version = "3.13"
4341
pip = TRUE
4442
new_env = TRUE
4543
method = "conda"

README.md

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,6 @@ numpyroCode = graph %>% dag_numpyro(mcmc = FALSE)
228228
#> import numpyro as npo
229229
#> import numpyro.distributions as dist
230230
#> import pandas as pd
231-
#> import arviz as az
232231
#> from jax import random
233232
#> from numpyro.infer import MCMC, NUTS
234233
#> from jax.numpy import transpose as t
@@ -254,19 +253,9 @@ numpyroCode = graph %>% dag_numpyro(mcmc = FALSE)
254253
#> mcmc = MCMC(NUTS(graph_model), num_warmup = 1000, num_samples = 4000)
255254
#> rng_key = random.PRNGKey(seed = 1234567)
256255
#> mcmc.run(rng_key,y,x)
257-
#> drawsDS = az.from_numpyro(mcmc,
258-
#> coords = {'x_dim': x_crd},
259-
#> dims = {'theta': ['x_dim']}
260-
#> ).posterior
261-
#> # prepare xarray dataset for export to R dataframe
262-
#> dimensions_to_keep = ['chain','draw','x_dim']
263-
#> drawsDS = drawsDS.squeeze(drop = True ).drop_dims([dim for dim in drawsDS.dims if dim not in dimensions_to_keep])
264-
#> # unstack plate variables to flatten dataframe as needed
265-
#> for x_da in drawsDS['x_dim']:
266-
#> new_varname = f'theta_{x_da.values}'
267-
#> drawsDS = drawsDS.assign(**{new_varname:drawsDS['theta'].sel(x_dim = x_da)})
268-
#> drawsDS = drawsDS.drop_dims(['x_dim'])
269-
#> drawsDF = drawsDS.squeeze().to_dataframe()"
256+
#> samples = mcmc.get_samples(group_by_chain=True) # dict: name -> [chains, draws, ...]
257+
#> flat = {k: np.reshape(v, (-1,) + v.shape[2:]) for k, v in samples.items()}
258+
#> drawsDF = pd.DataFrame({k: flat[k].squeeze() for k in flat})
270259
#> ) ## END PYTHON STRING
271260
#> drawsDF = reticulate::py$drawsDF
272261
```

man/dag_numpyro.Rd

Lines changed: 0 additions & 24 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/install_causact_deps.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)