Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.5
hooks:
- id: ruff
args: ["--fix"]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.13.1
hooks:
# Run the linter.
- id: ruff-check
args: [ --fix ]
# Run the formatter.
- id: ruff-format
2 changes: 1 addition & 1 deletion posteriors/ekf/dense_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import partial
import torch
from torch.func import grad_and_value
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel
from tensordict import TensorClass

from posteriors.tree_utils import tree_size, tree_insert_
Expand Down
2 changes: 1 addition & 1 deletion posteriors/laplace/dense_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import partial
import torch
from optree import tree_map
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel
from tensordict import TensorClass
from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.tree_utils import tree_size, tree_insert_
Expand Down
2 changes: 1 addition & 1 deletion posteriors/laplace/dense_ggn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any
import torch
from optree import tree_map
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel
from tensordict import TensorClass

from posteriors.types import (
Expand Down
2 changes: 1 addition & 1 deletion posteriors/laplace/dense_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import partial
import torch
from optree import tree_map
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel
from tensordict import TensorClass

from posteriors.types import TensorTree, Transform, LogProbFn
Expand Down
2 changes: 1 addition & 1 deletion posteriors/sgmcmc/sgnht.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.func import grad_and_value
from optree import tree_map
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel
from tensordict import TensorClass
from posteriors.types import TensorTree, Transform, LogProbFn, Schedule
from posteriors.tree_utils import flexi_tree_map, tree_insert_
Expand Down
2 changes: 1 addition & 1 deletion posteriors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.func import grad, jvp, vjp, functional_call, jacrev, jacfwd
from torch.distributions import Normal
from optree import tree_map, tree_reduce, tree_flatten, tree_leaves
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel

from posteriors.types import TensorTree, ForwardFn, Tensor
from posteriors.tree_utils import tree_size
Expand Down
2 changes: 1 addition & 1 deletion posteriors/vi/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.func import grad_and_value, vmap
from optree import tree_map
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel
import torchopt
from tensordict import TensorClass

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
"License :: OSI Approved :: Apache Software License",
]
dependencies = ["torch>=2.0.0", "torchopt>=0.7.3", "optree>=0.10.0", "tensordict>=0.7.0"]
dependencies = ["torch>=2.0.0", "torchopt>=0.7.3", "optree>=0.17.0", "tensordict>=0.7.0"]

[project.optional-dependencies]
test = ["pre-commit", "pytest-cov", "pytest-xdist", "ruff"]
Expand Down
2 changes: 1 addition & 1 deletion tests/ekf/test_diag_fisher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
import torch
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel
from posteriors import ekf
from tests.scenarios import get_multivariate_normal_log_prob
from tests.utils import verify_inplace_update
Expand Down
2 changes: 1 addition & 1 deletion tests/ekf/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable
import torch
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel
from tests.scenarios import get_multivariate_normal_log_prob
from posteriors.types import LogProbFn, Transform

Expand Down
2 changes: 1 addition & 1 deletion tests/laplace/test_dense_ggn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.distributions import Normal
from torch.utils.data import DataLoader, TensorDataset
from torch.func import functional_call
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel

from posteriors.laplace import dense_ggn
from tests.utils import verify_inplace_update
Expand Down
2 changes: 1 addition & 1 deletion tests/laplace/test_dense_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.distributions import Normal
from torch.utils.data import DataLoader, TensorDataset
from torch.func import functional_call, hessian
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel

from posteriors import tree_size, diag_normal_log_prob
from posteriors.laplace import dense_hessian
Expand Down
2 changes: 1 addition & 1 deletion tests/laplace/test_diag_fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.distributions import Normal
from torch.utils.data import DataLoader, TensorDataset
from torch.func import functional_call
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel
from optree import tree_map

from posteriors.laplace import diag_fisher
Expand Down
2 changes: 1 addition & 1 deletion tests/laplace/test_diag_ggn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils.data import DataLoader, TensorDataset
from torch.func import functional_call
from optree import tree_map
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel

from posteriors.laplace import diag_ggn

Expand Down
2 changes: 1 addition & 1 deletion tests/scenarios.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel
from posteriors.types import LogProbFn


Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import partial
import torch
from optree import tree_map, tree_flatten, tree_reduce
from optree.integration.torch import tree_ravel
from optree.integrations.torch import tree_ravel

from posteriors import (
CatchAuxError,
Expand Down