diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index 5114b6e..233d33e 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -29,7 +29,6 @@ from pytensor.tensor.variable import TensorVariable from .split_rules import SplitRule -from .tree import Tree from .utils import TensorLike, _sample_posterior __all__ = ["BART"] @@ -42,7 +41,6 @@ class BARTRV(RandomVariable): signature = "(m,n),(m),(),(),() -> (m)" dtype: str = "floatX" _print_name: tuple[str, str] = ("BART", "\\operatorname{BART}") - all_trees = list[list[list[Tree]]] def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed idx = dist_params[0].ndim - 2 @@ -55,7 +53,7 @@ def rng_fn( # pylint: disable=W0237 if not size: size = None - if not cls.all_trees: + if not hasattr(cls, "all_trees") or not cls.all_trees: if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)): Y = cls.Y.eval() else: @@ -142,8 +140,9 @@ def __new__( "Options linear and mix are experimental and still not well tested\n" + "Use with caution." ) + # Create a unique manager list for each BART instance manager = Manager() - cls.all_trees = manager.list() + instance_all_trees = manager.list() X, Y = preprocess_xy(X, Y) @@ -154,7 +153,7 @@ def __new__( (BARTRV,), { "name": "BART", - "all_trees": cls.all_trees, + "all_trees": instance_all_trees, # Instance-specific tree storage "inplace": False, "initval": Y.mean(), "X": X, diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 014313a..74fb1ce 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -130,6 +130,7 @@ def __init__( # noqa: PLR0912, PLR0915 model: Optional[Model] = None, initial_point: PointType | None = None, compile_kwargs: dict | None = None, + **kwargs, # Accept additional kwargs for compound sampling ) -> None: model = modelcontext(model) if initial_point is None: @@ -143,7 +144,24 @@ def __init__( # noqa: PLR0912, PLR0915 if vars is None: raise ValueError("Unable to find variables to sample") - value_bart = vars[0] + # Filter to only BART variables + bart_vars = [] + for var in vars: + rv = model.values_to_rvs.get(var) + if rv is not None and isinstance(rv.owner.op, BARTRV): + bart_vars.append(var) + + if not bart_vars: + raise ValueError("No BART variables found in the provided variables") + + if len(bart_vars) > 1: + raise ValueError( + "PGBART can only handle one BART variable at a time. " + "For multiple BART variables, PyMC will automatically create " + "separate PGBART samplers for each variable." + ) + + value_bart = bart_vars[0] self.bart = model.values_to_rvs[value_bart].owner.op if isinstance(self.bart.X, Variable): @@ -227,15 +245,15 @@ def __init__( # noqa: PLR0912, PLR0915 self.num_particles = num_particles self.indices = list(range(1, num_particles)) - shared = make_shared_replacements(initial_point, vars, model) - self.likelihood_logp = logp(initial_point, [model.datalogp], vars, shared) + shared = make_shared_replacements(initial_point, [value_bart], model) + self.likelihood_logp = logp(initial_point, [model.datalogp], [value_bart], shared) self.all_particles = [ [ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape) ] self.all_trees = np.array([[p.tree for p in pl] for pl in self.all_particles]) self.lower = 0 self.iter = 0 - super().__init__(vars, shared) + super().__init__([value_bart], shared) def astep(self, _): variable_inclusion = np.zeros(self.num_variates, dtype="int") @@ -408,6 +426,13 @@ def competence(var: pm.Distribution, has_grad: bool) -> Competence: return Competence.IDEAL return Competence.INCOMPATIBLE + @staticmethod + def _make_update_stats_functions(): + def update_stats(step_stats): + return {key: step_stats[key] for key in ("variable_inclusion", "tune")} + + return (update_stats,) + class RunningSd: """Welford's online algorithm for computing the variance/standard deviation""" diff --git a/tests/test_bart.py b/tests/test_bart.py index 226d938..8311c2a 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -256,3 +256,70 @@ def test_categorical_model(separate_trees, split_rule): # Fit should be good enough so right category is selected over 50% of time assert (idata.predictions.y.median(["chain", "draw"]) == Y).all() assert pmb.compute_variable_importance(idata, bartrv=lo, X=X)["preds"].shape == (5, 50, 9, 3) + + +def test_multiple_bart_variables(): + """Test that multiple BART variables can coexist in a single model.""" + X1 = np.random.normal(0, 1, size=(50, 2)) + X2 = np.random.normal(0, 1, size=(50, 3)) + Y = np.random.normal(0, 1, size=50) + + # Create correlated responses + Y1 = X1[:, 0] + np.random.normal(0, 0.1, size=50) + Y2 = X2[:, 0] + X2[:, 1] + np.random.normal(0, 0.1, size=50) + + with pm.Model() as model: + # Two separate BART variables with different covariates + mu1 = pmb.BART("mu1", X1, Y1, m=5) + mu2 = pmb.BART("mu2", X2, Y2, m=5) + + # Combined model + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu1 + mu2, sigma, observed=Y) + + # Sample with automatic assignment of BART samplers + idata = pm.sample(tune=50, draws=50, chains=1, random_seed=3415) + + # Verify both BART variables have their own tree collections + assert hasattr(mu1.owner.op, "all_trees") + assert hasattr(mu2.owner.op, "all_trees") + + # Verify trees are stored separately (different object references) + assert mu1.owner.op.all_trees is not mu2.owner.op.all_trees + + # Verify sampling worked + assert idata.posterior["mu1"].shape == (1, 50, 50) + assert idata.posterior["mu2"].shape == (1, 50, 50) + + +def test_multiple_bart_variables_manual_step(): + """Test that multiple BART variables work with manually assigned PGBART samplers.""" + X1 = np.random.normal(0, 1, size=(30, 2)) + X2 = np.random.normal(0, 1, size=(30, 2)) + Y = np.random.normal(0, 1, size=30) + + # Create simple responses + Y1 = X1[:, 0] + np.random.normal(0, 0.1, size=30) + Y2 = X2[:, 1] + np.random.normal(0, 0.1, size=30) + + with pm.Model() as model: + # Two separate BART variables + mu1 = pmb.BART("mu1", X1, Y1, m=3) + mu2 = pmb.BART("mu2", X2, Y2, m=3) + + # Non-BART variable + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu1 + mu2, sigma, observed=Y) + + # Manually create PGBART samplers for each BART variable + step1 = pmb.PGBART([mu1], num_particles=5) + step2 = pmb.PGBART([mu2], num_particles=5) + + # Sample with manual step assignment + idata = pm.sample(tune=20, draws=20, chains=1, step=[step1, step2], random_seed=3415) + + # Verify both variables were sampled + assert "mu1" in idata.posterior + assert "mu2" in idata.posterior + assert idata.posterior["mu1"].shape == (1, 20, 30) + assert idata.posterior["mu2"].shape == (1, 20, 30)