From 2d7e457ef90ee3480d69d2e28f04525f1449313d Mon Sep 17 00:00:00 2001 From: qinqian Date: Sat, 1 Feb 2025 20:51:18 -0500 Subject: [PATCH 1/2] initialize the branch; ignore all pyright errors for now --- Makefile | 3 +++ pyproject.toml | 19 +++++++++++++++++++ scripts/pyright.sh | 5 +++++ setup.py | 1 + 4 files changed, 28 insertions(+) create mode 100644 pyproject.toml create mode 100755 scripts/pyright.sh diff --git a/Makefile b/Makefile index 49c19053..41b0a607 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,9 @@ lint: FORCE lint-notebooks: ./scripts/lint_notebooks.sh +pyright: + ./scripts/pyright.sh + format: ./scripts/clean.sh diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..9f470989 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,19 @@ +[tool.pyright] +include = ["tests", "effectful"] +pythonVersion = "3.12" +exclude = ["**/node_modules", + "**/__pycache__", + "effectful.egg-info" +] +reportMissingImports = false +reportAttributeAccessIssue = false +reportCallIssue = false +reportGeneralTypeIssues = false +reportMissingModuleSource = false +reportOperatorIssue = false +reportInvalidTypeForm = false +reportIncompatibleVariableOverride = false +reportArgumentType = false +reportIndexIssue = false +reportPossiblyUnboundVariable = false +reportReturnType = false diff --git a/scripts/pyright.sh b/scripts/pyright.sh new file mode 100755 index 00000000..36a71657 --- /dev/null +++ b/scripts/pyright.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set -euxo pipefail + +SRC="tests/ effectful/" +pyright $SRC diff --git a/setup.py b/setup.py index 5dfcf09e..4c3ade5a 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "pytest-xdist", "pytest-benchmark", "mypy", + "pyright", "ruff", "nbval", "nbqa", From d94cda6c50e463343ce0787ca51d3f3b7aeb5476 Mon Sep 17 00:00:00 2001 From: qinqian Date: Sun, 2 Feb 2025 17:12:34 -0500 Subject: [PATCH 2/2] fix most of the reportMissingImports and TypedDict errors in pyright --- effectful/handlers/pyro.py | 85 +++++++++++++++++++------------------ effectful/handlers/torch.py | 37 +++++++++------- pyproject.toml | 4 +- tests/test_handlers_pyro.py | 14 +++--- 4 files changed, 75 insertions(+), 65 deletions(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index 2f1922de..49e757e6 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -135,10 +135,10 @@ def _broadcast_to_named( def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: if typing.TYPE_CHECKING: - assert msg["type"] == "sample" - assert msg["name"] is not None - assert msg["infer"] is not None - assert isinstance(msg["fn"], TorchDistributionMixin) + assert msg.get("type") == "sample" + assert msg.get("name") is not None + assert msg.get("infer") is not None + assert isinstance(msg.get("fn"), TorchDistributionMixin) if pyro.poutine.util.site_is_subsample(msg) or pyro.poutine.util.site_is_factor( msg @@ -155,30 +155,32 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: # compatible with Pyro. In particular, it removes all named dimensions # and stores naming information in the message. Names are replaced by # _pyro_post_sample. - if getattr(self, "_current_site", None) == msg["name"]: + if getattr(self, "_current_site", None) == msg.get("name"): if "_index_naming" in msg: return # We need to identify this pyro shim during post-sample. msg["_pyro_shim_id"] = id(self) # type: ignore[typeddict-unknown-key] - if "_markov_scope" in msg["infer"] and self._current_site: - msg["infer"]["_markov_scope"].pop(self._current_site, None) + if "_markov_scope" in msg.get("infer") and self._current_site: + if "infer" in msg: + if "_markov_scope" in msg["infer"]: + msg["infer"]["_markov_scope"].pop(self._current_site, None) - dist = msg["fn"] - obs = msg["value"] if msg["is_observed"] else None + dist = msg.get("fn") + obs = msg.get("value") if msg.get("is_observed") else None # pdist shape: | named1 | batch_shape | event_shape | # obs shape: | batch_shape | event_shape |, | named2 | where named2 may overlap named1 indices = sizesof(dist) pdist, naming = positional_distribution(dist) - if msg["mask"] is None: + if msg.get("mask") is None: mask = torch.tensor(True) - elif isinstance(msg["mask"], bool): - mask = torch.tensor(msg["mask"]) + elif isinstance(msg.get("mask"), bool): + mask = torch.tensor(msg.get("mask")) else: - mask = msg["mask"] + mask = msg.get("mask") assert set(sizesof(mask).keys()) <= ( set(indices.keys()) | set(sizesof(obs).keys()) @@ -205,25 +207,25 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: size=indices[var], counter=0, ) - msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] + msg["cond_indep_stack"] = (frame,) + msg.get("cond_indep_stack") msg["fn"] = pdist msg["value"] = pos_obs msg["mask"] = pos_mask msg["_index_naming"] = naming # type: ignore - assert sizesof(msg["value"]) == {} - assert sizesof(msg["mask"]) == {} + assert sizesof(msg.get("value")) == {} + assert sizesof(msg.get("mask")) == {} # This branch handles the first call to pyro.sample by calling pyro_sample. else: try: - self._current_site = msg["name"] + self._current_site = msg.get("name") msg["value"] = pyro_sample( - msg["name"], - msg["fn"], - obs=msg["value"] if msg["is_observed"] else None, - infer=msg["infer"].copy(), + msg.get("name"), + msg.get("fn"), + obs=msg.get("value") if msg.get("is_observed") else None, + infer=msg.get("infer").copy(), ) finally: self._current_site = None @@ -233,32 +235,33 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: msg["done"] = True msg["mask"] = False msg["is_observed"] = True - msg["infer"]["is_auxiliary"] = True - msg["infer"]["_do_not_trace"] = True + if "infer" in msg: + msg["infer"]["is_auxiliary"] = True + msg["infer"]["_do_not_trace"] = True def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None: - assert msg["value"] is not None - - # If this message has been handled already by a different pyro shim, ignore. - if "_pyro_shim_id" in msg and msg["_pyro_shim_id"] != id(self): # type: ignore[typeddict-item] - return + if msg.get("value", None) is not None: + # If this message has been handled already by a different pyro shim, ignore. + if "_pyro_shim_id" in msg and msg["_pyro_shim_id"] != id(self): # type: ignore[typeddict-item] + return - if getattr(self, "_current_site", None) == msg["name"]: - assert "_index_naming" in msg + if "name" in msg: + if getattr(self, "_current_site", None) == msg["name"]: + assert "_index_naming" in msg - # note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key - naming = msg["_index_naming"] # type: ignore + # note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key + naming = msg["_index_naming"] # type: ignore - value = msg["value"] + value = msg["value"] - # note: is it safe to assume that msg['fn'] is a distribution? - dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape # type: ignore - if len(value.shape) < len(dist_shape): - value = value.broadcast_to( - torch.broadcast_shapes(value.shape, dist_shape) - ) - value = naming.apply(value) - msg["value"] = value + # note: is it safe to assume that msg['fn'] is a distribution? + dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape # type: ignore + if len(value.shape) < len(dist_shape): + value = value.broadcast_to( + torch.broadcast_shapes(value.shape, dist_shape) + ) + value = naming.apply(value) + msg["value"] = value class Naming: diff --git a/effectful/handlers/torch.py b/effectful/handlers/torch.py index ab3eeeae..8fb9a264 100644 --- a/effectful/handlers/torch.py +++ b/effectful/handlers/torch.py @@ -2,6 +2,13 @@ import typing from collections.abc import Callable, Collection, Mapping, Sequence from types import EllipsisType +from torch._functorch.apis import grad as torch_grad +from torch._functorch.apis import jacfwd as torch_jacfwd +from torch._functorch.apis import jacrev as torch_jacrev +from torch._functorch.apis import vmap as torch_vmap +from torch._functorch.apis import hessian as torch_hessian +from torch._functorch.apis import jvp as torch_jvp +from torch._functorch.apis import vjp as torch_vjp from typing import ( TypeVar, ) @@ -156,7 +163,7 @@ def _partial_eval(t: T, order: Collection[Operation[[], int]] | None = None) -> ] ordered_sized_fvs = reindex_fvs + [(var, sized_fvs[var]) for var in order] - tpe_torch_fn = torch.func.vmap( + tpe_torch_fn = torch_vmap( deffn(t, *[var for (var, _) in ordered_sized_fvs]), randomness="different" ) @@ -501,7 +508,7 @@ def index_expr(i): return deindexed, reindex -@functools.wraps(torch.func.grad) +@functools.wraps(torch_grad) def grad(func, *args, **kwargs): """Compute the gradient of a function with respect to its arguments. This is a wrapper around `torch.func.grad` that allows the function to be called @@ -509,42 +516,42 @@ def grad(func, *args, **kwargs): """ (deindexed_func, reindex) = _indexed_func_wrapper(func) - f = _register_torch_op(torch.func.grad(deindexed_func, *args, **kwargs)) + f = _register_torch_op(torch_grad(deindexed_func, *args, **kwargs)) return lambda *a, **k: reindex(f(*a, *k)) -@functools.wraps(torch.func.jacfwd) +@functools.wraps(torch_jacfwd) def jacfwd(func, *args, **kwargs): (deindexed_func, reindex) = _indexed_func_wrapper(func) - jacobian = _register_torch_op(torch.func.jacfwd(deindexed_func, *args, **kwargs)) + jacobian = _register_torch_op(torch_jacfwd(deindexed_func, *args, **kwargs)) return lambda *a, **k: reindex(jacobian(*a, *k)) -@functools.wraps(torch.func.jacrev) +@functools.wraps(torch_jacrev) def jacrev(func, *args, **kwargs): (deindexed_func, reindex) = _indexed_func_wrapper(func) - jacobian = _register_torch_op(torch.func.jacrev(deindexed_func, *args, **kwargs)) + jacobian = _register_torch_op(torch_jacrev(deindexed_func, *args, **kwargs)) return lambda *a, **k: reindex(jacobian(*a, *k)) -@functools.wraps(torch.func.hessian) +@functools.wraps(torch_hessian) def hessian(func, *args, **kwargs): (deindexed_func, reindex) = _indexed_func_wrapper(func) - h = _register_torch_op(torch.func.hessian(deindexed_func, *args, **kwargs)) + h = _register_torch_op(torch_hessian(deindexed_func, *args, **kwargs)) return lambda *a, **k: reindex(h(*a, *k)) -@functools.wraps(torch.func.jvp) +@functools.wraps(torch_jvp) def jvp(func, *args, **kwargs): (deindexed_func, reindex) = _indexed_func_wrapper(func) # hide deindexed_func from _register_torch_op - jvp_func = functools.partial(torch.func.jvp, deindexed_func) + jvp_func = functools.partial(torch_jvp, deindexed_func) ret = _register_torch_op(jvp_func)(*args, **kwargs) return tree.map_structure(reindex, ret) -@functools.wraps(torch.func.vjp) +@functools.wraps(torch_vjp) def vjp(func, *indexed_primals, **kwargs): unpacked_primals = [] for t in indexed_primals: @@ -568,7 +575,7 @@ def wrapper(*primals): ) unindexed_primals = [t[0] for t in unpacked_primals] - _, vjpfunc = torch.func.vjp(wrapper, *unindexed_primals, **kwargs) + _, vjpfunc = torch_vjp(wrapper, *unindexed_primals, **kwargs) def vjpfunc_wrapper(*tangents): unindexed_tangents = tree.map_structure( @@ -580,10 +587,10 @@ def vjpfunc_wrapper(*tangents): return indexed_result, vjpfunc_wrapper -@functools.wraps(torch.func.vmap) +@functools.wraps(torch_vmap) def vmap(func, *args, **kwargs): (deindexed_func, reindex) = _indexed_func_wrapper(func) - vmap_func = _register_torch_op(torch.func.vmap(deindexed_func, *args, **kwargs)) + vmap_func = _register_torch_op(torch_vmap(deindexed_func, *args, **kwargs)) # vmap_func returns tensors of shape [vmap_dim, indexed_dim_1, ..., # indexed_dim_n, pos_dim_1, ..., pos_dim_m], so we reapply indexes starting # at dim 1 diff --git a/pyproject.toml b/pyproject.toml index 9f470989..cbaabba9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,11 @@ [tool.pyright] include = ["tests", "effectful"] -pythonVersion = "3.12" exclude = ["**/node_modules", "**/__pycache__", "effectful.egg-info" ] -reportMissingImports = false +pythonVersion = "3.12" +reportMissingImports = true reportAttributeAccessIssue = false reportCallIssue = false reportGeneralTypeIssues = false diff --git a/tests/test_handlers_pyro.py b/tests/test_handlers_pyro.py index 2769de33..cfb6e8cf 100644 --- a/tests/test_handlers_pyro.py +++ b/tests/test_handlers_pyro.py @@ -124,15 +124,15 @@ def test_smoke_condition_enumerate_hmm_elbo( tr.compute_log_prob() for t in range(num_steps): assert f"x_{t}" in tr.nodes - assert tr.nodes[f"x_{t}"]["type"] == "sample" - assert not tr.nodes[f"x_{t}"]["is_observed"] - assert any(f.name == "plate1" for f in tr.nodes[f"x_{t}"]["cond_indep_stack"]) + assert tr.nodes.get(f"x_{t}", {}).get("type") == "sample" + assert not tr.nodes.get(f"x_{t}", {}).get("is_observed") + assert any(f.name == "plate1" for f in tr.nodes.get(f"x_{t}", {}).get("cond_indep_stack", [])) assert f"y_{t}" in tr.nodes - assert tr.nodes[f"y_{t}"]["type"] == "sample" - assert tr.nodes[f"y_{t}"]["is_observed"] - assert (tr.nodes[f"y_{t}"]["value"] == data[t]).all() - assert any(f.name == "plate1" for f in tr.nodes[f"x_{t}"]["cond_indep_stack"]) + assert tr.nodes.get(f"y_{t}", {}).get("type") == "sample" + assert tr.nodes.get(f"y_{t}", {}).get("is_observed") + assert (tr.nodes.get(f"y_{t}", {}).get("value") == data[t]).all() + assert any(f.name == "plate1" for f in tr.nodes.get(f"x_{t}", {}).get("cond_indep_stack", [])) if use_guide: guide = pyro.infer.config_enumerate(default="parallel")(