Skip to content

Commit d50610e

Browse files
committed
release 0.6.0
1 parent c9cfa4c commit d50610e

File tree

5 files changed

+75
-15
lines changed

5 files changed

+75
-15
lines changed

NEWS.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
11
# causact (development version)
22

3-
# causact 0.5.8
3+
# causact 0.6.0
4+
5+
* Updated `dag_numpyro()` to address compatibility issues with `JAX 0.7.1`, resolving the error:
6+
`ImportError: cannot import name 'pjit_p' from 'jax.experimental.pjit'`.
7+
* Removed the dependency on Python’s `arviz` package to improve long-term compatibility across Python environments. The `arviz` dependency has introduced bugs in past releases and remains unstable in its current version.
8+
* Revised `dag_plotp()` to suppress warning messages triggered by recent updates to `ggplot2`.
9+
410

511
# causact 0.5.8
612
* patch release to support changes to object types in ggplot2.
713

814
# causact 0.5.6
15+
916
* patch release to fix numpyro integration issues. The default installation workflow now uses `numpyro==0.16.1` to avoid `TypeError` when using a discrete data distribution, e.g. Bernoulli.
1017

1118
# causact 0.5.5
19+
1220
* patch release to fix installation issues. The default installation workflow now uses `scipy==0.12` to avoid error `cannot import name 'gaussian' from 'scipy.signal'`.
1321

1422
# causact 0.5.4
23+
1524
* patch release to fix installation issues. The default installation workflow now uses `numpyro==0.13.2` to avoid `ModuleNotFoundError: No module named 'jax.linear_util'`.
1625

1726
# causact 0.5.3
27+
1828
* Fixed bug where beta distribution was being treated as Laplace distribution.
1929
* Fixed bug where nested plated were being indexed incorrectly in Python.
2030

2131
# causact 0.5.2
32+
2233
* Switched inference to Python's `numpyro`
2334
* `dag_greta()` is now deprecated; `dag_numpyro()` should be used as a drop-in replacement.
2435
* Added helper function for Python dependencies: `install_causact_deps`.
@@ -28,10 +39,12 @@
2839
* Temporarily removed support for the `dim` argument of probability distributions.
2940

3041
# causact 0.4.2
42+
3143
* Added `vignette("narrative-to-insight-with-causact") to introduce the package's encouraged user workflow.
3244
* Allow `greta::variable()` to be used for flat priors.
3345

3446
# causact 0.4.1
47+
3548
* Fixed bugs introduced by changes in `tidyr::replace_na()`.
3649
* Fixed bug related to rendering generative DAGs where two variables were both on multiple plates.
3750

README.md

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -183,16 +183,16 @@ drawsDF ### see top of data frame
183183
#> # A tibble: 4,000 × 4
184184
#> theta_Toyota.Corolla theta_Subaru.Outback theta_Kia.Forte theta_Jeep.Wrangler
185185
#> <dbl> <dbl> <dbl> <dbl>
186-
#> 1 0.196 0.617 0.293 0.853
187-
#> 2 0.188 0.643 0.214 0.867
188-
#> 3 0.182 0.641 0.215 0.862
189-
#> 4 0.202 0.658 0.245 0.882
190-
#> 5 0.191 0.652 0.293 0.854
191-
#> 6 0.166 0.594 0.236 0.844
192-
#> 7 0.214 0.626 0.230 0.843
193-
#> 8 0.207 0.579 0.249 0.854
194-
#> 9 0.189 0.601 0.271 0.831
195-
#> 10 0.219 0.641 0.219 0.865
186+
#> 1 0.215 0.582 0.226 0.839
187+
#> 2 0.194 0.643 0.264 0.853
188+
#> 3 0.216 0.602 0.243 0.831
189+
#> 4 0.217 0.647 0.177 0.843
190+
#> 5 0.186 0.580 0.309 0.853
191+
#> 6 0.230 0.633 0.237 0.846
192+
#> 7 0.217 0.619 0.263 0.859
193+
#> 8 0.205 0.692 0.274 0.862
194+
#> 9 0.191 0.693 0.332 0.847
195+
#> 10 0.188 0.655 0.174 0.851
196196
#> # ℹ 3,990 more rows
197197
```
198198

