Skip to content

Add sample_prior function #7833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"compile_forward_sampling_function",
"draw",
"sample_posterior_predictive",
"sample_prior",
"sample_prior_predictive",
)

Expand Down Expand Up @@ -984,3 +985,91 @@ def sample_posterior_predictive(
idata.extend(idata_pp)
return idata
return idata_pp


def sample_prior(
draws: int = 500,
model: Model | None = None,
var_names: Iterable[str] | None = None,
random_seed: RandomState = None,
return_inferencedata: bool = True,
idata_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
) -> InferenceData | dict[str, np.ndarray]:
"""Generate samples from the prior distribution.

This function samples only from the prior (unobserved random variables)
and deterministics that do not depend on observed variables.

This is different from `sample_prior_predictive` which samples from both
prior and prior predictive distributions.

Parameters
----------
draws : int
Number of samples from the prior to generate. Defaults to 500.
model : Model (optional if in ``with`` context)
var_names : Iterable[str]
A list of names of variables for which to compute the prior samples.
random_seed : int, RandomState or Generator, optional
Seed for the random number generator.
return_inferencedata : bool
Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False).
Defaults to True.
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc.to_inference_data`
compile_kwargs: dict, optional
Keyword arguments for :func:`pymc.pytensorf.compile_pymc`.

Returns
-------
arviz.InferenceData or Dict
An ArviZ ``InferenceData`` object containing the prior samples (default),
or a dictionary with variable names as keys and samples as numpy arrays.

Examples
--------
Basic usage:

.. code:: python

import pymc as pm

with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=[1, 2, 3])

# Sample only from the prior (mu and sigma)
prior_samples = pm.sample_prior(draws=1000)

Specify specific variables:

.. code:: python

with model:
# Sample only mu from the prior
mu_samples = pm.sample_prior(draws=1000, var_names=["mu"])
"""
model = modelcontext(model)

if var_names is None:
# Default to unobserved random variables
var_names = (var.name for var in model.unobserved_RVs)

# Filter out deterministics that depend on observed variables
dependent_dets = observed_dependent_deterministics(model)
var_names = (var_name for var_name in var_names if model[var_name] not in dependent_dets)

# Use sample_prior_predictive with filtered var_names
result = sample_prior_predictive(
draws=draws,
model=model,
var_names=var_names,
random_seed=random_seed,
return_inferencedata=return_inferencedata,
idata_kwargs=idata_kwargs,
compile_kwargs=compile_kwargs,
)

return result
107 changes: 107 additions & 0 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,3 +1801,110 @@ def test_sample_prior_predictive_samples_deprecated_warns() -> None:
match = "The samples argument has been deprecated"
with pytest.warns(DeprecationWarning, match=match):
pm.sample_prior_predictive(model=m, samples=10)


class TestSamplePrior:
def test_basic_prior_sampling(self, seeded_test):
"""Test that sample_prior only samples from unobserved random variables."""
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=[1, 2, 3])
det = pm.Deterministic("det", mu + sigma)

prior_samples = pm.sample_prior(draws=100, return_inferencedata=False)

# Should contain unobserved RVs and deterministics that do not
# depend on observed variables, but not observed variables
assert "mu" in prior_samples
assert "sigma" in prior_samples
assert "det" in prior_samples # deterministic is included
assert "y" not in prior_samples # observed variable

assert prior_samples["mu"].shape == (100,)
assert prior_samples["sigma"].shape == (100,)
assert prior_samples["det"].shape == (100,)

def test_specific_var_names(self, seeded_test):
"""Test sampling specific variables from the prior."""
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=[1, 2, 3])

# Sample only mu
mu_samples = pm.sample_prior(draws=100, var_names=["mu"], return_inferencedata=False)

assert "mu" in mu_samples
assert "sigma" not in mu_samples
assert "y" not in mu_samples
assert mu_samples["mu"].shape == (100,)

def test_multivariate_prior(self, seeded_test):
"""Test sampling from multivariate priors."""
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1, size=3)
# Use a simpler multivariate setup to avoid LKJCholeskyCov issues
mv = pm.MvNormal("mv", mu, cov=np.eye(3), size=4)

prior_samples = pm.sample_prior(draws=50, return_inferencedata=False)

assert "mu" in prior_samples
assert "mv" in prior_samples
assert prior_samples["mu"].shape == (50, 3)
assert prior_samples["mv"].shape == (50, 4, 3)

def test_only_requested_variables(self, seeded_test):
"""Test that sample_prior only returns the requested variables."""
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
sigma = pm.HalfNormal("sigma", 1)
det = pm.Deterministic("det", mu + sigma)
y = pm.Normal("y", det, sigma, observed=[1, 2, 3])

# Request only mu, but y depends on det which depends on mu
prior_samples = pm.sample_prior(draws=100, var_names=["mu"], return_inferencedata=False)

# Should only contain mu, not det or sigma even though they're dependencies
assert "mu" in prior_samples
assert "sigma" not in prior_samples
assert "det" not in prior_samples
assert "y" not in prior_samples
assert len(prior_samples) == 1

def test_deterministics_behavior(self, seeded_test):
"""Test that sample_prior only includes deterministics that don't depend on observed variables."""
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=[1, 2, 3])

# Deterministic that depends only on unobserved RVs
det_prior = pm.Deterministic("det_prior", mu + sigma)

# Deterministic that depends on observed RV
det_obs = pm.Deterministic("det_obs", y + mu)

prior_samples = pm.sample_prior(draws=100, return_inferencedata=False)

# Should include deterministics that depend only on unobserved RVs
assert "det_prior" in prior_samples

# Should NOT include deterministics that depend on observed variables
assert "det_obs" not in prior_samples

# Should not include observed variables
assert "y" not in prior_samples

def test_empty_var_names_behavior(self, seeded_test):
"""Test what happens when we pass an empty var_names set."""
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=[1, 2, 3])
det = pm.Deterministic("det", mu + sigma)

# Test with empty var_names
empty_samples = pm.sample_prior(draws=100, var_names=[], return_inferencedata=False)

assert empty_samples == {}
Loading