Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 61 additions & 45 deletions effectful/handlers/pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
raise ImportError("Pyro is required to use effectful.handlers.pyro.")

import pyro.distributions as dist
from pyro.distributions import Distribution
from pyro.distributions.torch_distribution import (
TorchDistribution,
TorchDistributionMixin,
Expand Down Expand Up @@ -135,10 +136,10 @@ def _broadcast_to_named(

def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None:
if typing.TYPE_CHECKING:
assert msg["type"] == "sample"
assert msg["name"] is not None
assert msg["infer"] is not None
assert isinstance(msg["fn"], TorchDistributionMixin)
assert msg.get("type") == "sample"
assert msg.get("name") is not None
assert msg.get("infer") is not None
assert isinstance(msg.get("fn"), TorchDistributionMixin)

if pyro.poutine.util.site_is_subsample(msg) or pyro.poutine.util.site_is_factor(
msg
Expand All @@ -155,41 +156,50 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None:
# compatible with Pyro. In particular, it removes all named dimensions
# and stores naming information in the message. Names are replaced by
# _pyro_post_sample.
if getattr(self, "_current_site", None) == msg["name"]:
if getattr(self, "_current_site", None) == msg.get("name"):
if "_index_naming" in msg:
return

# We need to identify this pyro shim during post-sample.
msg["_pyro_shim_id"] = id(self) # type: ignore[typeddict-unknown-key]

if "_markov_scope" in msg["infer"] and self._current_site:
msg["infer"]["_markov_scope"].pop(self._current_site, None)
if "_markov_scope" in msg.get("infer", {}) and self._current_site:
if "infer" in msg and isinstance(msg["infer"], dict):
if "_markov_scope" in msg["infer"]:
msg["infer"]["_markov_scope"].pop(self._current_site, None)

dist = msg["fn"]
obs = msg["value"] if msg["is_observed"] else None
dist = msg.get("fn")
assert dist is None or isinstance(dist, Distribution)
obs = msg.get("value") if msg.get("is_observed") else None

# pdist shape: | named1 | batch_shape | event_shape |
# obs shape: | batch_shape | event_shape |, | named2 | where named2 may overlap named1
indices = sizesof(dist)
pdist, naming = positional_distribution(dist)

if msg["mask"] is None:
if msg.get("mask") is None:
mask = torch.tensor(True)
elif isinstance(msg["mask"], bool):
mask = torch.tensor(msg["mask"])
elif isinstance(msg.get("mask"), bool):
mask = torch.tensor(msg.get("mask"))
else:
mask = msg["mask"]
mask = msg.get("mask")

assert set(sizesof(mask).keys()) <= (
set(indices.keys()) | set(sizesof(obs).keys())
)
pos_mask, _ = PyroShim._broadcast_to_named(mask, dist.batch_shape, indices)

pos_obs: torch.Tensor | None = None
if obs is not None:
pos_obs, naming = PyroShim._broadcast_to_named(
obs, dist.shape(), indices
)
if isinstance(dist, Distribution):
pos_mask, _ = PyroShim._broadcast_to_named(mask, dist.batch_shape, indices)

pos_obs: torch.Tensor | None = None
if obs is not None:
pos_obs, naming = PyroShim._broadcast_to_named(
obs, dist.shape(), indices
)
else:
# Handle the case where dist is None safely
pos_mask, _ = PyroShim._broadcast_to_named(mask, (), indices)
pos_obs = None

# Each of the batch dimensions on the distribution gets a
# cond_indep_stack frame.
Expand All @@ -205,25 +215,26 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None:
size=indices[var],
counter=0,
)
msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
msg["cond_indep_stack"] = (frame,) + msg.get("cond_indep_stack")

msg["fn"] = pdist
msg["value"] = pos_obs
msg["mask"] = pos_mask
msg["_index_naming"] = naming # type: ignore

assert sizesof(msg["value"]) == {}
assert sizesof(msg["mask"]) == {}
assert sizesof(msg.get("value")) == {}
assert sizesof(msg.get("mask")) == {}

# This branch handles the first call to pyro.sample by calling pyro_sample.
else:
try:
self._current_site = msg["name"]
infer_data = msg.get("infer") or {}
self._current_site = msg.get("name")
msg["value"] = pyro_sample(
msg["name"],
msg["fn"],
obs=msg["value"] if msg["is_observed"] else None,
infer=msg["infer"].copy(),
msg.get("name"),
msg.get("fn"),
obs=msg.get("value") if msg.get("is_observed") else None,
infer=infer_data,
)
finally:
self._current_site = None
Expand All @@ -233,32 +244,37 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None:
msg["done"] = True
msg["mask"] = False
msg["is_observed"] = True

# Ensure msg["infer"] exists before modifying it
if "infer" not in msg or msg["infer"] is None:
msg["infer"] = {} # Initialize it as an empty dict

msg["infer"]["is_auxiliary"] = True
msg["infer"]["_do_not_trace"] = True

def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None:
assert msg["value"] is not None

# If this message has been handled already by a different pyro shim, ignore.
if "_pyro_shim_id" in msg and msg["_pyro_shim_id"] != id(self): # type: ignore[typeddict-item]
return
if msg.get("value", None) is not None:
# If this message has been handled already by a different pyro shim, ignore.
if "_pyro_shim_id" in msg and msg["_pyro_shim_id"] != id(self): # type: ignore[typeddict-item]
return

if getattr(self, "_current_site", None) == msg["name"]:
assert "_index_naming" in msg
if "name" in msg:
if getattr(self, "_current_site", None) == msg["name"]:
assert "_index_naming" in msg

# note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key
naming = msg["_index_naming"] # type: ignore
# note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key
naming = msg["_index_naming"] # type: ignore

value = msg["value"]
value = msg["value"]

# note: is it safe to assume that msg['fn'] is a distribution?
dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape # type: ignore
if len(value.shape) < len(dist_shape):
value = value.broadcast_to(
torch.broadcast_shapes(value.shape, dist_shape)
)
value = naming.apply(value)
msg["value"] = value
# note: is it safe to assume that msg['fn'] is a distribution?
dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape # type: ignore
if len(value.shape) < len(dist_shape):
value = value.broadcast_to(
torch.broadcast_shapes(value.shape, dist_shape)
)
value = naming.apply(value)
msg["value"] = value


class Naming:
Expand Down
39 changes: 24 additions & 15 deletions effectful/handlers/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@
import typing
from collections.abc import Callable, Collection, Mapping, Sequence
from types import EllipsisType
from torch._functorch.apis import grad as torch_grad
from torch._functorch.apis import jacfwd as torch_jacfwd
from torch._functorch.apis import jacrev as torch_jacrev
from torch._functorch.apis import vmap as torch_vmap
from torch._functorch.apis import hessian as torch_hessian
from torch._functorch.apis import jvp as torch_jvp
from torch._functorch.apis import vjp as torch_vjp
from typing import (
TypeVar,
)
from torch import vmap


try:
import torch
Expand Down Expand Up @@ -164,7 +173,7 @@ def _partial_eval(t: T, order: Collection[Operation[[], int]] | None = None) ->
def wrapper(*args):
return index_fn(*args)

tpe_torch_fn = torch.func.vmap(wrapper, randomness="different")
tpe_torch_fn = vmap(wrapper, randomness="different")

inds = torch.broadcast_tensors(
*(
Expand Down Expand Up @@ -507,50 +516,50 @@ def index_expr(i):
return deindexed, reindex


@functools.wraps(torch.func.grad)
@functools.wraps(torch_grad)
def grad(func, *args, **kwargs):
"""Compute the gradient of a function with respect to its arguments. This is
a wrapper around `torch.func.grad` that allows the function to be called
with indexed arguments.

"""
(deindexed_func, reindex) = _indexed_func_wrapper(func)
f = _register_torch_op(torch.func.grad(deindexed_func, *args, **kwargs))
f = _register_torch_op(torch_grad(deindexed_func, *args, **kwargs))
return lambda *a, **k: reindex(f(*a, *k))


@functools.wraps(torch.func.jacfwd)
@functools.wraps(torch_jacfwd)
def jacfwd(func, *args, **kwargs):
(deindexed_func, reindex) = _indexed_func_wrapper(func)
jacobian = _register_torch_op(torch.func.jacfwd(deindexed_func, *args, **kwargs))
jacobian = _register_torch_op(torch_jacfwd(deindexed_func, *args, **kwargs))
return lambda *a, **k: reindex(jacobian(*a, *k))


@functools.wraps(torch.func.jacrev)
@functools.wraps(torch_jacrev)
def jacrev(func, *args, **kwargs):
(deindexed_func, reindex) = _indexed_func_wrapper(func)
jacobian = _register_torch_op(torch.func.jacrev(deindexed_func, *args, **kwargs))
jacobian = _register_torch_op(torch_jacrev(deindexed_func, *args, **kwargs))
return lambda *a, **k: reindex(jacobian(*a, *k))


@functools.wraps(torch.func.hessian)
@functools.wraps(torch_hessian)
def hessian(func, *args, **kwargs):
(deindexed_func, reindex) = _indexed_func_wrapper(func)
h = _register_torch_op(torch.func.hessian(deindexed_func, *args, **kwargs))
h = _register_torch_op(torch_hessian(deindexed_func, *args, **kwargs))
return lambda *a, **k: reindex(h(*a, *k))


@functools.wraps(torch.func.jvp)
@functools.wraps(torch_jvp)
def jvp(func, *args, **kwargs):
(deindexed_func, reindex) = _indexed_func_wrapper(func)

# hide deindexed_func from _register_torch_op
jvp_func = functools.partial(torch.func.jvp, deindexed_func)
jvp_func = functools.partial(torch_jvp, deindexed_func)
ret = _register_torch_op(jvp_func)(*args, **kwargs)
return tree.map_structure(reindex, ret)


@functools.wraps(torch.func.vjp)
@functools.wraps(torch_vjp)
def vjp(func, *indexed_primals, **kwargs):
unpacked_primals = []
for t in indexed_primals:
Expand All @@ -574,7 +583,7 @@ def wrapper(*primals):
)

unindexed_primals = [t[0] for t in unpacked_primals]
_, vjpfunc = torch.func.vjp(wrapper, *unindexed_primals, **kwargs)
_, vjpfunc = torch_vjp(wrapper, *unindexed_primals, **kwargs)

def vjpfunc_wrapper(*tangents):
unindexed_tangents = tree.map_structure(
Expand All @@ -586,10 +595,10 @@ def vjpfunc_wrapper(*tangents):
return indexed_result, vjpfunc_wrapper


@functools.wraps(torch.func.vmap)
@functools.wraps(torch_vmap)
def vmap(func, *args, **kwargs):
(deindexed_func, reindex) = _indexed_func_wrapper(func)
vmap_func = _register_torch_op(torch.func.vmap(deindexed_func, *args, **kwargs))
vmap_func = _register_torch_op(torch_vmap(deindexed_func, *args, **kwargs))
# vmap_func returns tensors of shape [vmap_dim, indexed_dim_1, ...,
# indexed_dim_n, pos_dim_1, ..., pos_dim_m], so we reapply indexes starting
# at dim 1
Expand Down
20 changes: 20 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[tool.pyright]
include = ["tests", "effectful"]
exclude = ["**/node_modules",
"**/__pycache__",
"effectful.egg-info"
]
pythonVersion = "3.12"
reportMissingImports = true
reportWildcardImportFromLibrary = false
reportAttributeAccessIssue = false
reportCallIssue = false
reportGeneralTypeIssues = false
reportMissingModuleSource = false
reportOperatorIssue = false
reportInvalidTypeForm = false
reportIncompatibleVariableOverride = false
reportArgumentType = false
reportIndexIssue = false
reportPossiblyUnboundVariable = false
reportReturnType = false
5 changes: 5 additions & 0 deletions scripts/pyright.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
set -euxo pipefail

SRC="tests/ effectful/"
pyright $SRC
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"pytest-xdist",
"pytest-benchmark",
"mypy",
"pyright",
"ruff",
"nbval",
"nbqa",
Expand Down
14 changes: 7 additions & 7 deletions tests/test_handlers_pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,15 @@ def test_smoke_condition_enumerate_hmm_elbo(
tr.compute_log_prob()
for t in range(num_steps):
assert f"x_{t}" in tr.nodes
assert tr.nodes[f"x_{t}"]["type"] == "sample"
assert not tr.nodes[f"x_{t}"]["is_observed"]
assert any(f.name == "plate1" for f in tr.nodes[f"x_{t}"]["cond_indep_stack"])
assert tr.nodes.get(f"x_{t}", {}).get("type") == "sample"
assert not tr.nodes.get(f"x_{t}", {}).get("is_observed")
assert any(f.name == "plate1" for f in tr.nodes.get(f"x_{t}", {}).get("cond_indep_stack", []))

assert f"y_{t}" in tr.nodes
assert tr.nodes[f"y_{t}"]["type"] == "sample"
assert tr.nodes[f"y_{t}"]["is_observed"]
assert (tr.nodes[f"y_{t}"]["value"] == data[t]).all()
assert any(f.name == "plate1" for f in tr.nodes[f"x_{t}"]["cond_indep_stack"])
assert tr.nodes.get(f"y_{t}", {}).get("type") == "sample"
assert tr.nodes.get(f"y_{t}", {}).get("is_observed")
assert (tr.nodes.get(f"y_{t}", {}).get("value") == data[t]).all()
assert any(f.name == "plate1" for f in tr.nodes.get(f"x_{t}", {}).get("cond_indep_stack", []))

if use_guide:
guide = pyro.infer.config_enumerate(default="parallel")(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_handlers_pyro_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def random_scale_tril(*args):
shape = args

data = torch.randn(shape)
return dist.transforms.transform_to(dist.constraints.lower_cholesky)(data)
return dist.transforms.transform_to(dist.constraints.lower_cholesky)(data) # type: ignore


@functools.cache
Expand Down