Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cuthbert/mcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
MCMC methods.
"""
10 changes: 10 additions & 0 deletions cuthbert/mcmc/csmc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
Conditional Sequential Monte Carlo.

This module provides functions to build a conditional particle filter and a corresponding smoother,
following the patterns used in the rest of the `cuthbert` library.
"""
from .conditional_particle_filter import build_csmc_filter
from .smoother import build_csmc_smoother

__all__ = ["build_csmc_filter", "build_csmc_smoother"]
168 changes: 168 additions & 0 deletions cuthbert/mcmc/csmc/conditional_particle_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from functools import partial

import jax
import jax.numpy as jnp
from jax import random

from cuthbert.inference import Filter
from cuthbert.smc.particle_filter import ParticleFilterState
from cuthbert.smc.types import InitSample, LogPotential, PropagateSample
from cuthbert.utils import dummy_tree_like
from cuthbertlib.resampling.protocols import ConditionalResampling
from cuthbertlib.types import ArrayTreeLike, KeyArray


def build_csmc_filter(
init_sample: InitSample,
propagate_sample: PropagateSample,
log_potential: LogPotential,
n_particles: int,
resampling_fn: ConditionalResampling,
) -> Filter:
"""Builds a conditional particle filter object.
Args:
init_sample: Function to sample from the initial distribution.
propagate_sample: Function to sample from the Markov kernel.
log_potential: Function to compute the log potential.
n_particles: Number of particles for the filter.
resampling_fn: Conditional resampling algorithm to use.
Returns:
Filter object for the conditional particle filter.
"""
return Filter(
init_prepare=partial(
init_prepare,
init_sample=init_sample,
log_potential=log_potential,
n_particles=n_particles,
),
filter_prepare=partial(
filter_prepare,
init_sample=init_sample,
n_particles=n_particles,
),
filter_combine=partial(
filter_combine,
propagate_sample=propagate_sample,
log_potential=log_potential,
resampling_fn=resampling_fn,
),
associative=False,
)


def init_prepare(
model_inputs: ArrayTreeLike,
init_sample: InitSample,
log_potential: LogPotential,
n_particles: int,
key: KeyArray | None = None,
) -> ParticleFilterState:
"""Prepare the initial state for the conditional particle filter."""
if key is None:
raise ValueError("A JAX PRNG key must be provided.")

# Sample
keys = random.split(key, n_particles)
particles = jax.vmap(init_sample, (0, None))(keys, model_inputs)

# Pin reference particle
_, reference_particle, reference_index = model_inputs
particles = particles.at[reference_index].set(reference_particle)

# Weight
log_weights = jax.vmap(log_potential, (None, 0, None))(
None, particles, model_inputs
)

# Compute the log normalizing constant
log_normalizing_constant = jax.nn.logsumexp(log_weights) - jnp.log(n_particles)

return ParticleFilterState(
key=key,
particles=particles,
log_weights=log_weights,
ancestor_indices=jnp.arange(n_particles),
model_inputs=model_inputs,
log_normalizing_constant=log_normalizing_constant,
)


def filter_prepare(
model_inputs: ArrayTreeLike,
init_sample: InitSample,
n_particles: int,
key: KeyArray | None = None,
) -> ParticleFilterState:
"""Prepare a state for a conditional particle filter step."""
if key is None:
raise ValueError("A JAX PRNG key must be provided.")
dummy_particle = jax.eval_shape(init_sample, key, model_inputs)
particles = jax.tree.map(
lambda x: jnp.empty((n_particles,) + x.shape), dummy_particle
)
particles = dummy_tree_like(particles)
return ParticleFilterState(
key=key,
particles=particles,
log_weights=jnp.zeros((n_particles, 1)),
ancestor_indices=jnp.arange(n_particles),
model_inputs=model_inputs,
log_normalizing_constant=jnp.array(0.0),
)


def filter_combine(
state_1: ParticleFilterState,
state_2: ParticleFilterState,
propagate_sample: PropagateSample,
log_potential: LogPotential,
resampling_fn: ConditionalResampling,
ess_threshold: float,
) -> ParticleFilterState:
"""Combine previous filter state with the state prepared for the current step."""
n_particles = state_1.log_weights.shape[0]
keys = random.split(state_1.key, n_particles + 1)

