A generalization of the Feynman-Kac representation of state-space models #30
Replies: 7 comments 20 replies
-
|
Just adding the extract from Särkka and Garcia-Fernandez.
So the idea would be that the methods would have the following structure
|
Beta Was this translation helpful? Give feedback.
-
|
Ok need to think further about this but could be really nice. I think discrete and generalised Kalman should fit, with SMC the main thing to think about. I would be inclined to not abstract away the observation |
Beta Was this translation helpful? Give feedback.
-
|
Been thinking about this a bit and wondering if something like the below makes sense. Although I'm not quite sure how to handle from typing import NamedTuple, Protocol
from cuthbert.types import (
ArrayTree,
ArrayTreeLike,
KeyArray,
)
# State encoding parameters needed to compute
# f(x_k | x_{k-1}) = p(x_k | y_k, x_{k-1}, inputs)
# and
# g(x_{k-1}) = p(y_k | x_{k-1}, inputs)
FilteringState = ArrayTree
# State encoding parameters needed to compute
# k(x_k | x_{k+1}) = p(x_k | y_{1:k}, x_{k+1}, inputs)
SmoothingState = ArrayTree
# Initiate first filtering state encoding parameters needed to compute
# f(x_0 | None) = p(x_0 | inputs)
# and
# g(None) = 1
class Init(Protocol):
def __call__(
self,
inputs: ArrayTreeLike,
key: KeyArray | None = None,
) -> FilteringState: ...
# Binary operator (could be associative but not necessarily)
# Combines filtering states in a way that is consistent with the filtering distribution
class FilteringOperator(Protocol):
def __call__(
self,
state_prev: FilteringState,
state: FilteringState,
observation: ArrayTreeLike,
inputs: ArrayTreeLike,
key: KeyArray | None = None,
) -> FilteringState: ...
# Binary operator (could be associative but not necessarily)
# Combines smoothing states in a way that is consistent with the smoothing distribution
class SmoothingOperator(Protocol):
def __call__(
self,
state_prev: SmoothingState,
state: SmoothingState,
inputs_prev: ArrayTreeLike,
inputs: ArrayTreeLike,
key: KeyArray | None = None,
) -> SmoothingState: ...
class SSMInference(NamedTuple):
init: Init
filtering_operator: FilteringOperator
smoothing_operator: SmoothingOperator |
Beta Was this translation helpful? Give feedback.
-
|
I was thinking about this rough structure: import jax
import jax.numpy as jnp
from jax.scipy.stats import norm
from typing import NamedTuple, Any
class BootstrapState(NamedTuple):
"""State of the PF at time
"""
key: Any
weights: Any
particles: Any
genealogy: Any
inps: Any
def f_rvs(self, key, xs):
return xs + jax.random.normal(key, xs.shape)
def g_rvs(self, xs)
return norm.logpdf(self.inps, xs)
@classmethod
def init_as(cls, key, dist):
...
class BootstrapOperator(NamedTuple):
"""Operator for the PF
"""
# Some hyperparameters
resampling: Any = lambda k, w: jax.random.choice(k, w.shape[0], w.shape, p = w / w.sum())
def combine(self, state_1, state_2):
key_resampling, key_propose = jax.random.split(state_1.key)
ancestors = self.resampling(key_resampling, state_1.weights)
next_particles = state_1.f_rvs(key_propose, state_1.particles[ancestors])
next_weights = state_2.g_rvs(next_particles)
return BootstrapState(state_2.key, next_weights, next_particles, state_1.genealogy) |
Beta Was this translation helpful? Give feedback.
-
|
Unified protocol
|
Beta Was this translation helpful? Give feedback.
-
|
Some more ramblings from our meeting def build(resampling):
return Inference(combine=partial(combine, resampling=resampling))
def combine(state_1, state_2, resampling):
key_resampling, key_propose = jax.random.split(state_1.key)
ancestors = resampling(key_resampling, state_1.weights)
next_particles = state_1.f_rvs(key_propose, state_1.particles[ancestors])
next_weights = state_2.g_rvs(next_particles)
return BootstrapState(state_2.key, next_weights, next_particles, state_1.genealogy)
inference = pf.build(resampling)
inference = extended_kalman.build(covariance_functions)
state = inference.init(inputs[0])
state = inference.filter_combine(state, inputs[1]) |
Beta Was this translation helpful? Give feedback.
-
|
Closing this discussion as we have aligned on an interface |
Beta Was this translation helpful? Give feedback.


Uh oh!
There was an error while loading. Please reload this page.
-
Feynman--Kac models are usually represented as a pair$M(x_t \mid x_{t-1})$ , $G(x_{t-1:t})$ which is somewhat very tied to the SMC implementation and not super compatible with other (Gaussian-approximated for example) perspectives.
It does make sense to generalize the structure, and I think the perspective I took in my PhD thesis (Section 5.2.3), which represents them in the same way as what Simo and Angel did in their paper:
Of course, we can't really compute these things in general, but we can say how they should be computed.$F(x_t \mid x_{t-1})$ and function $H(x_{t-1})$ can be simulated and computed together.
For instance, the kernel
I'm not sure what the best representation would actually be, but the idea is that we can really just represent objects by the way they should be computed rather than by their probabilistic structure.
Beta Was this translation helpful? Give feedback.
All reactions