@@ -253,9 +253,55 @@ numpyroCode = graph %>% dag_numpyro(mcmc = FALSE)
253253
#> mcmc = MCMC(NUTS(graph_model), num_warmup = 1000, num_samples = 4000)
254254
#> rng_key = random.PRNGKey(seed = 1234567)
255255
#> mcmc.run(rng_key,y,x)
256-
#> samples = mcmc.get_samples(group_by_chain=True) # dict: name -> [chains, draws, ...]
257-
#> flat = {k: np.reshape(v, (-1,) + v.shape[2:]) for k, v in samples.items()}
258-
#> drawsDF = pd.DataFrame({k: flat[k].squeeze() for k in flat})
256+
#> axis_labels = {'x_dim': x_crd}
257+
#> rv_to_axes = {'theta': ['x_dim']}
258+
#> drop_names = set()
259+
#> samples = mcmc.get_samples(group_by_chain=True) # dict: name -> [chains, draws, *axes]
260+
#> import numpy as np, pandas as pd, string
261+
#> # Flatten (chains*draws) and expand RVs using rv_to_axes + axis_labels
262+
#> flat = {name: np.reshape(val, (-1,) + val.shape[2:]) for name, val in list(samples.items())}
263+
#> out = {}
264+
#> allowed = set(string.ascii_letters + string.digits + '._')
265+
#> def sanitize_label(s):
266+
#> s = str(s).strip().replace(' ', '.') # spaces -> dots
267+
#> s = ''.join((ch if ch in allowed else '.') for ch in s)
268+
#> # collapse repeated dots without regex
269+
#> while '..' in s:
270+
#> s = s.replace('..', '.')
271+
#> s = s.strip('.')
272+
#> return s if s else '1'
273+
#>
274+
#> for name, arr in list(flat.items()):
275+
#> # Drop deterministic RVs entirely
276+
#> if name in drop_names:
277+
#> continue
278+
#> axes = rv_to_axes.get(name, []) # [] if not present
279+
#> # If scalar per draw, keep as a single column
280+
#> if arr.ndim == 1:
281+
#> out[name] = arr
282+
#> continue
283+
#> # Otherwise (vector/matrix/...): expand to separate columns
284+
#> trailing = arr.shape[1:]
285+
#> arr2 = arr.reshape(arr.shape[0], int(np.prod(trailing)))
286+
#> for j in range(arr2.shape[1]):
287+
#> idx = np.unravel_index(j, trailing)
288+
#> parts = []
289+
#> for axis, i in enumerate(idx):
290+
#> if axis < len(axes):
291+
#> axis_name = axes[axis]
292+
#> labels = axis_labels.get(axis_name)
293+
#> if labels is not None:
294+
#> labs = np.asarray(labels).astype(str)
295+
#> lab = labs[i] if i < labs.shape[0] else str(i + 1)
296+
#> parts.append(sanitize_label(lab))
297+
#> else:
298+
#> parts.append(str(i + 1))
299+
#> else:
300+
#> parts.append(str(i + 1))
301+
#> # No brackets; join labels with '_' so tibble won't append trailing dots
302+
#> col = name if not parts else name + '_' + '_'.join(parts)
303+
#> out[col] = arr2[:, j]
304+
#> drawsDF = pd.DataFrame(out)"
259305
#> ) ## END PYTHON STRING
260306
#> drawsDF = reticulate::py$drawsDF
261307
```

cran-comments.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
This is a very minor patch release to accommodate changes for ggplot2 v4.0.0 to get to CRAN.
1+
This is a minor release to accommodate breaking changes in Python dependencies, most notable JAX 0.7.1.
22

33
## Test environments
4+
45
* local R installation, R 4.5.1
56
* R-hub: ubuntu-gcc-release (Ubuntu 20.04.1 LTS, R-release, GCC), R 4.5.1
67
* win-builder (devel)

man/figures/chimpsGraphPost-1.png

5.91 KB
Loading

man/figures/gretaPost-1.png

12.9 KB
Loading

0 commit comments

Comments
 (0)