# Get conditional info from states
_, _, prev_ref_idx = state_1.model_inputs
_, current_ref_particle, current_ref_idx = state_2.model_inputs

# Resample
# Here we assume that if conditional is True, a ConditionalResampling function is provided.
ancestor_indices = resampling_fn(
keys[0], state_1.log_weights, n_particles, prev_ref_idx, current_ref_idx
)

ancestors = jax.tree.map(lambda x: x[ancestor_indices], state_1.particles)
log_weights = jnp.zeros((n_particles, 1)) # Reset weights after resampling

# Propagate
next_particles = jax.vmap(propagate_sample, (0, 0, None))(
keys[1:], ancestors, state_2.model_inputs
)

# Pin reference particle
next_particles = next_particles.at[current_ref_idx].set(current_ref_particle)

# Reweight
log_potentials = jax.vmap(log_potential, (0, 0, None))(
ancestors, next_particles, state_2.model_inputs
)
next_log_weights = log_weights + log_potentials

# Compute the log normalizing constant
logsum_weights = jax.nn.logsumexp(next_log_weights)
log_normalizing_constant_incr = logsum_weights - jnp.log(n_particles)
log_normalizing_constant = (
log_normalizing_constant_incr + state_1.log_normalizing_constant
)

return ParticleFilterState(
key=state_2.key,
particles=next_particles,
log_weights=next_log_weights,
ancestor_indices=ancestor_indices,
model_inputs=state_2.model_inputs,
log_normalizing_constant=log_normalizing_constant,
)
120 changes: 120 additions & 0 deletions cuthbert/mcmc/csmc/smoother.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Implements the backward pass for the Conditional Sequential Monte Carlo."""

from functools import partial

import jax
from jax import numpy as jnp

from cuthbert.smc.particle_filter import ParticleFilterState
from cuthbert.smc.types import LogPotential
from cuthbertlib.mcmc.protocols import AncestorMove
from cuthbertlib.resampling.utils import normalize
from cuthbertlib.types import KeyArray



def _backward_sampling_step(
log_potential_fn: LogPotential,
particle_t: jnp.ndarray,
inp: tuple[ParticleFilterState, KeyArray],
):
"""A single step of the backward sampling pass."""
state_t_minus_1, key_t = inp
log_weights = log_potential_fn(
state_t_minus_1.particles, particle_t, state_t_minus_1.model_inputs
)
log_weights -= jnp.max(log_weights)
log_weights += state_t_minus_1.log_weights

weights = normalize(log_weights)
ancestor_idx = jax.random.choice(key_t, weights.shape[0], p=weights, shape=())
particle_t_minus_1 = state_t_minus_1.particles[ancestor_idx]

return particle_t_minus_1, (particle_t_minus_1, ancestor_idx)


def _backward_trace_step(ancestor_idx_t: int, state_t_minus_1: ParticleFilterState):
"""A single step of the backward tracing pass."""
ancestor_idx_t_minus_1 = state_t_minus_1.ancestor_indices[ancestor_idx_t]
particle_t_minus_1 = state_t_minus_1.particles[ancestor_idx_t_minus_1]
return ancestor_idx_t_minus_1, (particle_t_minus_1, ancestor_idx_t_minus_1)


def build_csmc_smoother(
log_potential_fn: LogPotential,
ancestor_move_fn: AncestorMove,
conditional: bool = True,
):
"""Builds a CSMC smoother function.

Args:
log_potential_fn: The log potential function.
ancestor_move_fn: The function to move the final ancestor.
conditional: Whether the pass is conditional.

Returns:
A smoother function that can be applied to the output of a forward pass.
"""

def _smoother(
forward_states: ParticleFilterState, key: KeyArray, do_sampling: bool
):
"""The smoother function to be returned.

