Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions docs/examples/diff_pf_resampling.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@ In the default implementation of a particle filter, particles are propagated thr

Because of this, many different approaches have been proposed to make the resampling step differentiable. A simple implementation is that of [Scibior and Wood (2021)](https://arxiv.org/abs/2106.10314), which implements a stop-gradient trick recovering classical estimates of the gradients of the marginal log-likelihood via automatic differentation.
Formally, it is equivalent to computing
\begin{align}

$$
\begin{align*}
\nabla_{\theta} \log p(y_{1:T})
&= \int \nabla_{\theta} \log p(x_{0:T}, y_{1:T}) \, p(x_{0:T} \mid y_{1:T}) \mathrm{d}x_{0:T}\\
&\approx \sum_{n=1}^N W_T^{n} \log p(X^{(n)}_{0:T}, y_{1:T})
\end{align}
\end{align*}
$$

under the high variance *[backward tracing](https://github.com/state-space-models/cuthbert/blob/main/cuthbertlib/smc/smoothing/tracing.py)* version of the smoothing distribution.
However, this implementation is cheap, black-box, does not modify the forward pass of the filter when implement, and is therefore a reasonable start for estimating the parameters of the dynamics or observation model in a state-space model (but not parameters of the proposal!). Here, we illustrate the bias of non-differentiable resampling, and show how to use the `stop_gradient_decorator` from `cuthbertlib.resampling` to achieve better gradient estimates.

Expand All @@ -27,6 +31,8 @@ from cuthbert.smc import particle_filter as pf
from cuthbertlib.stats.multivariate_normal import logpdf
```

## Defining the Benchmarking Model

The model that we simulate from is linear-Gaussian:

$$
Expand Down Expand Up @@ -85,6 +91,8 @@ We can visualize this data.

![Simulated linear–Gaussian SSM](../assets/pf_grad_bias_simulated_data.png)

## Marginal Likelihood Computation

We next set up utilities to examine the MLL of the Kalman filter, and bootstrap PF with and without differentiable resampling. Let's begin with the Kalman filter.

```{.python #pf_grad_bias-kalman_mll}
Expand Down Expand Up @@ -195,6 +203,8 @@ def pf_mll(
pf_grad = jax.grad(lambda th, ys, **kw: pf_mll(th, ys, **kw))
```

## Comparison: Marginal Likelihood & Score Estimates

We compare the standard bootstrap PF to one that uses differentiable resampling (stop_gradient around systematic resampling) on a grid of $\log q$ and $\log r$ values.

```{.python #pf_grad_bias-methods_grids}
Expand Down Expand Up @@ -481,6 +491,8 @@ for diff, label in [(False, "PF (standard)"), (True, "PF (diff-resampling)")]:
print(f"{label}: {median_ms:.2f} ± {std_ms:.2f} ms (median ± std over {n_timing} evals)")
```

## Illustration of Forward Pass Invariance

Finally, we should be sure that the forward pass is not actually being modified in the stop-gradient differentiable particle filter, and that both methods are tracking the state accurately. To illustrate this, we display the final-time particles after filtering over the full trajectory, which should track the state well and be identical.

??? "Code to plot final-time particles."
Expand Down
Loading