diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8817d27..3fe0779 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.11", "3.12", "3.13"] name: Set up Python ${{ matrix.python-version }} steps: diff --git a/docs/api_reference.rst b/docs/api_reference.rst index b6fb8a5..88b910c 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -13,4 +13,4 @@ methods in the current release of PyMC-BART. ============================= .. automodule:: pymc_bart - :members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_convergence, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule + :members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule diff --git a/docs/index.rst b/docs/index.rst index 78a59fb..e390c3c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,7 +29,7 @@ interpretation of those models and perform variable selection. Installation ============ -PyMC-BART requires a working Python interpreter (3.10+). We recommend installing Python and key numerical libraries using the `Anaconda distribution `_, which has one-click installers available on all major platforms. +PyMC-BART requires a working Python interpreter (3.11+). We recommend installing Python and key numerical libraries using the `Anaconda distribution `_, which has one-click installers available on all major platforms. Assuming a standard Python environment is installed on your machine, PyMC-BART itself can be installed either using pip or conda-forge. diff --git a/env-dev.yml b/env-dev.yml index 1e28429..fae1398 100644 --- a/env-dev.yml +++ b/env-dev.yml @@ -3,8 +3,7 @@ channels: - conda-forge - defaults dependencies: - - pymc>=5.16.2,<=5.19.1 - - arviz>=0.18.0 + - pymc>=5.16.2,<=5.23.0 - numba - matplotlib - numpy @@ -20,4 +19,5 @@ dependencies: - flake8 - pip - pip: + - arviz-stats[xarray]>=0.6.0 - -e . diff --git a/env.yml b/env.yml index bd814ae..77f6c13 100644 --- a/env.yml +++ b/env.yml @@ -3,8 +3,7 @@ channels: - conda-forge - defaults dependencies: - - pymc>=5.16.2,<=5.19.1 - - arviz>=0.18.0 + - pymc>=5.16.2,<=5.23.0 - numba - matplotlib - numpy @@ -12,3 +11,4 @@ dependencies: - pip - pip: - pymc-bart + - arviz-stats[xarray]>=0.6.0 diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 014313a..b76c40c 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -346,7 +346,7 @@ def resample( new_particles.append(particles[idx].copy()) else: new_particles.append(particles[idx]) - seen.append(idx) + seen.append(int(idx)) particles[1:] = new_particles diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 3ba6e58..ab10467 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -4,16 +4,16 @@ import warnings from typing import Any, Callable, Optional, Union -import arviz as az import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt import pytensor.tensor as pt +from arviz_base import rcParams +from arviz_stats.base import array_stats from numba import jit from pytensor.tensor.variable import Variable from scipy.interpolate import griddata from scipy.signal import savgol_filter -from scipy.stats import norm from .tree import Tree @@ -76,12 +76,12 @@ def _sample_posterior( def plot_convergence( - idata: az.InferenceData, + idata: Any, var_name: Optional[str] = None, kind: str = "ecdf", figsize: Optional[tuple[float, float]] = None, ax=None, -) -> list[plt.Axes]: +) -> None: """ Plot convergence diagnostics. @@ -102,39 +102,12 @@ def plot_convergence( ------- list[ax] : matplotlib axes """ - ess_threshold = idata["posterior"]["chain"].size * 100 - ess = np.atleast_2d(az.ess(idata, method="bulk", var_names=var_name)[var_name].values) - rhat = np.atleast_2d(az.rhat(idata, var_names=var_name)[var_name].values) - - if figsize is None: - figsize = (10, 3) - - if kind == "ecdf": - kind_func: Callable[..., Any] = az.plot_ecdf - sharey = True - elif kind == "kde": - kind_func = az.plot_kde - sharey = False - - if ax is None: - _, ax = plt.subplots(1, 2, figsize=figsize, sharex="col", sharey=sharey) - - for idx, (essi, rhati) in enumerate(zip(ess, rhat)): - kind_func(essi, ax=ax[0], plot_kwargs={"color": f"C{idx}"}) - kind_func(rhati, ax=ax[1], plot_kwargs={"color": f"C{idx}"}) - - ax[0].axvline(ess_threshold, color="0.7", ls="--") - # Assume Rhats are N(1, 0.005) iid. Then compute the 0.99 quantile - # scaled by the sample size and use it as a threshold. - ax[1].axvline(norm(1, 0.005).ppf(0.99 ** (1 / ess.size)), color="0.7", ls="--") - - ax[0].set_xlabel("ESS") - ax[1].set_xlabel("R-hat") - if kind == "kde": - ax[0].set_yticks([]) - ax[1].set_yticks([]) - - return ax + warnings.warn( + "This function has been deprecated" + "Use az.plot_convergence_dist() instead." + "https://arviz-plots.readthedocs.io/en/latest/api/generated/arviz_plots.plot_convergence_dist.html", + FutureWarning, + ) def plot_ice( @@ -408,7 +381,7 @@ def identity(x): if var in var_discrete: _, idx_uni = np.unique(new_x, return_index=True) y_means = p_di.mean(0)[idx_uni] - hdi = az.hdi(p_di)[idx_uni] + hdi = array_stats.hdi(p_di, prob=rcParams["stats.ci_prob"], axis=0)[idx_uni] axes[count].errorbar( new_x[idx_uni], y_means, @@ -418,11 +391,13 @@ def identity(x): ) axes[count].set_xticks(new_x[idx_uni]) else: - az.plot_hdi( + _plot_hdi( new_x, p_di, smooth=smooth, - fill_kwargs={"alpha": alpha, "color": color}, + alpha=alpha, + color=color, + smooth_kwargs=smooth_kwargs, ax=axes[count], ) if smooth: @@ -659,7 +634,7 @@ def _create_pdp_data( def _smooth_mean( new_x: npt.NDArray, p_di: npt.NDArray, - kind: str = "pdp", + kind: str = "neutral", smooth_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[np.ndarray, np.ndarray]: """ @@ -688,7 +663,10 @@ def _smooth_mean( smooth_kwargs.setdefault("polyorder", 2) x_data = np.linspace(np.nanmin(new_x), np.nanmax(new_x), 200) x_data[0] = (x_data[0] + x_data[1]) / 2 - if kind == "pdp": + + if kind == "neutral": + interp = griddata(new_x, p_di, x_data) + elif kind == "pdp": interp = griddata(new_x, p_di.mean(0), x_data) else: interp = griddata(new_x, p_di.T, x_data) @@ -800,7 +778,7 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non def compute_variable_importance( # noqa: PLR0915 PLR0912 - idata: az.InferenceData, + idata: Any, bartrv: Variable, X: npt.NDArray, method: str = "VI", @@ -904,7 +882,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 [pearsonr2(predicted_all[j], predicted_subset[j]) for j in range(samples)] ) r2_mean[idx] = np.mean(r_2) - r2_hdi[idx] = az.hdi(r_2) + r2_hdi[idx] = array_stats.hdi(r_2, prob=rcParams["stats.ci_prob"]) preds[idx] = predicted_subset.squeeze() if method in ["backward", "backward_VI"]: @@ -954,7 +932,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 # Save values for plotting later r2_mean[i_var - init] = max_r_2 - r2_hdi[i_var - init] = az.hdi(r_2_without_least_important_vars) + r2_hdi[i_var - init] = array_stats.hdi(r_2_without_least_important_vars) preds[i_var - init] = least_important_samples.squeeze() # extend current list of least important variable @@ -1079,7 +1057,7 @@ def plot_variable_importance( ) ax.fill_between( [-0.5, n_vars - 0.5], - *az.hdi(r_2_ref), + *array_stats.hdi(r_2_ref, prob=rcParams["stats.ci_prob"]), alpha=0.1, color=plot_kwargs.get("color_ref", "grey"), ) @@ -1229,3 +1207,22 @@ def pearsonr2(A, B): am = A - np.mean(A) bm = B - np.mean(B) return (am @ bm) ** 2 / (np.sum(am**2) * np.sum(bm**2)) + + +def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax): + x = np.asarray(x) + y = np.asarray(y) + hdi_prob = rcParams["stats.ci_prob"] + hdi_data = array_stats.hdi(y, hdi_prob, axis=0) + if smooth: + if isinstance(x[0], np.datetime64): + raise TypeError("Cannot deal with x as type datetime. Recommend setting smooth=False.") + + x_data, y_data = _smooth_mean(x, hdi_data, smooth_kwargs=smooth_kwargs) + else: + idx = np.argsort(x) + x_data = x[idx] + y_data = hdi_data[idx] + + ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], color=color, alpha=alpha) + return ax diff --git a/requirements.txt b/requirements.txt index 95fce57..5e6713e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -pymc>=5.16.2, <=5.23.0 -arviz>=0.18.0 +pymc>=5.16.2,<=5.23.0 +arviz-stats[xarray]>=0.6.0 numba matplotlib -numpy +numpy>=2.0 diff --git a/setup.py b/setup.py index e934ae2..0ae76b2 100644 --- a/setup.py +++ b/setup.py @@ -29,9 +29,9 @@ "Development Status :: 5 - Production/Stable", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "License :: OSI Approved :: Apache Software License", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering",