diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index 2f1922de..2f0bd96f 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -13,6 +13,7 @@ raise ImportError("Pyro is required to use effectful.handlers.pyro.") import pyro.distributions as dist +from pyro.distributions import Distribution from pyro.distributions.torch_distribution import ( TorchDistribution, TorchDistributionMixin, @@ -135,10 +136,10 @@ def _broadcast_to_named( def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: if typing.TYPE_CHECKING: - assert msg["type"] == "sample" - assert msg["name"] is not None - assert msg["infer"] is not None - assert isinstance(msg["fn"], TorchDistributionMixin) + assert msg.get("type") == "sample" + assert msg.get("name") is not None + assert msg.get("infer") is not None + assert isinstance(msg.get("fn"), TorchDistributionMixin) if pyro.poutine.util.site_is_subsample(msg) or pyro.poutine.util.site_is_factor( msg @@ -155,41 +156,50 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: # compatible with Pyro. In particular, it removes all named dimensions # and stores naming information in the message. Names are replaced by # _pyro_post_sample. - if getattr(self, "_current_site", None) == msg["name"]: + if getattr(self, "_current_site", None) == msg.get("name"): if "_index_naming" in msg: return # We need to identify this pyro shim during post-sample. msg["_pyro_shim_id"] = id(self) # type: ignore[typeddict-unknown-key] - if "_markov_scope" in msg["infer"] and self._current_site: - msg["infer"]["_markov_scope"].pop(self._current_site, None) + if "_markov_scope" in msg.get("infer", {}) and self._current_site: + if "infer" in msg and isinstance(msg["infer"], dict): + if "_markov_scope" in msg["infer"]: + msg["infer"]["_markov_scope"].pop(self._current_site, None) - dist = msg["fn"] - obs = msg["value"] if msg["is_observed"] else None + dist = msg.get("fn") + assert dist is None or isinstance(dist, Distribution) + obs = msg.get("value") if msg.get("is_observed") else None # pdist shape: | named1 | batch_shape | event_shape | # obs shape: | batch_shape | event_shape |, | named2 | where named2 may overlap named1 indices = sizesof(dist) pdist, naming = positional_distribution(dist) - if msg["mask"] is None: + if msg.get("mask") is None: mask = torch.tensor(True) - elif isinstance(msg["mask"], bool): - mask = torch.tensor(msg["mask"]) + elif isinstance(msg.get("mask"), bool): + mask = torch.tensor(msg.get("mask")) else: - mask = msg["mask"] + mask = msg.get("mask") assert set(sizesof(mask).keys()) <= ( set(indices.keys()) | set(sizesof(obs).keys()) ) - pos_mask, _ = PyroShim._broadcast_to_named(mask, dist.batch_shape, indices) - pos_obs: torch.Tensor | None = None - if obs is not None: - pos_obs, naming = PyroShim._broadcast_to_named( - obs, dist.shape(), indices - ) + if isinstance(dist, Distribution): + pos_mask, _ = PyroShim._broadcast_to_named(mask, dist.batch_shape, indices) + + pos_obs: torch.Tensor | None = None + if obs is not None: + pos_obs, naming = PyroShim._broadcast_to_named( + obs, dist.shape(), indices + ) + else: + # Handle the case where dist is None safely + pos_mask, _ = PyroShim._broadcast_to_named(mask, (), indices) + pos_obs = None # Each of the batch dimensions on the distribution gets a # cond_indep_stack frame. @@ -205,25 +215,26 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: size=indices[var], counter=0, ) - msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] + msg["cond_indep_stack"] = (frame,) + msg.get("cond_indep_stack") msg["fn"] = pdist msg["value"] = pos_obs msg["mask"] = pos_mask msg["_index_naming"] = naming # type: ignore - assert sizesof(msg["value"]) == {} - assert sizesof(msg["mask"]) == {} + assert sizesof(msg.get("value")) == {} + assert sizesof(msg.get("mask")) == {} # This branch handles the first call to pyro.sample by calling pyro_sample. else: try: - self._current_site = msg["name"] + infer_data = msg.get("infer") or {} + self._current_site = msg.get("name") msg["value"] = pyro_sample( - msg["name"], - msg["fn"], - obs=msg["value"] if msg["is_observed"] else None, - infer=msg["infer"].copy(), + msg.get("name"), + msg.get("fn"), + obs=msg.get("value") if msg.get("is_observed") else None, + infer=infer_data, ) finally: self._current_site = None @@ -233,32 +244,37 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: msg["done"] = True msg["mask"] = False msg["is_observed"] = True + + # Ensure msg["infer"] exists before modifying it + if "infer" not in msg or msg["infer"] is None: + msg["infer"] = {} # Initialize it as an empty dict + msg["infer"]["is_auxiliary"] = True msg["infer"]["_do_not_trace"] = True def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None: - assert msg["value"] is not None - - # If this message has been handled already by a different pyro shim, ignore. - if "_pyro_shim_id" in msg and msg["_pyro_shim_id"] != id(self): # type: ignore[typeddict-item] - return + if msg.get("value", None) is not None: + # If this message has been handled already by a different pyro shim, ignore. + if "_pyro_shim_id" in msg and msg["_pyro_shim_id"] != id(self): # type: ignore[typeddict-item] + return - if getattr(self, "_current_site", None) == msg["name"]: - assert "_index_naming" in msg + if "name" in msg: + if getattr(self, "_current_site", None) == msg["name"]: + assert "_index_naming" in msg - # note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key - naming = msg["_index_naming"] # type: ignore + # note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key + naming = msg["_index_naming"] # type: ignore - value = msg["value"] + value = msg["value"] - # note: is it safe to assume that msg['fn'] is a distribution? - dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape # type: ignore - if len(value.shape) < len(dist_shape): - value = value.broadcast_to( - torch.broadcast_shapes(value.shape, dist_shape) - ) - value = naming.apply(value) - msg["value"] = value + # note: is it safe to assume that msg['fn'] is a distribution? + dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape # type: ignore + if len(value.shape) < len(dist_shape): + value = value.broadcast_to( + torch.broadcast_shapes(value.shape, dist_shape) + ) + value = naming.apply(value) + msg["value"] = value class Naming: diff --git a/effectful/handlers/torch.py b/effectful/handlers/torch.py index 134fd12d..2a81521b 100644 --- a/effectful/handlers/torch.py +++ b/effectful/handlers/torch.py @@ -2,9 +2,18 @@ import typing from collections.abc import Callable, Collection, Mapping, Sequence from types import EllipsisType +from torch._functorch.apis import grad as torch_grad +from torch._functorch.apis import jacfwd as torch_jacfwd +from torch._functorch.apis import jacrev as torch_jacrev +from torch._functorch.apis import vmap as torch_vmap +from torch._functorch.apis import hessian as torch_hessian +from torch._functorch.apis import jvp as torch_jvp +from torch._functorch.apis import vjp as torch_vjp from typing import ( TypeVar, ) +from torch import vmap + try: import torch @@ -164,7 +173,7 @@ def _partial_eval(t: T, order: Collection[Operation[[], int]] | None = None) -> def wrapper(*args): return index_fn(*args) - tpe_torch_fn = torch.func.vmap(wrapper, randomness="different") + tpe_torch_fn = vmap(wrapper, randomness="different") inds = torch.broadcast_tensors( *( @@ -507,7 +516,7 @@ def index_expr(i): return deindexed, reindex -@functools.wraps(torch.func.grad) +@functools.wraps(torch_grad) def grad(func, *args, **kwargs): """Compute the gradient of a function with respect to its arguments. This is a wrapper around `torch.func.grad` that allows the function to be called @@ -515,42 +524,42 @@ def grad(func, *args, **kwargs): """ (deindexed_func, reindex) = _indexed_func_wrapper(func) - f = _register_torch_op(torch.func.grad(deindexed_func, *args, **kwargs)) + f = _register_torch_op(torch_grad(deindexed_func, *args, **kwargs)) return lambda *a, **k: reindex(f(*a, *k)) -@functools.wraps(torch.func.jacfwd) +@functools.wraps(torch_jacfwd) def jacfwd(func, *args, **kwargs): (deindexed_func, reindex) = _indexed_func_wrapper(func) - jacobian = _register_torch_op(torch.func.jacfwd(deindexed_func, *args, **kwargs)) + jacobian = _register_torch_op(torch_jacfwd(deindexed_func, *args, **kwargs)) return lambda *a, **k: reindex(jacobian(*a, *k)) -@functools.wraps(torch.func.jacrev) +@functools.wraps(torch_jacrev) def jacrev(func, *args, **kwargs): (deindexed_func, reindex) = _indexed_func_wrapper(func) - jacobian = _register_torch_op(torch.func.jacrev(deindexed_func, *args, **kwargs)) + jacobian = _register_torch_op(torch_jacrev(deindexed_func, *args, **kwargs)) return lambda *a, **k: reindex(jacobian(*a, *k)) -@functools.wraps(torch.func.hessian) +@functools.wraps(torch_hessian) def hessian(func, *args, **kwargs): (deindexed_func, reindex) = _indexed_func_wrapper(func) - h = _register_torch_op(torch.func.hessian(deindexed_func, *args, **kwargs)) + h = _register_torch_op(torch_hessian(deindexed_func, *args, **kwargs)) return lambda *a, **k: reindex(h(*a, *k)) -@functools.wraps(torch.func.jvp) +@functools.wraps(torch_jvp) def jvp(func, *args, **kwargs): (deindexed_func, reindex) = _indexed_func_wrapper(func) # hide deindexed_func from _register_torch_op - jvp_func = functools.partial(torch.func.jvp, deindexed_func) + jvp_func = functools.partial(torch_jvp, deindexed_func) ret = _register_torch_op(jvp_func)(*args, **kwargs) return tree.map_structure(reindex, ret) -@functools.wraps(torch.func.vjp) +@functools.wraps(torch_vjp) def vjp(func, *indexed_primals, **kwargs): unpacked_primals = [] for t in indexed_primals: @@ -574,7 +583,7 @@ def wrapper(*primals): ) unindexed_primals = [t[0] for t in unpacked_primals] - _, vjpfunc = torch.func.vjp(wrapper, *unindexed_primals, **kwargs) + _, vjpfunc = torch_vjp(wrapper, *unindexed_primals, **kwargs) def vjpfunc_wrapper(*tangents): unindexed_tangents = tree.map_structure( @@ -586,10 +595,10 @@ def vjpfunc_wrapper(*tangents): return indexed_result, vjpfunc_wrapper -@functools.wraps(torch.func.vmap) +@functools.wraps(torch_vmap) def vmap(func, *args, **kwargs): (deindexed_func, reindex) = _indexed_func_wrapper(func) - vmap_func = _register_torch_op(torch.func.vmap(deindexed_func, *args, **kwargs)) + vmap_func = _register_torch_op(torch_vmap(deindexed_func, *args, **kwargs)) # vmap_func returns tensors of shape [vmap_dim, indexed_dim_1, ..., # indexed_dim_n, pos_dim_1, ..., pos_dim_m], so we reapply indexes starting # at dim 1 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..c45def6a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,20 @@ +[tool.pyright] +include = ["tests", "effectful"] +exclude = ["**/node_modules", + "**/__pycache__", + "effectful.egg-info" +] +pythonVersion = "3.12" +reportMissingImports = true +reportWildcardImportFromLibrary = false +reportAttributeAccessIssue = false +reportCallIssue = false +reportGeneralTypeIssues = false +reportMissingModuleSource = false +reportOperatorIssue = false +reportInvalidTypeForm = false +reportIncompatibleVariableOverride = false +reportArgumentType = false +reportIndexIssue = false +reportPossiblyUnboundVariable = false +reportReturnType = false diff --git a/scripts/pyright.sh b/scripts/pyright.sh new file mode 100755 index 00000000..36a71657 --- /dev/null +++ b/scripts/pyright.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set -euxo pipefail + +SRC="tests/ effectful/" +pyright $SRC diff --git a/setup.py b/setup.py index 5dfcf09e..4c3ade5a 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "pytest-xdist", "pytest-benchmark", "mypy", + "pyright", "ruff", "nbval", "nbqa", diff --git a/tests/test_handlers_pyro.py b/tests/test_handlers_pyro.py index 2769de33..cfb6e8cf 100644 --- a/tests/test_handlers_pyro.py +++ b/tests/test_handlers_pyro.py @@ -124,15 +124,15 @@ def test_smoke_condition_enumerate_hmm_elbo( tr.compute_log_prob() for t in range(num_steps): assert f"x_{t}" in tr.nodes - assert tr.nodes[f"x_{t}"]["type"] == "sample" - assert not tr.nodes[f"x_{t}"]["is_observed"] - assert any(f.name == "plate1" for f in tr.nodes[f"x_{t}"]["cond_indep_stack"]) + assert tr.nodes.get(f"x_{t}", {}).get("type") == "sample" + assert not tr.nodes.get(f"x_{t}", {}).get("is_observed") + assert any(f.name == "plate1" for f in tr.nodes.get(f"x_{t}", {}).get("cond_indep_stack", [])) assert f"y_{t}" in tr.nodes - assert tr.nodes[f"y_{t}"]["type"] == "sample" - assert tr.nodes[f"y_{t}"]["is_observed"] - assert (tr.nodes[f"y_{t}"]["value"] == data[t]).all() - assert any(f.name == "plate1" for f in tr.nodes[f"x_{t}"]["cond_indep_stack"]) + assert tr.nodes.get(f"y_{t}", {}).get("type") == "sample" + assert tr.nodes.get(f"y_{t}", {}).get("is_observed") + assert (tr.nodes.get(f"y_{t}", {}).get("value") == data[t]).all() + assert any(f.name == "plate1" for f in tr.nodes.get(f"x_{t}", {}).get("cond_indep_stack", [])) if use_guide: guide = pyro.infer.config_enumerate(default="parallel")( diff --git a/tests/test_handlers_pyro_dist.py b/tests/test_handlers_pyro_dist.py index 68dd6f30..d0e8da83 100644 --- a/tests/test_handlers_pyro_dist.py +++ b/tests/test_handlers_pyro_dist.py @@ -35,7 +35,7 @@ def random_scale_tril(*args): shape = args data = torch.randn(shape) - return dist.transforms.transform_to(dist.constraints.lower_cholesky)(data) + return dist.transforms.transform_to(dist.constraints.lower_cholesky)(data) # type: ignore @functools.cache