diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index b1f9c39895..6bdb7d0e75 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -67,6 +67,7 @@ "compile_forward_sampling_function", "draw", "sample_posterior_predictive", + "sample_prior", "sample_prior_predictive", ) @@ -984,3 +985,91 @@ def sample_posterior_predictive( idata.extend(idata_pp) return idata return idata_pp + + +def sample_prior( + draws: int = 500, + model: Model | None = None, + var_names: Iterable[str] | None = None, + random_seed: RandomState = None, + return_inferencedata: bool = True, + idata_kwargs: dict | None = None, + compile_kwargs: dict | None = None, +) -> InferenceData | dict[str, np.ndarray]: + """Generate samples from the prior distribution. + + This function samples only from the prior (unobserved random variables) + and deterministics that do not depend on observed variables. + + This is different from `sample_prior_predictive` which samples from both + prior and prior predictive distributions. + + Parameters + ---------- + draws : int + Number of samples from the prior to generate. Defaults to 500. + model : Model (optional if in ``with`` context) + var_names : Iterable[str] + A list of names of variables for which to compute the prior samples. + random_seed : int, RandomState or Generator, optional + Seed for the random number generator. + return_inferencedata : bool + Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False). + Defaults to True. + idata_kwargs : dict, optional + Keyword arguments for :func:`pymc.to_inference_data` + compile_kwargs: dict, optional + Keyword arguments for :func:`pymc.pytensorf.compile_pymc`. + + Returns + ------- + arviz.InferenceData or Dict + An ArviZ ``InferenceData`` object containing the prior samples (default), + or a dictionary with variable names as keys and samples as numpy arrays. + + Examples + -------- + Basic usage: + + .. code:: python + + import pymc as pm + + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=[1, 2, 3]) + + # Sample only from the prior (mu and sigma) + prior_samples = pm.sample_prior(draws=1000) + + Specify specific variables: + + .. code:: python + + with model: + # Sample only mu from the prior + mu_samples = pm.sample_prior(draws=1000, var_names=["mu"]) + """ + model = modelcontext(model) + + if var_names is None: + # Default to unobserved random variables + var_names = (var.name for var in model.unobserved_RVs) + + # Filter out deterministics that depend on observed variables + dependent_dets = observed_dependent_deterministics(model) + var_names = (var_name for var_name in var_names if model[var_name] not in dependent_dets) + + # Use sample_prior_predictive with filtered var_names + result = sample_prior_predictive( + draws=draws, + model=model, + var_names=var_names, + random_seed=random_seed, + return_inferencedata=return_inferencedata, + idata_kwargs=idata_kwargs, + compile_kwargs=compile_kwargs, + ) + + return result diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index d3b41bf667..b70f3339db 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -1801,3 +1801,110 @@ def test_sample_prior_predictive_samples_deprecated_warns() -> None: match = "The samples argument has been deprecated" with pytest.warns(DeprecationWarning, match=match): pm.sample_prior_predictive(model=m, samples=10) + + +class TestSamplePrior: + def test_basic_prior_sampling(self, seeded_test): + """Test that sample_prior only samples from unobserved random variables.""" + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=[1, 2, 3]) + det = pm.Deterministic("det", mu + sigma) + + prior_samples = pm.sample_prior(draws=100, return_inferencedata=False) + + # Should contain unobserved RVs and deterministics that do not + # depend on observed variables, but not observed variables + assert "mu" in prior_samples + assert "sigma" in prior_samples + assert "det" in prior_samples # deterministic is included + assert "y" not in prior_samples # observed variable + + assert prior_samples["mu"].shape == (100,) + assert prior_samples["sigma"].shape == (100,) + assert prior_samples["det"].shape == (100,) + + def test_specific_var_names(self, seeded_test): + """Test sampling specific variables from the prior.""" + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=[1, 2, 3]) + + # Sample only mu + mu_samples = pm.sample_prior(draws=100, var_names=["mu"], return_inferencedata=False) + + assert "mu" in mu_samples + assert "sigma" not in mu_samples + assert "y" not in mu_samples + assert mu_samples["mu"].shape == (100,) + + def test_multivariate_prior(self, seeded_test): + """Test sampling from multivariate priors.""" + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1, size=3) + # Use a simpler multivariate setup to avoid LKJCholeskyCov issues + mv = pm.MvNormal("mv", mu, cov=np.eye(3), size=4) + + prior_samples = pm.sample_prior(draws=50, return_inferencedata=False) + + assert "mu" in prior_samples + assert "mv" in prior_samples + assert prior_samples["mu"].shape == (50, 3) + assert prior_samples["mv"].shape == (50, 4, 3) + + def test_only_requested_variables(self, seeded_test): + """Test that sample_prior only returns the requested variables.""" + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + sigma = pm.HalfNormal("sigma", 1) + det = pm.Deterministic("det", mu + sigma) + y = pm.Normal("y", det, sigma, observed=[1, 2, 3]) + + # Request only mu, but y depends on det which depends on mu + prior_samples = pm.sample_prior(draws=100, var_names=["mu"], return_inferencedata=False) + + # Should only contain mu, not det or sigma even though they're dependencies + assert "mu" in prior_samples + assert "sigma" not in prior_samples + assert "det" not in prior_samples + assert "y" not in prior_samples + assert len(prior_samples) == 1 + + def test_deterministics_behavior(self, seeded_test): + """Test that sample_prior only includes deterministics that don't depend on observed variables.""" + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=[1, 2, 3]) + + # Deterministic that depends only on unobserved RVs + det_prior = pm.Deterministic("det_prior", mu + sigma) + + # Deterministic that depends on observed RV + det_obs = pm.Deterministic("det_obs", y + mu) + + prior_samples = pm.sample_prior(draws=100, return_inferencedata=False) + + # Should include deterministics that depend only on unobserved RVs + assert "det_prior" in prior_samples + + # Should NOT include deterministics that depend on observed variables + assert "det_obs" not in prior_samples + + # Should not include observed variables + assert "y" not in prior_samples + + def test_empty_var_names_behavior(self, seeded_test): + """Test what happens when we pass an empty var_names set.""" + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=[1, 2, 3]) + det = pm.Deterministic("det", mu + sigma) + + # Test with empty var_names + empty_samples = pm.sample_prior(draws=100, var_names=[], return_inferencedata=False) + + assert empty_samples == {}