Skip to content

Use ArviZ-stats #232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11"]
python-version: ["3.11", "3.12", "3.13"]

name: Set up Python ${{ matrix.python-version }}
steps:
Expand Down
2 changes: 1 addition & 1 deletion docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ methods in the current release of PyMC-BART.
=============================

.. automodule:: pymc_bart
:members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_convergence, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
:members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ interpretation of those models and perform variable selection.
Installation
============

PyMC-BART requires a working Python interpreter (3.10+). We recommend installing Python and key numerical libraries using the `Anaconda distribution <https://www.anaconda.com/products/individual#Downloads>`_, which has one-click installers available on all major platforms.
PyMC-BART requires a working Python interpreter (3.11+). We recommend installing Python and key numerical libraries using the `Anaconda distribution <https://www.anaconda.com/products/individual#Downloads>`_, which has one-click installers available on all major platforms.

Assuming a standard Python environment is installed on your machine, PyMC-BART itself can be installed either using pip or conda-forge.

Expand Down
4 changes: 2 additions & 2 deletions env-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ channels:
- conda-forge
- defaults
dependencies:
- pymc>=5.16.2,<=5.19.1
- arviz>=0.18.0
- pymc>=5.16.2,<=5.23.0
- numba
- matplotlib
- numpy
Expand All @@ -20,4 +19,5 @@ dependencies:
- flake8
- pip
- pip:
- arviz-stats[xarray]>=0.6.0
- -e .
4 changes: 2 additions & 2 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ channels:
- conda-forge
- defaults
dependencies:
- pymc>=5.16.2,<=5.19.1
- arviz>=0.18.0
- pymc>=5.16.2,<=5.23.0
- numba
- matplotlib
- numpy
- pytensor
- pip
- pip:
- pymc-bart
- arviz-stats[xarray]>=0.6.0
2 changes: 1 addition & 1 deletion pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def resample(
new_particles.append(particles[idx].copy())
else:
new_particles.append(particles[idx])
seen.append(idx)
seen.append(int(idx))

particles[1:] = new_particles

Expand Down
89 changes: 43 additions & 46 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import warnings
from typing import Any, Callable, Optional, Union

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pytensor.tensor as pt
from arviz_base import rcParams
from arviz_stats.base import array_stats
from numba import jit
from pytensor.tensor.variable import Variable
from scipy.interpolate import griddata
from scipy.signal import savgol_filter
from scipy.stats import norm

from .tree import Tree

Expand Down Expand Up @@ -76,12 +76,12 @@ def _sample_posterior(


def plot_convergence(
idata: az.InferenceData,
idata: Any,
var_name: Optional[str] = None,
kind: str = "ecdf",
figsize: Optional[tuple[float, float]] = None,
ax=None,
) -> list[plt.Axes]:
) -> None:
"""
Plot convergence diagnostics.

Expand All @@ -102,39 +102,12 @@ def plot_convergence(
-------
list[ax] : matplotlib axes
"""
ess_threshold = idata["posterior"]["chain"].size * 100
ess = np.atleast_2d(az.ess(idata, method="bulk", var_names=var_name)[var_name].values)
rhat = np.atleast_2d(az.rhat(idata, var_names=var_name)[var_name].values)

if figsize is None:
figsize = (10, 3)

if kind == "ecdf":
kind_func: Callable[..., Any] = az.plot_ecdf
sharey = True
elif kind == "kde":
kind_func = az.plot_kde
sharey = False

if ax is None:
_, ax = plt.subplots(1, 2, figsize=figsize, sharex="col", sharey=sharey)

for idx, (essi, rhati) in enumerate(zip(ess, rhat)):
kind_func(essi, ax=ax[0], plot_kwargs={"color": f"C{idx}"})
kind_func(rhati, ax=ax[1], plot_kwargs={"color": f"C{idx}"})

ax[0].axvline(ess_threshold, color="0.7", ls="--")
# Assume Rhats are N(1, 0.005) iid. Then compute the 0.99 quantile
# scaled by the sample size and use it as a threshold.
ax[1].axvline(norm(1, 0.005).ppf(0.99 ** (1 / ess.size)), color="0.7", ls="--")

ax[0].set_xlabel("ESS")
ax[1].set_xlabel("R-hat")
if kind == "kde":
ax[0].set_yticks([])
ax[1].set_yticks([])

return ax
warnings.warn(
"This function has been deprecated"
"Use az.plot_convergence_dist() instead."
"https://arviz-plots.readthedocs.io/en/latest/api/generated/arviz_plots.plot_convergence_dist.html",
FutureWarning,
)


def plot_ice(
Expand Down Expand Up @@ -408,7 +381,7 @@ def identity(x):
if var in var_discrete:
_, idx_uni = np.unique(new_x, return_index=True)
y_means = p_di.mean(0)[idx_uni]
hdi = az.hdi(p_di)[idx_uni]
hdi = array_stats.hdi(p_di, prob=rcParams["stats.ci_prob"], axis=0)[idx_uni]
axes[count].errorbar(
new_x[idx_uni],
y_means,
Expand All @@ -418,11 +391,13 @@ def identity(x):
)
axes[count].set_xticks(new_x[idx_uni])
else:
az.plot_hdi(
_plot_hdi(
new_x,
p_di,
smooth=smooth,
fill_kwargs={"alpha": alpha, "color": color},
alpha=alpha,
color=color,
smooth_kwargs=smooth_kwargs,
ax=axes[count],
)
if smooth:
Expand Down Expand Up @@ -659,7 +634,7 @@ def _create_pdp_data(
def _smooth_mean(
new_x: npt.NDArray,
p_di: npt.NDArray,
kind: str = "pdp",
kind: str = "neutral",
smooth_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Expand Down Expand Up @@ -688,7 +663,10 @@ def _smooth_mean(
smooth_kwargs.setdefault("polyorder", 2)
x_data = np.linspace(np.nanmin(new_x), np.nanmax(new_x), 200)
x_data[0] = (x_data[0] + x_data[1]) / 2
if kind == "pdp":

if kind == "neutral":
interp = griddata(new_x, p_di, x_data)
elif kind == "pdp":
interp = griddata(new_x, p_di.mean(0), x_data)
else:
interp = griddata(new_x, p_di.T, x_data)
Expand Down Expand Up @@ -800,7 +778,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non


def compute_variable_importance( # noqa: PLR0915 PLR0912
idata: az.InferenceData,
idata: Any,
bartrv: Variable,
X: npt.NDArray,
method: str = "VI",
Expand Down Expand Up @@ -904,7 +882,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
[pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)]
)
r2_mean[idx] = np.mean(r_2)
r2_hdi[idx] = az.hdi(r_2)
r2_hdi[idx] = array_stats.hdi(r_2, prob=rcParams["stats.ci_prob"])
preds[idx] = predicted_subset.squeeze()

if method in ["backward", "backward_VI"]:
Expand Down Expand Up @@ -954,7 +932,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

# Save values for plotting later
r2_mean[i_var - init] = max_r_2
r2_hdi[i_var - init] = az.hdi(r_2_without_least_important_vars)
r2_hdi[i_var - init] = array_stats.hdi(r_2_without_least_important_vars)
preds[i_var - init] = least_important_samples.squeeze()

# extend current list of least important variable
Expand Down Expand Up @@ -1079,7 +1057,7 @@ def plot_variable_importance(
)
ax.fill_between(
[-0.5, n_vars - 0.5],
*az.hdi(r_2_ref),
*array_stats.hdi(r_2_ref, prob=rcParams["stats.ci_prob"]),
alpha=0.1,
color=plot_kwargs.get("color_ref", "grey"),
)
Expand Down Expand Up @@ -1229,3 +1207,22 @@ def pearsonr2(A, B):
am = A - np.mean(A)
bm = B - np.mean(B)
return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2))


def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax):
x = np.asarray(x)
y = np.asarray(y)
hdi_prob = rcParams["stats.ci_prob"]
hdi_data = array_stats.hdi(y, hdi_prob, axis=0)
if smooth:
if isinstance(x[0], np.datetime64):
raise TypeError("Cannot deal with x as type datetime. Recommend setting smooth=False.")

x_data, y_data = _smooth_mean(x, hdi_data, smooth_kwargs=smooth_kwargs)
else:
idx = np.argsort(x)
x_data = x[idx]
y_data = hdi_data[idx]

ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], color=color, alpha=alpha)
return ax
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pymc>=5.16.2, <=5.23.0
arviz>=0.18.0
pymc>=5.16.2,<=5.23.0
arviz-stats[xarray]>=0.6.0
numba
matplotlib
numpy
numpy>=2.0
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
"Development Status :: 5 - Production/Stable",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"License :: OSI Approved :: Apache Software License",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering",
Expand Down