diff --git a/docs/examples/diff_pf_resampling.md b/docs/examples/diff_pf_resampling.md index 05f2a7e..2747223 100644 --- a/docs/examples/diff_pf_resampling.md +++ b/docs/examples/diff_pf_resampling.md @@ -4,11 +4,15 @@ In the default implementation of a particle filter, particles are propagated thr Because of this, many different approaches have been proposed to make the resampling step differentiable. A simple implementation is that of [Scibior and Wood (2021)](https://arxiv.org/abs/2106.10314), which implements a stop-gradient trick recovering classical estimates of the gradients of the marginal log-likelihood via automatic differentation. Formally, it is equivalent to computing -\begin{align} + +$$ +\begin{align*} \nabla_{\theta} \log p(y_{1:T}) &= \int \nabla_{\theta} \log p(x_{0:T}, y_{1:T}) \, p(x_{0:T} \mid y_{1:T}) \mathrm{d}x_{0:T}\\ &\approx \sum_{n=1}^N W_T^{n} \log p(X^{(n)}_{0:T}, y_{1:T}) -\end{align} +\end{align*} +$$ + under the high variance *[backward tracing](https://github.com/state-space-models/cuthbert/blob/main/cuthbertlib/smc/smoothing/tracing.py)* version of the smoothing distribution. However, this implementation is cheap, black-box, does not modify the forward pass of the filter when implement, and is therefore a reasonable start for estimating the parameters of the dynamics or observation model in a state-space model (but not parameters of the proposal!). Here, we illustrate the bias of non-differentiable resampling, and show how to use the `stop_gradient_decorator` from `cuthbertlib.resampling` to achieve better gradient estimates. @@ -27,6 +31,8 @@ from cuthbert.smc import particle_filter as pf from cuthbertlib.stats.multivariate_normal import logpdf ``` +## Defining the Benchmarking Model + The model that we simulate from is linear-Gaussian: $$ @@ -85,6 +91,8 @@ We can visualize this data. ![Simulated linear–Gaussian SSM](../assets/pf_grad_bias_simulated_data.png) +## Marginal Likelihood Computation + We next set up utilities to examine the MLL of the Kalman filter, and bootstrap PF with and without differentiable resampling. Let's begin with the Kalman filter. ```{.python #pf_grad_bias-kalman_mll} @@ -195,6 +203,8 @@ def pf_mll( pf_grad = jax.grad(lambda th, ys, **kw: pf_mll(th, ys, **kw)) ``` +## Comparison: Marginal Likelihood & Score Estimates + We compare the standard bootstrap PF to one that uses differentiable resampling (stop_gradient around systematic resampling) on a grid of $\log q$ and $\log r$ values. ```{.python #pf_grad_bias-methods_grids} @@ -481,6 +491,8 @@ for diff, label in [(False, "PF (standard)"), (True, "PF (diff-resampling)")]: print(f"{label}: {median_ms:.2f} ± {std_ms:.2f} ms (median ± std over {n_timing} evals)") ``` +## Illustration of Forward Pass Invariance + Finally, we should be sure that the forward pass is not actually being modified in the stop-gradient differentiable particle filter, and that both methods are tracking the state accurately. To illustrate this, we display the final-time particles after filtering over the full trajectory, which should track the state well and be identical. ??? "Code to plot final-time particles."