diff --git a/cuthbert/mcmc/__init__.py b/cuthbert/mcmc/__init__.py new file mode 100644 index 00000000..c48121af --- /dev/null +++ b/cuthbert/mcmc/__init__.py @@ -0,0 +1,3 @@ +""" +MCMC methods. +""" diff --git a/cuthbert/mcmc/csmc/__init__.py b/cuthbert/mcmc/csmc/__init__.py new file mode 100644 index 00000000..2a2922f6 --- /dev/null +++ b/cuthbert/mcmc/csmc/__init__.py @@ -0,0 +1,10 @@ +""" +Conditional Sequential Monte Carlo. + +This module provides functions to build a conditional particle filter and a corresponding smoother, +following the patterns used in the rest of the `cuthbert` library. +""" +from .conditional_particle_filter import build_csmc_filter +from .smoother import build_csmc_smoother + +__all__ = ["build_csmc_filter", "build_csmc_smoother"] diff --git a/cuthbert/mcmc/csmc/conditional_particle_filter.py b/cuthbert/mcmc/csmc/conditional_particle_filter.py new file mode 100644 index 00000000..4e4fbccc --- /dev/null +++ b/cuthbert/mcmc/csmc/conditional_particle_filter.py @@ -0,0 +1,168 @@ +from functools import partial + +import jax +import jax.numpy as jnp +from jax import random + +from cuthbert.inference import Filter +from cuthbert.smc.particle_filter import ParticleFilterState +from cuthbert.smc.types import InitSample, LogPotential, PropagateSample +from cuthbert.utils import dummy_tree_like +from cuthbertlib.resampling.protocols import ConditionalResampling +from cuthbertlib.types import ArrayTreeLike, KeyArray + + +def build_csmc_filter( + init_sample: InitSample, + propagate_sample: PropagateSample, + log_potential: LogPotential, + n_particles: int, + resampling_fn: ConditionalResampling, +) -> Filter: + """Builds a conditional particle filter object. + Args: + init_sample: Function to sample from the initial distribution. + propagate_sample: Function to sample from the Markov kernel. + log_potential: Function to compute the log potential. + n_particles: Number of particles for the filter. + resampling_fn: Conditional resampling algorithm to use. + Returns: + Filter object for the conditional particle filter. + """ + return Filter( + init_prepare=partial( + init_prepare, + init_sample=init_sample, + log_potential=log_potential, + n_particles=n_particles, + ), + filter_prepare=partial( + filter_prepare, + init_sample=init_sample, + n_particles=n_particles, + ), + filter_combine=partial( + filter_combine, + propagate_sample=propagate_sample, + log_potential=log_potential, + resampling_fn=resampling_fn, + ), + associative=False, + ) + + +def init_prepare( + model_inputs: ArrayTreeLike, + init_sample: InitSample, + log_potential: LogPotential, + n_particles: int, + key: KeyArray | None = None, +) -> ParticleFilterState: + """Prepare the initial state for the conditional particle filter.""" + if key is None: + raise ValueError("A JAX PRNG key must be provided.") + + # Sample + keys = random.split(key, n_particles) + particles = jax.vmap(init_sample, (0, None))(keys, model_inputs) + + # Pin reference particle + _, reference_particle, reference_index = model_inputs + particles = particles.at[reference_index].set(reference_particle) + + # Weight + log_weights = jax.vmap(log_potential, (None, 0, None))( + None, particles, model_inputs + ) + + # Compute the log normalizing constant + log_normalizing_constant = jax.nn.logsumexp(log_weights) - jnp.log(n_particles) + + return ParticleFilterState( + key=key, + particles=particles, + log_weights=log_weights, + ancestor_indices=jnp.arange(n_particles), + model_inputs=model_inputs, + log_normalizing_constant=log_normalizing_constant, + ) + + +def filter_prepare( + model_inputs: ArrayTreeLike, + init_sample: InitSample, + n_particles: int, + key: KeyArray | None = None, +) -> ParticleFilterState: + """Prepare a state for a conditional particle filter step.""" + if key is None: + raise ValueError("A JAX PRNG key must be provided.") + dummy_particle = jax.eval_shape(init_sample, key, model_inputs) + particles = jax.tree.map( + lambda x: jnp.empty((n_particles,) + x.shape), dummy_particle + ) + particles = dummy_tree_like(particles) + return ParticleFilterState( + key=key, + particles=particles, + log_weights=jnp.zeros((n_particles, 1)), + ancestor_indices=jnp.arange(n_particles), + model_inputs=model_inputs, + log_normalizing_constant=jnp.array(0.0), + ) + + +def filter_combine( + state_1: ParticleFilterState, + state_2: ParticleFilterState, + propagate_sample: PropagateSample, + log_potential: LogPotential, + resampling_fn: ConditionalResampling, + ess_threshold: float, +) -> ParticleFilterState: + """Combine previous filter state with the state prepared for the current step.""" + n_particles = state_1.log_weights.shape[0] + keys = random.split(state_1.key, n_particles + 1) + + # Get conditional info from states + _, _, prev_ref_idx = state_1.model_inputs + _, current_ref_particle, current_ref_idx = state_2.model_inputs + + # Resample + # Here we assume that if conditional is True, a ConditionalResampling function is provided. + ancestor_indices = resampling_fn( + keys[0], state_1.log_weights, n_particles, prev_ref_idx, current_ref_idx + ) + + ancestors = jax.tree.map(lambda x: x[ancestor_indices], state_1.particles) + log_weights = jnp.zeros((n_particles, 1)) # Reset weights after resampling + + # Propagate + next_particles = jax.vmap(propagate_sample, (0, 0, None))( + keys[1:], ancestors, state_2.model_inputs + ) + + # Pin reference particle + next_particles = next_particles.at[current_ref_idx].set(current_ref_particle) + + # Reweight + log_potentials = jax.vmap(log_potential, (0, 0, None))( + ancestors, next_particles, state_2.model_inputs + ) + next_log_weights = log_weights + log_potentials + + # Compute the log normalizing constant + logsum_weights = jax.nn.logsumexp(next_log_weights) + log_normalizing_constant_incr = logsum_weights - jnp.log(n_particles) + log_normalizing_constant = ( + log_normalizing_constant_incr + state_1.log_normalizing_constant + ) + + return ParticleFilterState( + key=state_2.key, + particles=next_particles, + log_weights=next_log_weights, + ancestor_indices=ancestor_indices, + model_inputs=state_2.model_inputs, + log_normalizing_constant=log_normalizing_constant, + ) diff --git a/cuthbert/mcmc/csmc/smoother.py b/cuthbert/mcmc/csmc/smoother.py new file mode 100644 index 00000000..1a14ca8e --- /dev/null +++ b/cuthbert/mcmc/csmc/smoother.py @@ -0,0 +1,120 @@ +"""Implements the backward pass for the Conditional Sequential Monte Carlo.""" + +from functools import partial + +import jax +from jax import numpy as jnp + +from cuthbert.smc.particle_filter import ParticleFilterState +from cuthbert.smc.types import LogPotential +from cuthbertlib.mcmc.protocols import AncestorMove +from cuthbertlib.resampling.utils import normalize +from cuthbertlib.types import KeyArray + + + +def _backward_sampling_step( + log_potential_fn: LogPotential, + particle_t: jnp.ndarray, + inp: tuple[ParticleFilterState, KeyArray], +): + """A single step of the backward sampling pass.""" + state_t_minus_1, key_t = inp + log_weights = log_potential_fn( + state_t_minus_1.particles, particle_t, state_t_minus_1.model_inputs + ) + log_weights -= jnp.max(log_weights) + log_weights += state_t_minus_1.log_weights + + weights = normalize(log_weights) + ancestor_idx = jax.random.choice(key_t, weights.shape[0], p=weights, shape=()) + particle_t_minus_1 = state_t_minus_1.particles[ancestor_idx] + + return particle_t_minus_1, (particle_t_minus_1, ancestor_idx) + + +def _backward_trace_step(ancestor_idx_t: int, state_t_minus_1: ParticleFilterState): + """A single step of the backward tracing pass.""" + ancestor_idx_t_minus_1 = state_t_minus_1.ancestor_indices[ancestor_idx_t] + particle_t_minus_1 = state_t_minus_1.particles[ancestor_idx_t_minus_1] + return ancestor_idx_t_minus_1, (particle_t_minus_1, ancestor_idx_t_minus_1) + + +def build_csmc_smoother( + log_potential_fn: LogPotential, + ancestor_move_fn: AncestorMove, + conditional: bool = True, +): + """Builds a CSMC smoother function. + + Args: + log_potential_fn: The log potential function. + ancestor_move_fn: The function to move the final ancestor. + conditional: Whether the pass is conditional. + + Returns: + A smoother function that can be applied to the output of a forward pass. + """ + + def _smoother( + forward_states: ParticleFilterState, key: KeyArray, do_sampling: bool + ): + """The smoother function to be returned. + + Args: + forward_states: The output of the forward pass. + key: JAX random number generator key. + do_sampling: Whether to perform backward sampling or tracing. + """ + T = forward_states.particles.shape[0] + keys = jax.random.split(key, T) + + # Select last ancestor + final_state = jax.tree_map(lambda x: x[-1], forward_states) + final_log_weights = final_state.log_weights + if not conditional: + weights = normalize(final_log_weights) + final_ancestor_idx = jax.random.choice( + keys[-1], weights.shape[0], p=weights + ) + else: + final_ancestor_idx, _ = ancestor_move_fn( + keys[-1], + normalize(final_log_weights), + final_state.particles.shape[0] - 1, + ) + final_particle = final_state.particles[final_ancestor_idx] + + def sampling_pass(): + """Performs a backward sampling pass.""" + backward_step_fn = partial(_backward_sampling_step, log_potential_fn) + init_carry = final_particle + inputs = ( + jax.tree_map(lambda x: x[:-1], forward_states)[::-1], + keys[:-1], + ) + _, (particles, indices) = jax.lax.scan( + backward_step_fn, init_carry, inputs + ) + return particles, indices + + def tracing_pass(): + """Performs a backward tracing pass.""" + backward_step_fn = _backward_trace_step + init_carry = final_ancestor_idx + inputs = jax.tree_map(lambda x: x[:-1], forward_states)[::-1] + _, (particles, indices) = jax.lax.scan( + backward_step_fn, init_carry, inputs + ) + return particles, indices + + particles, indices = jax.lax.cond( + do_sampling, sampling_pass, tracing_pass + ) + + particles = jnp.insert(particles, 0, final_particle, axis=0) + indices = jnp.insert(indices, 0, final_ancestor_idx, axis=0) + + return particles[::-1], indices[::-1] + + return _smoother diff --git a/cuthbertlib/mcmc/__init__.py b/cuthbertlib/mcmc/__init__.py new file mode 100644 index 00000000..ac57850e --- /dev/null +++ b/cuthbertlib/mcmc/__init__.py @@ -0,0 +1,4 @@ +"""MCMC methods for cuthbertlib.""" +from .index_select import barker_move + +__all__ = ["barker_move"] diff --git a/cuthbertlib/mcmc/index_select.py b/cuthbertlib/mcmc/index_select.py new file mode 100644 index 00000000..b6855dc9 --- /dev/null +++ b/cuthbertlib/mcmc/index_select.py @@ -0,0 +1,60 @@ +"""Index selection methods for MCMC.""" + +import jax +import jax.numpy as jnp + +from cuthbertlib.types import Array, KeyArray, ScalarArray, ScalarArrayLike + + +def barker_move(key: KeyArray, weights: Array, pivot: ScalarArrayLike) -> tuple[ScalarArray, ScalarArray]: + """ + A Barker proposal move for a categorical distribution. + + Args: + key: JAX PRNG key. + weights: Normalized weights of the categorical distribution. + pivot: The current index to move from. + + Returns: + A tuple containing the new index and the probability of a new index being selected + """ + M = weights.shape[0] + i = jax.random.choice(key, M, p=weights, shape=()) + return i, 1 - weights[pivot] + +def force_move(key: KeyArray, weights: Array, pivot: ScalarArrayLike) -> tuple[ScalarArray, ScalarArray]: + """A forced-move proposal for a categorical distribution. + + The weights are assumed to be normalised (linear, not log). + + Args: + key: JAX PRNG key. + weights: Normalized weights of the categorical distribution. + pivot: The current index to move from. + + Returns: + A tuple containing the new index and the overall acceptance probability. + """ + n_particles = weights.shape[0] + key_1, key_2 = jax.random.split(key, 2) + + p_pivot = weights[pivot] + one_minus_p_pivot = 1 - p_pivot + + # Create proposal distribution q(i) = w_i / (1 - w_k) for i != k + proposal_weights = weights.at[pivot].set(0) + proposal_weights = proposal_weights / one_minus_p_pivot + + proposal_idx = jax.random.choice(key_1, n_particles, p=proposal_weights, shape=()) + + # Acceptance step to make the move valid: u < (1 - w_k) / (1 - w_i) + u = jax.random.uniform(key_2, shape=()) + accept = u * (1 - weights[proposal_idx]) < one_minus_p_pivot + + new_idx = jax.lax.select(accept, proposal_idx, pivot) + + # The acceptance probability alpha is sum_i q(i) * min(1, (1-w_k)/(1-w_i)) + alpha = jnp.nansum(one_minus_p_pivot * proposal_weights / (1 - weights)) + alpha = jnp.clip(alpha, 0, 1.0) + + return new_idx, alpha diff --git a/cuthbertlib/mcmc/protocols.py b/cuthbertlib/mcmc/protocols.py new file mode 100644 index 00000000..0e796d9d --- /dev/null +++ b/cuthbertlib/mcmc/protocols.py @@ -0,0 +1,24 @@ +"""Protocols for MCMC methods.""" + +from typing import Any, Protocol, runtime_checkable + +from cuthbertlib.types import Array, KeyArray + + +@runtime_checkable +class AncestorMove(Protocol): + """Protocol for ancestor index selection operations.""" + + def __call__(self, key: KeyArray, weights: Array, pivot: int) -> tuple[int, Any]: + """ + Selects an ancestor index, potentially using an MCMC move around a pivot. + + Args: + key: JAX PRNG key. + weights: Normalized weights of the particles. + pivot: The current index to move from. + + Returns: + A tuple containing the new index and any auxiliary output from the move. + """ + ... diff --git a/cuthbertlib/resampling/systematic.py b/cuthbertlib/resampling/systematic.py index cfc40dbb..87ca1d81 100644 --- a/cuthbertlib/resampling/systematic.py +++ b/cuthbertlib/resampling/systematic.py @@ -55,7 +55,7 @@ def conditional_resampling_0_to_0( N = logits.shape[0] weights = jnp.exp(logits - logsumexp(logits)) - tmp = n * weights[0] + tmp = n * weights[0, 0] tmp_floor = jnp.floor(tmp) U, V, W = random.uniform(key, (3,)) diff --git a/cuthbertlib/resampling/utils.py b/cuthbertlib/resampling/utils.py index 4bfb743a..f86d68ca 100644 --- a/cuthbertlib/resampling/utils.py +++ b/cuthbertlib/resampling/utils.py @@ -10,6 +10,23 @@ from cuthbertlib.types import Array, ArrayLike +@jax.jit +def normalize(log_w: Array, log_space: bool = False) -> Array: + """Normalizes a set of log-weights. + + Args: + log_w: The log-weights to normalize. + log_space: Whether to return the normalized weights in log-space. + + Returns: + The normalized weights. + """ + log_w_norm = log_w - logsumexp(log_w) + if log_space: + return log_w_norm + return jnp.exp(log_w_norm) + + @jax.jit def inverse_cdf(sorted_uniforms: ArrayLike, logits: ArrayLike) -> Array: """Inverse CDF sampling for resampling algorithms. diff --git a/tests/cuthbert/mcmc/csmc/test_conditional_particle_filter.py b/tests/cuthbert/mcmc/csmc/test_conditional_particle_filter.py new file mode 100644 index 00000000..38d5edfa --- /dev/null +++ b/tests/cuthbert/mcmc/csmc/test_conditional_particle_filter.py @@ -0,0 +1,104 @@ +import chex +import jax +import jax.numpy as jnp +import pytest +from absl.testing import parameterized + +from cuthbert.mcmc.csmc.conditional_particle_filter import build_csmc_filter +from cuthbert.filtering import filter as apply_filter +from cuthbertlib.resampling.systematic import conditional_resampling + + +# A simple linear Gaussian state-space model for testing +def f(x, _): # state transition + return 0.9 * x + + +def g(x, _): # observation + return 0.5 * x + + +def sample_init(key, _): + return jax.random.normal(key, (1,)) + + +def propagate_sample(key, prev_particles, _): + return f(prev_particles, None) + jax.random.normal(key, prev_particles.shape) + + +def log_potential(_, particles, model_inputs): + y_t, _, _ = model_inputs + return jax.scipy.stats.norm.logpdf(y_t, loc=g(particles, None), scale=1.0) + + +class TestConditionalParticleFilter(chex.TestCase): + @chex.all_variants(with_pmap=False, without_jit=False) + @parameterized.parameters( + {"seed": 0, "n_particles": 100, "seq_len": 10, "conditional": True}, + {"seed": 42, "n_particles": 100, "seq_len": 10, "conditional": False}, + ) + def test_csmc_filter(self, seed, n_particles, seq_len, conditional): + """Tests the conditional particle filter forward pass.""" + key = jax.random.key(seed) + key_truth, key_obs, key_filter = jax.random.split(key, 3) + + # Generate a ground truth trajectory + true_states = [] + x = 0.0 + keys = jax.random.split(key_truth, seq_len) + for i in range(seq_len): + x = f(x, None) + jax.random.normal(keys[i]) + true_states.append(x) + true_states = jnp.array(true_states) + + # Generate observations + observations = g(true_states, None) + jax.random.normal( + key_obs, true_states.shape + ) + + # Define reference trajectory for the filter + reference_particles = true_states + reference_indices = jnp.zeros(seq_len, dtype=int) + + # The model_inputs will be a tuple of (observation, ref_particle, ref_index) + model_inputs = ( + observations, + reference_particles, + reference_indices, + ) + + # Build the filter + csmc_filter = self.variant( + build_csmc_filter, + static_argnums=(3, 5), + )( + init_sample=sample_init, + propagate_sample=propagate_sample, + log_potential=log_potential, + n_particles=n_particles, + resampling_fn=conditional_resampling, + conditional=conditional, + ) + + # Run the filter + filtered_states = apply_filter(csmc_filter, model_inputs, key=key_filter) + + # --- Assertions --- + chex.assert_shape( + filtered_states.particles, (seq_len, n_particles, 1) + ) + chex.assert_shape(filtered_states.log_weights, (seq_len, n_particles)) + + if conditional: + # Check that the reference trajectory is correctly pinned + pinned_particles = filtered_states.particles[ + jnp.arange(seq_len), reference_indices + ] + chex.assert_trees_all_close( + pinned_particles, reference_particles, atol=1e-5 + ) + + # Check that log-likelihood is not NaN or Inf + log_likelihood = filtered_states.log_normalizing_constant[-1] + chex.assert_scalar_not_nan(log_likelihood) + chex.assert_scalar_not_inf(log_likelihood) diff --git a/tests/cuthbert/mcmc/csmc/test_smoother.py b/tests/cuthbert/mcmc/csmc/test_smoother.py new file mode 100644 index 00000000..a7e7e0a9 --- /dev/null +++ b/tests/cuthbert/mcmc/csmc/test_smoother.py @@ -0,0 +1,82 @@ +import chex +import jax +import jax.numpy as jnp +import pytest +from absl.testing import parameterized + +from cuthbert.mcmc.csmc.conditional_particle_filter import build_csmc_filter +from cuthbert.mcmc.csmc.smoother import build_csmc_smoother +from cuthbert.filtering import filter as apply_filter +from cuthbertlib.mcmc import barker_move +from cuthbertlib.resampling.systematic import conditional_resampling + +# Reuse the model from the filter test +from tests.cuthbert.mcmc.csmc.test_conditional_particle_filter import ( + f, + g, + log_potential, + propagate_sample, + sample_init, +) + + +class TestConditionalParticleSmoother(chex.TestCase): + @chex.all_variants(with_pmap=False, without_jit=False) + @parameterized.parameters( + {"seed": 0, "n_particles": 50, "seq_len": 5, "do_sampling": True}, + {"seed": 42, "n_particles": 50, "seq_len": 5, "do_sampling": False}, + ) + def test_csmc_smoother(self, seed, n_particles, seq_len, do_sampling): + """Tests the conditional particle smoother backward pass.""" + key = jax.random.key(seed) + key_truth, key_obs, key_filter, key_smooth = jax.random.split(key, 4) + + # --- Forward Pass --- + true_states = jnp.zeros((seq_len, 1)) # Dummy states + observations = g(true_states, None) + jax.random.normal( + key_obs, true_states.shape + ) + reference_particles = true_states + reference_indices = jnp.zeros(seq_len, dtype=int) + model_inputs = (observations, reference_particles, reference_indices) + + csmc_filter = build_csmc_filter( + init_sample=sample_init, + propagate_sample=propagate_sample, + log_potential=log_potential, + n_particles=n_particles, + resampling_fn=conditional_resampling, + conditional=True, + ) + forward_states = apply_filter(csmc_filter, model_inputs, key=key_filter) + + # --- Backward Pass --- + csmc_smoother = self.variant(build_csmc_smoother)( + log_potential_fn=log_potential, + ancestor_move_fn=barker_move, + conditional=True, + ) + + # Run the smoother + smoothed_particles, smoothed_indices = csmc_smoother( + forward_states, key_smooth, do_sampling + ) + + # --- Assertions --- + chex.assert_shape(smoothed_particles, (seq_len, 1)) + chex.assert_shape(smoothed_indices, (seq_len,)) + + if not do_sampling: + # For backward trace, verify the trajectory is consistent + # Reconstruct the trajectory manually and compare + reconstructed_particles = [] + current_idx = smoothed_indices[-1] + reconstructed_particles.append(forward_states.particles[-1, current_idx]) + for t in reversed(range(seq_len - 1)): + current_idx = forward_states.ancestor_indices[t, current_idx] + reconstructed_particles.append(forward_states.particles[t, current_idx]) + + reconstructed_particles = jnp.array(reconstructed_particles)[::-1] + chex.assert_trees_all_close( + smoothed_particles, reconstructed_particles, atol=1e-5 + ) diff --git a/tests/cuthbertlib/mcmc/test_index_select.py b/tests/cuthbertlib/mcmc/test_index_select.py new file mode 100644 index 00000000..e64f1dc2 --- /dev/null +++ b/tests/cuthbertlib/mcmc/test_index_select.py @@ -0,0 +1,58 @@ +import itertools + +import chex +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from absl.testing import parameterized + +from cuthbertlib.mcmc.index_select import barker_move, force_move +from cuthbertlib.resampling.utils import normalize + +N_PARTICLES = [5, 10] +test_cases = list(itertools.product([0, 42], N_PARTICLES, [barker_move, force_move])) + + +def _check_dist(indices, expected_probs, n_particles): + """Checks that the empirical distribution of indices matches the expected one.""" + n_chains, n_iter = indices.shape + counts = jnp.bincount(jnp.ravel(indices), length=n_particles) + probs = counts / (n_iter * n_chains) + tol = 1 / np.sqrt(n_chains) + chex.assert_trees_all_close(probs, expected_probs, rtol=tol, atol=tol) + + +class TestIndexSelect(chex.TestCase): + def setUp(self): + super().setUp() + self.K = 10_000 # Number of samples for statistical tests + + @chex.all_variants(with_pmap=False, without_jit=False) + @pytest.mark.xdist_group(name="mcmc") + @parameterized.parameters(test_cases) + def test_mcmc(self, seed, n_particles, move): + """Tests the functions.""" + key = jax.random.key(seed) + key_weights, key_test = jax.random.split(key, 2) + + # Generate random weights + log_weights = jax.random.uniform(key_weights, (n_particles,)) + weights = normalize(log_weights) + + # Set a fixed initial pivot + pivot = 0 + + def run_chain(key): + def body(p, key_in): + p, _ = move(key_in, weights, p) + return p, p + + # Run the chain for a few steps to burn in + _, pivot_final = jax.lax.scan(body, pivot, jax.random.split(key, self.K)) + return pivot_final + + keys = jax.random.split(key_test, 25) + final_indices = self.variant(jax.vmap(run_chain))(keys) + + _check_dist(final_indices, weights, n_particles)