diff --git a/chirho/counterfactual/handlers/exogenation.py b/chirho/counterfactual/handlers/exogenation.py new file mode 100644 index 00000000..f31d2e48 --- /dev/null +++ b/chirho/counterfactual/handlers/exogenation.py @@ -0,0 +1,162 @@ +import pyro +from pyro.poutine.messenger import Messenger +from pyro.distributions.torch_distribution import TorchDistributionMixin +from torch.distributions import TransformedDistribution +import pyro.distributions as dist + +EXOGENATE_META_KEY = "exogenate_meta" + +# TODO check if the base predicate matches Uniform(0, 1) first, and if it does, sample uniform and push through icdf, no unwrapping needed +# because transforms handle that internally when calling icdf on them. + + +class _NonInvertibleTransformError(Exception): + """Raised when transforms cannot be inverted analytically.""" + pass + + +def _invert_transforms_to_noise(value, transforms): + """ + Apply inverse transforms in reverse order to recover noise. + + Raises _NonInvertibleTransformError if any transform is not bijective. + """ + recovered = value + for transform in reversed(transforms): + try: + recovered = transform.inv(recovered) + except (NotImplementedError, RuntimeError) as e: + # TODO not sure if all valid torch transforms are invertible, but keeping as a conceptual hook for surjective transforms. + raise _NonInvertibleTransformError( + f"Transform {type(transform).__name__} is not invertible: {e}" + ) from e + return recovered + + +class ExogenateNoiseMessenger(Messenger): + + def _sample_noise_observed(self, name, current_dist, obs_value, transforms, original_fn): + """Sample noise for an observed site by inverting transforms. This is effectively closed form posterior inference on the noise. + """ + try: + # Compute implied noise by inverting transforms + u_implied = _invert_transforms_to_noise(obs_value, transforms) + except _NonInvertibleTransformError: + # TODO: Think about whether the base distribution should function as a prior on exogenous noise + # under soft conditioning when the inverse map isn't defined analytically. + # For non-invertible transforms, we can't analytically compute the noise, + # so users must rely on soft conditioning at the Delta level for inference to work properly. We would + # want to observe the delta site in that case. + raise NotImplementedError( + f"Cannot exogenate observed site '{name}': transforms are not bijective. " + ) + + # Sample noise at implied value, but mask its log_prob contribution + # The noise still appears in the trace and propagates to counterfactual worlds, + # but doesn't add prior probability to the model + with pyro.poutine.mask(mask=False): + u = pyro.sample(f"{name}_u", current_dist, obs=u_implied) # TODO make _u suffix passable to the handler. + + # Replace the masked noise prior contribution with the original distribution's log_prob + original_log_prob = original_fn.log_prob(obs_value) + pyro.factor(f"{name}_log_prob", original_log_prob) + + return u + + def _sample_noise_unobserved(self, name, current_dist): + """Sample noise for an unobserved site.""" + return pyro.sample(f"{name}_u", current_dist) # TODO make _u suffix passable to the handler. + + # (folded into _pyro_sample) + + def _pyro_sample(self, msg: dict) -> None: + exogenate_meta = msg.get("infer", {}).get(EXOGENATE_META_KEY, None) + if exogenate_meta is None or "base_dist_predicate" not in exogenate_meta: + return + + base_dist_predicate = exogenate_meta["base_dist_predicate"] + original_fn = msg["fn"] + name = msg["name"] + is_observed = msg.get("is_observed", False) + obs_value = msg.get("obs", None) if is_observed else None + + # Handle Independent wrapper + current_dist = original_fn + if isinstance(current_dist, dist.Independent): + event_dim = current_dist.reinterpreted_batch_ndims + current_dist = current_dist.base_dist + else: + event_dim = original_fn.event_dim if hasattr(original_fn, 'event_dim') else 0 + + # Unwrap TransformedDistribution layers and collect transforms until predicate matches + transforms = [] + while isinstance(current_dist, TransformedDistribution) and not base_dist_predicate(current_dist): + # Prepend transforms to maintain correct order when applying them later + transforms = list(current_dist.transforms) + transforms + current_dist = current_dist.base_dist + + # Validate that we found a matching distribution + if not base_dist_predicate(current_dist): + raise ValueError( + f"Could not find base distribution matching predicate in distribution chain. " + f"Reached {type(current_dist).__name__} instead." + ) + + # Validate that the distribution is Pyro-compatible + if not isinstance(current_dist, TorchDistributionMixin): + raise TypeError( + f"Cannot exogenate to {type(current_dist).__name__}: not a Pyro-compatible distribution. " + f"The distribution must be an instance of a Pyro distribution (have TorchDistributionMixin). " + f"Consider exogenating to a base Pyro distribution like Normal, Exponential, etc. See Pyro's" + f"LogNormal implementation for an example of how define a distribution who's base is a Pyro distribution." + ) + + # Sample noise (method differs for observed vs unobserved) + if is_observed: + u = self._sample_noise_observed(name, current_dist, obs_value, transforms, original_fn) + else: + u = self._sample_noise_unobserved(name, current_dist) + + # Apply forward transforms to push noise back to original space + # TODO if available, use the observed value to avoid numerical issues with inverese transforms? + x = u + for transform in transforms: + x = transform(x) + + # Create Delta sample site that interventions can target + delta_infer = {EXOGENATE_META_KEY: {"original_fn": original_fn}} + msg["value"] = pyro.sample(name, dist.Delta(x, event_dim=event_dim), infer=delta_infer) + + # Stop so that everything else looks at the new sample site instead + msg["stop"] = True + + +def sample_exogenated(name, fn, base_dist_predicate, **kwargs): + """ + Sample from a distribution while exogenating noise to a base distribution. + + This function unwraps TransformedDistribution layers until reaching a base distribution + that matches the provided predicate. It samples noise from that base, then applies + transforms to get back to the original distribution's space. This creates separate + sample sites for noise and reparameterized value, allowing interventions that leave upstream + noise shared and intact across parallel worlds. + + :param name: Name of the sample site + :param fn: Distribution to sample from (may be transformed) + :param base_dist_predicate: Callable that takes a distribution and returns True + if it's the desired base to exogenate to. + The matched distribution must be Pyro-compatible. + :param kwargs: Additional keyword arguments to pass to pyro.sample + :return: Sample value + + Example:: + # Unwrap to any Normal distribution in the chain + sample_exogenated("x", dist.LogNormal(0, 1), lambda d: isinstance(d, dist.Normal)) + + # Unwrap to Normal with specific parameters + sample_exogenated("z", transformed_dist, + lambda d: isinstance(d, dist.Normal) and d.loc == 0) + """ + infer = dict(kwargs.pop("infer", {})) + infer[EXOGENATE_META_KEY] = {"base_dist_predicate": base_dist_predicate} + return pyro.sample(name, fn, infer=infer, **kwargs) diff --git a/tests/counterfactual/test_exogenation.py b/tests/counterfactual/test_exogenation.py new file mode 100644 index 00000000..64ae6a45 --- /dev/null +++ b/tests/counterfactual/test_exogenation.py @@ -0,0 +1,329 @@ +import logging + +import pyro +import pyro.distributions as dist +import pytest +import torch +from pyro.distributions.torch_distribution import TorchDistributionMixin +from torch.distributions.transforms import AffineTransform, ExpTransform + +from chirho.counterfactual.handlers import MultiWorldCounterfactual +from chirho.counterfactual.handlers.exogenation import ( + ExogenateNoiseMessenger, + sample_exogenated, +) +from chirho.interventional.handlers import do +from chirho.observational.handlers import condition + +logger = logging.getLogger(__name__) + + +def _recover_noise_from_value(value, distribution, base_dist_predicate): + """Helper to recover noise by applying inverse transforms until reaching base matching predicate. + + Args: + value: The observed value to invert + distribution: The original distribution + base_dist_predicate: Predicate that returns True for the target base distribution + + Returns: + The recovered noise value + """ + import torch.distributions + + recovered_u = value + current_dist = distribution + + # Unwrap Independent if present + if isinstance(current_dist, dist.Independent): + current_dist = current_dist.base_dist + + # Walk outward-to-inward, applying inverse transforms until reaching the predicate match + # Note: Need to check for both dist.TransformedDistribution and torch.distributions.TransformedDistribution + # because some Pyro distributions like LogNormal are based on PyTorch's TransformedDistribution + while isinstance(current_dist, (dist.TransformedDistribution, torch.distributions.TransformedDistribution)): + if base_dist_predicate(current_dist): + break + for transform in reversed(current_dist.transforms): + recovered_u = transform.inv(recovered_u) + current_dist = current_dist.base_dist + if base_dist_predicate(current_dist): + break + + return recovered_u + + +@pytest.mark.parametrize( + "distribution,base_dist_predicate", + [ + # Simple distributions (no transforms to unwrap, so just a trivial structural equation of y = y_u) + (dist.Normal(2.0, 0.5), lambda d: isinstance(d, dist.Normal)), + (dist.Exponential(1.0), lambda d: isinstance(d, dist.Exponential)), + # LogNormal is TransformedDistribution(Normal, ExpTransform), so base is Normal + (dist.LogNormal(0.0, 1.0), lambda d: isinstance(d, dist.Normal)), + (dist.LogNormal(1.0, 0.5), lambda d: isinstance(d, dist.Normal)), + # Explicitly transformed distributions to test unwrapping all the way to root + # Normal(0, 1) -> affine transform + (dist.TransformedDistribution(dist.Normal(0.0, 1.0), [AffineTransform(loc=3.0, scale=2.0)]), + lambda d: isinstance(d, dist.Normal)), + # Normal(0, 1) -> exp transform + (dist.TransformedDistribution(dist.Normal(0.0, 1.0), [ExpTransform()]), + lambda d: isinstance(d, dist.Normal)), + # Multiple nested transforms: should unwrap all the way to Normal + (dist.TransformedDistribution( + dist.TransformedDistribution(dist.Normal(0.0, 1.0), [AffineTransform(loc=0.0, scale=0.5)]), + [ExpTransform(), AffineTransform(loc=0.0, scale=2.0)] + ), lambda d: isinstance(d, dist.Normal)), + # LogNormal with additional transforms - stop at LogNormal (intermediate Pyro distribution) + # LogNormal(0, 1) is already TransformedDistribution(Normal, ExpTransform) + # Add affine transform on top: LogNormal -> scale by 2 + (dist.TransformedDistribution(dist.LogNormal(0.0, 1.0), [AffineTransform(loc=0.0, scale=2.0)]), + lambda d: isinstance(d, dist.LogNormal)), + # Same distribution but unwrap all the way to Normal (the root) + (dist.TransformedDistribution(dist.LogNormal(0.0, 1.0), [AffineTransform(loc=0.0, scale=2.0)]), + lambda d: isinstance(d, dist.Normal)), + # Test Independent wrapper - should unwrap Independent and then find Normal + (dist.Independent(dist.Normal(torch.zeros(3), torch.ones(3)), 1), + lambda d: isinstance(d, dist.Normal)), + # Test Independent wrapper around TransformedDistribution + (dist.Independent(dist.LogNormal(torch.zeros(2, 3), torch.ones(2, 3)), 2), + lambda d: isinstance(d, dist.Normal)), + ], +) +def test_exogenate_to_base_distributions(distribution, base_dist_predicate): + """Test that exogenation unwraps to the base distribution matching the predicate.""" + + def model(): + y = sample_exogenated("y", distribution, base_dist_predicate) + return y + + with pyro.poutine.trace() as tr: + with ExogenateNoiseMessenger(): + y = model() + + # Check that both y and y_u are in the trace + assert "y" in tr.trace.nodes + assert "y_u" in tr.trace.nodes + + # Check that y_u was sampled from a distribution matching the predicate + y_u_fn = tr.trace.nodes["y_u"]["fn"] + assert base_dist_predicate(y_u_fn) + + # Check that y is a Delta distribution + y_fn = tr.trace.nodes["y"]["fn"] + assert isinstance(y_fn, dist.Delta) + + # Check that original_fn is stored in exogenate_meta + assert "exogenate_meta" in tr.trace.nodes["y"]["infer"] + assert "original_fn" in tr.trace.nodes["y"]["infer"]["exogenate_meta"] + assert tr.trace.nodes["y"]["infer"]["exogenate_meta"]["original_fn"] is distribution + # Verify base_dist_predicate is NOT in the Delta's metadata (prevents re-triggering) + assert "base_dist_predicate" not in tr.trace.nodes["y"]["infer"]["exogenate_meta"] + + # Verify we can recover y_u by applying inverse transforms until reaching base matching predicate + y_value = tr.trace.nodes["y"]["value"] + u_value = tr.trace.nodes["y_u"]["value"] + + recovered_u = _recover_noise_from_value(y_value, distribution, base_dist_predicate) + assert torch.allclose(recovered_u, u_value, atol=1e-5) + + +@pytest.mark.parametrize( + "distribution,invalid_predicate", + [ + # LogNormal is TransformedDistribution(Normal, ExpTransform), so Exponential is not in the chain + (dist.LogNormal(0.0, 1.0), lambda d: isinstance(d, dist.Exponential)), + # Normal has no transforms, so Exponential is not in the chain + (dist.Normal(0.0, 1.0), lambda d: isinstance(d, dist.Exponential)), + # Exponential has no transforms, so Normal is not in the chain + (dist.Exponential(1.0), lambda d: isinstance(d, dist.Normal)), + ], +) +def test_exogenate_invalid_predicate_error(distribution, invalid_predicate): + """Test that exogenation raises ValueError when predicate doesn't match any dist in the chain.""" + + def model(): + y = sample_exogenated("y", distribution, invalid_predicate) + return y + + with pytest.raises(ValueError, match=r"Could not find base distribution matching predicate"): + with pyro.poutine.trace(): + with ExogenateNoiseMessenger(): + model() + + +def test_exogenate_non_pyro_distribution_error(): + """Test that exogenation raises TypeError when predicate matches a non-Pyro distribution.""" + + # Import torch's TransformedDistribution (not Pyro's) for this test + from torch.distributions import TransformedDistribution as TorchTransformedDistribution + + # Create a chain with a vanilla PyTorch distribution that won't have the Pyro mixin + # We'll use a predicate that intentionally matches a TransformedDistribution + vanilla_base = torch.distributions.Normal(0.0, 1.0) + vanilla_transformed = TorchTransformedDistribution(vanilla_base, [AffineTransform(loc=0.0, scale=1.0)]) + + # Wrap in a Pyro distribution to get past initial checks + pyro_dist = dist.TransformedDistribution(vanilla_base, [ExpTransform()]) + + # Try to match the vanilla_base (which is torch.distributions.Normal, not pyro.distributions.Normal) + def model(): + y = sample_exogenated("y", pyro_dist, lambda d: type(d).__name__ == "Normal" and not isinstance(d, TorchDistributionMixin)) + return y + + # This should raise TypeError because vanilla_base doesn't have TorchDistributionMixin + with pytest.raises(TypeError, match=r"Cannot exogenate to .* not a Pyro-compatible distribution"): + with pyro.poutine.trace(): + with ExogenateNoiseMessenger(): + model() + + +@pytest.mark.parametrize( + "distribution,base_dist_predicate", + [ + # Vector Normal + (dist.Normal(torch.zeros(6), torch.ones(6)), lambda d: isinstance(d, dist.Normal)), + # Vector LogNormal unwrapping to Normal + (dist.LogNormal(torch.zeros(5), torch.ones(5)), lambda d: isinstance(d, dist.Normal)), + # Independent(LogNormal) with reinterpreted batch dims + (dist.Independent(dist.LogNormal(torch.zeros(4, 3), torch.ones(4, 3)), 1), lambda d: isinstance(d, dist.Normal)), + # Transformed(LogNormal) stopping at intermediate LogNormal + (dist.TransformedDistribution(dist.LogNormal(torch.zeros(4), torch.ones(4)), [AffineTransform(loc=0.0, scale=2.0)]), lambda d: isinstance(d, dist.LogNormal)), + ], +) +def test_exogenate_multiworld_counterfactual(distribution, base_dist_predicate): + """Test that exogenation works correctly with MultiWorldCounterfactual. + + The key property: noise (y_u) should be shared across factual and counterfactual worlds, + while interventions on y should only affect the counterfactual world. + """ + + intervention_value = torch.tensor(5.0) + event_dim = len(distribution.event_shape) + + def model(): + y = sample_exogenated("y", distribution, base_dist_predicate) + z = pyro.sample("z", dist.Normal(y, 1.0)) + + with MultiWorldCounterfactual(first_available_dim=-2): + with do(actions={"y": intervention_value}): + with pyro.poutine.trace() as tr: # FIXME this must go under do and above ExogenateNoiseMessenger. :/ + with ExogenateNoiseMessenger(): + model() + + # Check that both y and y_u are in the trace + assert "y" in tr.trace.nodes + assert "y_u" in tr.trace.nodes + assert "z" in tr.trace.nodes + + # Get values from trace + y_u_value = tr.trace.nodes["y_u"]["value"] + y_value = tr.trace.nodes["y"]["value"] + z_value = tr.trace.nodes["z"]["value"] + + logger.info(f"y_u_value shape: {y_u_value.shape}") + logger.info(f"y_value shape: {y_value.shape}") + logger.info(f"z_value shape: {z_value.shape}") + + # Verify y_u distribution matches the predicate + y_u_fn = tr.trace.nodes["y_u"]["fn"] + assert base_dist_predicate(y_u_fn), "y_u should be sampled from distribution matching predicate" + + # Check that y is a Delta distribution (exogenated) + y_fn = tr.trace.nodes["y"]["fn"] + assert isinstance(y_fn, dist.Delta) + + # Check world splitting by examining shapes + # y_u should NOT have a world dimension (shared noise across worlds) + # y and z SHOULD have a world dimension of size 2 at position 0 (factual + counterfactual) + assert y_value.shape[0] == 2, f"y should have world dimension of size 2 at position 0, got shape {y_value.shape}" + assert z_value.shape[0] == 2, f"z should have world dimension of size 2 at position 0, got shape {z_value.shape}" + + # y_u should have same shape as one world's worth of y (no world dimension) + assert y_u_value.shape == y_value[0].shape, \ + f"y_u should have same shape as single world of y, got y_u: {y_u_value.shape}, y[0]: {y_value[0].shape}" + + # Extract factual (index 0) and counterfactual (index 1) worlds + y_factual = y_value[0] + y_counterfactual = y_value[1] + + # Reconstruct noise from the factual world by inverting transforms + recovered_u = _recover_noise_from_value(y_factual, distribution, base_dist_predicate) + assert torch.allclose(recovered_u, y_u_value, atol=1e-5) + + # Verify intervention is applied to counterfactual world, and factual differs + assert torch.allclose( + y_counterfactual, + intervention_value.expand_as(y_counterfactual), + atol=0, + rtol=0, + ), "Counterfactual world should equal the intervention value" + # Factual should not be identically equal to the intervention; allow rare coincidences by checking any element differs + assert (y_factual - intervention_value).abs().max() > 1e-6, ( + "Factual world unexpectedly equals the intervention value everywhere" + ) + + +def test_observe_downstream_and_intervene_upstream_shared_noise(): + """Observe y and intervene on upstream x; ensure implied noise is shared across worlds. + + Model: x ~ Normal(0,1); y = transform(x, u) with base noise u ~ Normal(0,1) (via exogenation). + We observe y=y_obs in factual world, reconstruct u via inverse transforms, and then under + counterfactual intervention x=x_cf the counterfactual y_cf should be computed using the same u. + """ + + # Distribution for y|x: Transformed(Normal(0,1) -> Affine(loc=x) -> Exp) + base_dist_predicate = lambda d: isinstance(d, dist.Normal) + def y_dist_given(x): + return dist.TransformedDistribution( + dist.Normal(0.0, 1.0), + [AffineTransform(loc=x, scale=1.0), ExpTransform()], + ) + + # observed y and counterfactual x + y_factual = torch.tensor(3.0) + x_cf = torch.tensor(0.7) + + def model(): + x = pyro.sample("x", dist.Normal(0.0, 1.0)) + y = sample_exogenated("y", y_dist_given(x), base_dist_predicate) #, obs=y_factual) + return x, y + + with MultiWorldCounterfactual(first_available_dim=-2): + # Observe downstream y and factual x; intervene upstream on x for CF world + with do(actions={"x": x_cf}): + with pyro.poutine.trace() as tr: + with ExogenateNoiseMessenger(): + model() + + # Check presence of sites + assert "y_u" in tr.trace.nodes and "y" in tr.trace.nodes and "x" in tr.trace.nodes + + # Extract value + x_value = tr.trace.nodes["x"]["value"] + y_u_value = tr.trace.nodes["y_u"]["value"] + y_value = tr.trace.nodes["y"]["value"] + + # y should have two worlds (factual, counterfactual) + assert y_value.shape[0] == 2 + + # Noise should be shared (no world dimension) + assert y_u_value.shape == y_value[0].shape + + # Validate inverse relationship for factual world: y_obs = exp(x_factual + y_u) + # So y_u should equal log(y_obs) - x_factual + expected_u = torch.log(y_factual) - x_factual + assert torch.allclose(y_u_value, expected_u, atol=1e-5) + + # Counterfactual world should use same u but with x_cf: y_cf = exp(x_cf + u) + y_cf_expected = torch.exp(x_cf + y_u_value) + y_factual, y_counterfactual = y_value[0], y_value[1] + assert torch.allclose(y_factual, y_factual, atol=1e-5) + assert torch.allclose(y_counterfactual, y_cf_expected, atol=1e-5) + + +if __name__ == "__main__": + # Run specific tests for debugging + print("Running test_exogenate_multiworld_counterfactual...") + test_observe_downstream_and_intervene_upstream_shared_noise() + print("\n✅ Test passed!")