Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion cuthbertlib/resampling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,20 @@ adaptive_resampling = resampling.adaptive.ess_decorator(
adaptive_resampled_indices, _, adaptive_resampled_particles = adaptive_resampling(
resampling_key, logits, particles, 100
)
```
```

For consistent gradient estimates with respect to model parameters, the [stop-gradient particle filter](https://arxiv.org/abs/2106.10314) is also implemented as a decorator.

```python
differentiable_resampling = resampling.stop_gradient.stop_gradient_decorator(
resampling.multinomial.resampling
)
# can be combined with adaptive resampling
adaptive_and_differentiable_resampling = resampling.adaptive.ess_decorator(
differentiable_resampling,
threshold=0.5,
)
resampled_indices, _, resampled_particles = adaptive_and_differentiable_resampling(
resampling_key, logits, particles, 100
)
```
9 changes: 8 additions & 1 deletion cuthbertlib/resampling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from cuthbertlib.resampling import adaptive, killing, multinomial, systematic
from cuthbertlib.resampling import (
adaptive,
autodiff,
killing,
multinomial,
systematic,
)
from cuthbertlib.resampling.adaptive import ess_decorator
from cuthbertlib.resampling.autodiff import stop_gradient_decorator
from cuthbertlib.resampling.protocols import ConditionalResampling, Resampling
from cuthbertlib.resampling.utils import inverse_cdf
64 changes: 64 additions & 0 deletions cuthbertlib/resampling/autodiff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Implements decorators for automatic differentiation of resampling schemes.

Current supported is the stop_gradient resampling scheme, which provides the
classical Fisher estimates for the score function via automatic differentiation.
This can be wrapped around a resampling scheme such as multinomial or systematic
resampling.

See [Scibior and Wood (2021)](https://arxiv.org/abs/2106.10314) for more details.
"""

from functools import wraps

import jax
import jax.numpy as jnp

from cuthbertlib.resampling.protocols import Resampling
from cuthbertlib.resampling.utils import apply_resampling_indices
from cuthbertlib.smc.ess import log_ess
from cuthbertlib.types import Array, ArrayLike, ArrayTree, ArrayTreeLike


def stop_gradient_decorator(func: Resampling) -> Resampling:
"""Wrap a Resampling function to use stop gradient resampling.

Args:
func: A resampling function with signature
(key, logits, positions, n) -> (indices, logits_out, positions_out).

Returns:
A Resampling function implementing stop gradient resampling.
"""
# Build a descriptive docstring that includes the wrapped function doc
wrapped_doc = func.__doc__ or ""
doc = f"""
Stop gradient resampling decorator.

This wrapper will call the provided resampling function, and then apply
the stop gradient trick of [Scibior and Wood (2021)](https://arxiv.org/abs/2106.10314).
Resulting estimates of the score function (i.e., the gradient of the
log-likelihood with respect to model parameters) are unbiased,
corresponding to the classical Fisher estimate.

Wrapped resampler documentation:
{wrapped_doc}
"""

@wraps(func)
def _wrapped(
key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int
) -> tuple[Array, Array, ArrayTree]:
idx_base, logits_base, positions_base = func(
key, jax.lax.stop_gradient(logits), positions, n
)

logits = jnp.asarray(
logits_base
+ apply_resampling_indices(logits, idx_base)
- jax.lax.stop_gradient(apply_resampling_indices(logits, idx_base))
)
return idx_base, logits, positions_base

# Attach the composed docstring and return a jitted version
_wrapped.__doc__ = doc
return jax.jit(_wrapped, static_argnames=("n",))
7 changes: 5 additions & 2 deletions cuthbertlib/resampling/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
import numba as nb
import numpy as np
from jax.lax import platform_dependent
from jax.lax import platform_dependent, stop_gradient
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map

Expand Down Expand Up @@ -32,7 +32,10 @@ def inverse_cdf(sorted_uniforms: ArrayLike, logits: ArrayLike) -> Array:
"""
weights = jnp.exp(logits - logsumexp(logits))
return platform_dependent(
sorted_uniforms, weights, cpu=_inverse_cdf_cpu, default=_inverse_cdf_default
sorted_uniforms,
stop_gradient(weights),
cpu=_inverse_cdf_cpu,
default=_inverse_cdf_default,
)


Expand Down
Binary file added docs/assets/pf_grad_bias_final_particles.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/pf_grad_bias_mll_scores.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/pf_grad_bias_score_boxplots.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/pf_grad_bias_simulated_data.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions docs/cuthbertlib_api/resampling.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@
- conditional_resampling

::: cuthbertlib.resampling.killing

::: cuthbertlib.resampling.adaptive

::: cuthbertlib.resampling.autodiff
Loading
Loading