-
Notifications
You must be signed in to change notification settings - Fork 5
Implement the Stop-Gradient Differentiable Particle Filter #209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
AdrienCorenflos
merged 8 commits into
state-space-models:main
from
DanWaxman:dw-make-pf-diff
Mar 3, 2026
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
34beb52
Fix differentiability of the PF implementation
DanWaxman f7345d3
Introduce a Stop Gradient Resampling Decorator
DanWaxman 96be48a
Add a DPF example to the documentation
DanWaxman 819b2f8
Add unittesting for the stop gradient PF
DanWaxman 8e628c6
Enhance explanation of differentiable resampling methods
AdrienCorenflos 41f858b
Move to a unit test on the SG resampling decorator
DanWaxman 90ebed9
Merge pull request #1 from AdrienCorenflos/dw-make-pf-diff
DanWaxman 81baec7
Migrate naming of stop_gradient -> autodiff submodule.
DanWaxman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,11 @@ | ||
| from cuthbertlib.resampling import adaptive, killing, multinomial, systematic | ||
| from cuthbertlib.resampling import ( | ||
| adaptive, | ||
| autodiff, | ||
| killing, | ||
| multinomial, | ||
| systematic, | ||
| ) | ||
| from cuthbertlib.resampling.adaptive import ess_decorator | ||
| from cuthbertlib.resampling.autodiff import stop_gradient_decorator | ||
| from cuthbertlib.resampling.protocols import ConditionalResampling, Resampling | ||
| from cuthbertlib.resampling.utils import inverse_cdf |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| """Implements decorators for automatic differentiation of resampling schemes. | ||
|
|
||
| Current supported is the stop_gradient resampling scheme, which provides the | ||
| classical Fisher estimates for the score function via automatic differentiation. | ||
| This can be wrapped around a resampling scheme such as multinomial or systematic | ||
| resampling. | ||
|
|
||
| See [Scibior and Wood (2021)](https://arxiv.org/abs/2106.10314) for more details. | ||
| """ | ||
|
|
||
| from functools import wraps | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| from cuthbertlib.resampling.protocols import Resampling | ||
| from cuthbertlib.resampling.utils import apply_resampling_indices | ||
| from cuthbertlib.smc.ess import log_ess | ||
| from cuthbertlib.types import Array, ArrayLike, ArrayTree, ArrayTreeLike | ||
|
|
||
|
|
||
| def stop_gradient_decorator(func: Resampling) -> Resampling: | ||
| """Wrap a Resampling function to use stop gradient resampling. | ||
|
|
||
| Args: | ||
| func: A resampling function with signature | ||
| (key, logits, positions, n) -> (indices, logits_out, positions_out). | ||
|
|
||
| Returns: | ||
| A Resampling function implementing stop gradient resampling. | ||
| """ | ||
| # Build a descriptive docstring that includes the wrapped function doc | ||
| wrapped_doc = func.__doc__ or "" | ||
| doc = f""" | ||
| Stop gradient resampling decorator. | ||
|
|
||
| This wrapper will call the provided resampling function, and then apply | ||
| the stop gradient trick of [Scibior and Wood (2021)](https://arxiv.org/abs/2106.10314). | ||
| Resulting estimates of the score function (i.e., the gradient of the | ||
| log-likelihood with respect to model parameters) are unbiased, | ||
| corresponding to the classical Fisher estimate. | ||
|
|
||
| Wrapped resampler documentation: | ||
| {wrapped_doc} | ||
| """ | ||
|
|
||
| @wraps(func) | ||
| def _wrapped( | ||
| key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int | ||
| ) -> tuple[Array, Array, ArrayTree]: | ||
| idx_base, logits_base, positions_base = func( | ||
| key, jax.lax.stop_gradient(logits), positions, n | ||
| ) | ||
|
|
||
| logits = jnp.asarray( | ||
| logits_base | ||
| + apply_resampling_indices(logits, idx_base) | ||
| - jax.lax.stop_gradient(apply_resampling_indices(logits, idx_base)) | ||
| ) | ||
| return idx_base, logits, positions_base | ||
|
|
||
| # Attach the composed docstring and return a jitted version | ||
| _wrapped.__doc__ = doc | ||
| return jax.jit(_wrapped, static_argnames=("n",)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.