Args:
forward_states: The output of the forward pass.
key: JAX random number generator key.
do_sampling: Whether to perform backward sampling or tracing.
"""
T = forward_states.particles.shape[0]
keys = jax.random.split(key, T)

# Select last ancestor
final_state = jax.tree_map(lambda x: x[-1], forward_states)
final_log_weights = final_state.log_weights
if not conditional:
weights = normalize(final_log_weights)
final_ancestor_idx = jax.random.choice(
keys[-1], weights.shape[0], p=weights
)
else:
final_ancestor_idx, _ = ancestor_move_fn(
keys[-1],
normalize(final_log_weights),
final_state.particles.shape[0] - 1,
)
final_particle = final_state.particles[final_ancestor_idx]

def sampling_pass():
"""Performs a backward sampling pass."""
backward_step_fn = partial(_backward_sampling_step, log_potential_fn)
init_carry = final_particle
inputs = (
jax.tree_map(lambda x: x[:-1], forward_states)[::-1],
keys[:-1],
)
_, (particles, indices) = jax.lax.scan(
backward_step_fn, init_carry, inputs
)
return particles, indices

def tracing_pass():
"""Performs a backward tracing pass."""
backward_step_fn = _backward_trace_step
init_carry = final_ancestor_idx
inputs = jax.tree_map(lambda x: x[:-1], forward_states)[::-1]
_, (particles, indices) = jax.lax.scan(
backward_step_fn, init_carry, inputs
)
return particles, indices

particles, indices = jax.lax.cond(
do_sampling, sampling_pass, tracing_pass
)

particles = jnp.insert(particles, 0, final_particle, axis=0)
indices = jnp.insert(indices, 0, final_ancestor_idx, axis=0)

return particles[::-1], indices[::-1]

return _smoother
4 changes: 4 additions & 0 deletions cuthbertlib/mcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""MCMC methods for cuthbertlib."""
from .index_select import barker_move

__all__ = ["barker_move"]
60 changes: 60 additions & 0 deletions cuthbertlib/mcmc/index_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Index selection methods for MCMC."""

import jax
import jax.numpy as jnp

from cuthbertlib.types import Array, KeyArray, ScalarArray, ScalarArrayLike


def barker_move(key: KeyArray, weights: Array, pivot: ScalarArrayLike) -> tuple[ScalarArray, ScalarArray]:
"""
A Barker proposal move for a categorical distribution.

Args:
key: JAX PRNG key.
weights: Normalized weights of the categorical distribution.
pivot: The current index to move from.

Returns:
A tuple containing the new index and the probability of a new index being selected
"""
M = weights.shape[0]
i = jax.random.choice(key, M, p=weights, shape=())
return i, 1 - weights[pivot]

def force_move(key: KeyArray, weights: Array, pivot: ScalarArrayLike) -> tuple[ScalarArray, ScalarArray]:
"""A forced-move proposal for a categorical distribution.

The weights are assumed to be normalised (linear, not log).

Args:
key: JAX PRNG key.
weights: Normalized weights of the categorical distribution.
pivot: The current index to move from.

Returns:
A tuple containing the new index and the overall acceptance probability.
"""
n_particles = weights.shape[0]
key_1, key_2 = jax.random.split(key, 2)

p_pivot = weights[pivot]
one_minus_p_pivot = 1 - p_pivot

# Create proposal distribution q(i) = w_i / (1 - w_k) for i != k
proposal_weights = weights.at[pivot].set(0)
proposal_weights = proposal_weights / one_minus_p_pivot

proposal_idx = jax.random.choice(key_1, n_particles, p=proposal_weights, shape=())

# Acceptance step to make the move valid: u < (1 - w_k) / (1 - w_i)
u = jax.random.uniform(key_2, shape=())
accept = u * (1 - weights[proposal_idx]) < one_minus_p_pivot

new_idx = jax.lax.select(accept, proposal_idx, pivot)

# The acceptance probability alpha is sum_i q(i) * min(1, (1-w_k)/(1-w_i))
alpha = jnp.nansum(one_minus_p_pivot * proposal_weights / (1 - weights))
alpha = jnp.clip(alpha, 0, 1.0)

return new_idx, alpha
24 changes: 24 additions & 0 deletions cuthbertlib/mcmc/protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Protocols for MCMC methods."""

from typing import Any, Protocol, runtime_checkable

from cuthbertlib.types import Array, KeyArray


@runtime_checkable
class AncestorMove(Protocol):
"""Protocol for ancestor index selection operations."""

def __call__(self, key: KeyArray, weights: Array, pivot: int) -> tuple[int, Any]:
"""
Selects an ancestor index, potentially using an MCMC move around a pivot.

Args:
key: JAX PRNG key.
weights: Normalized weights of the particles.
pivot: The current index to move from.

Returns:
A tuple containing the new index and any auxiliary output from the move.
"""
...
Loading
Loading