Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions cuthbertlib/linalg/marginal_sqrt_cov.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,62 @@
"""Extract marginal square root covariance from a joint square root covariance."""

from typing import Sequence
"""Extract marginal square root covariance(s) from a joint square root covariance."""

from jax import numpy as jnp
from jax import vmap
from jax.lax import dynamic_slice

from cuthbertlib.linalg.tria import tria
from cuthbertlib.types import Array, ArrayLike


def marginal_sqrt_cov(chol_cov: ArrayLike, start: int, end: int) -> Array:
def marginal_sqrt_cov(chol_cov: ArrayLike, start: int, size: int) -> Array:
"""Extracts square root submatrix from a joint square root matrix.

Specifically, returns B such that
B @ B.T = (chol_cov @ chol_cov.T)[start:end, start:end]
B @ B.T = (chol_cov @ chol_cov.T)[start:start+size, start:start+size]

Args:
chol_cov: Generalized Cholesky factor of the covariance matrix.
start: Start index of the submatrix.
end: End index of the submatrix.
size: Number of contiguous rows/columns of the marginal block.

Returns:
Lower triangular square root matrix of the marginal covariance matrix.
"""
chol_cov = jnp.asarray(chol_cov)

assert chol_cov.ndim == 2, "chol_cov must be a 2D array"
assert chol_cov.shape[0] == chol_cov.shape[1], "chol_cov must be square"
assert start >= 0 and end <= chol_cov.shape[0], (
"start and end must be within the bounds of chol_cov"
)
assert start < end, "start must be less than end"
# assert start >= 0 and start + size <= chol_cov.shape[0], (
# "start and start + size must be within the bounds of chol_cov"
# ) We don't assert based on start since start doesn't need to be a static argument
assert size > 0, "size must be positive"

chol_cov_select_rows = chol_cov[start:end, :]
slice_sizes = (size, chol_cov.shape[1])
chol_cov_select_rows = dynamic_slice(chol_cov, (start, 0), slice_sizes)
return tria(chol_cov_select_rows)


def block_marginal_sqrt_cov(chol_cov: ArrayLike, subdim: int) -> Array:
"""Extracts all square root submatrices of specified size from joint square root matrix.

Args:
chol_cov: Generalized Cholesky factor of the covariance matrix.
subdim: Size of the square root submatrices to extract.
Must be a divisor of the number of rows in chol_cov.

Returns:
Array of shape (chol_cov.shape[0] // subdim, subdim, subdim)
containing the square root submatrices.
"""
chol_cov = jnp.asarray(chol_cov)

assert chol_cov.ndim == 2, "chol_cov must be a 2D array"
assert chol_cov.shape[0] == chol_cov.shape[1], "chol_cov must be square"
assert subdim > 0 and chol_cov.shape[0] % subdim == 0, (
"subdim must be a positive divisor of the number of rows in chol_cov"
)

n_blocks = chol_cov.shape[0] // subdim
return vmap(lambda i: marginal_sqrt_cov(chol_cov, i * subdim, subdim))(
jnp.arange(n_blocks)
)
98 changes: 74 additions & 24 deletions tests/cuthbertlib/linalg/test_marginal_sqrt_cov.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import chex
import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized
from jax import random

from cuthbertlib.linalg.marginal_sqrt_cov import marginal_sqrt_cov
from cuthbertlib.linalg.marginal_sqrt_cov import (
block_marginal_sqrt_cov,
marginal_sqrt_cov,
)


@pytest.fixture(scope="module", autouse=True)
Expand All @@ -13,31 +18,76 @@ def config():
jax.config.update("jax_enable_x64", False)


@pytest.mark.parametrize("seed", [0, 42, 99])
@pytest.mark.parametrize(
"n,start,end",
[
(6, 0, 3), # top-left block
(6, 3, 6), # bottom-right block
(8, 2, 5), # middle block
(10, 1, 9), # large block
],
)
def test_marginal_sqrt_cov(seed, n, start, end):
key = random.key(seed)
class TestMarginalSqrtCov(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
@parameterized.product(
seed=[0, 42, 99],
block=[
(6, 0, 3), # top-left block
(6, 3, 6), # bottom-right block
(8, 2, 5), # middle block
(10, 1, 9), # large block
],
)
def test_marginal_sqrt_cov(self, seed, block):
key = random.key(seed)
n, start, end = block
size = end - start

# Random lower-triangular joint square root
L = jnp.tril(random.normal(key, (n, n)))

# Extract marginal square root
B = self.variant(
marginal_sqrt_cov,
static_argnames=("size"),
)(L, start, size)

# Expected marginal covariance block
Sigma = L @ L.T
Sigma_block = Sigma[start:end, start:end]

# Check B is lower triangular
chex.assert_trees_all_close(B, jnp.tril(B))

# Check B B^T reproduces marginal covariance
chex.assert_trees_all_close(B @ B.T, Sigma_block)

# Random lower-triangular joint square root
L = jnp.tril(random.normal(key, (n, n)))

# Extract marginal square root
B = marginal_sqrt_cov(L, start, end)
class TestBlockMarginalSqrtCov(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
@parameterized.product(
seed=[0, 42],
n_subdim=[
(6, 2),
(6, 3),
(8, 4),
(9, 3),
],
)
def test_block_marginal_sqrt_cov(self, seed, n_subdim):
n, subdim = n_subdim
key = random.key(seed)
L = jnp.tril(random.normal(key, (n, n)))

# Expected marginal covariance block
Sigma = L @ L.T
Sigma_block = Sigma[start:end, start:end]
blocks = self.variant(
block_marginal_sqrt_cov,
static_argnames=("subdim",),
)(L, subdim=subdim)

# Check B is lower triangular
assert jnp.allclose(B, jnp.tril(B))
n_blocks = n // subdim
chex.assert_equal(blocks.shape, (n_blocks, subdim, subdim))

# Check B B^T reproduces marginal covariance
assert jnp.allclose(B @ B.T, Sigma_block)
Sigma = L @ L.T
for i in range(n_blocks):
start, end = i * subdim, (i + 1) * subdim
Sigma_block = Sigma[start:end, start:end]
chex.assert_trees_all_close(
blocks[i], jnp.tril(blocks[i])
) # Check that blocks are lower triangular
chex.assert_trees_all_close(
blocks[i] @ blocks[i].T, Sigma_block
) # Check that blocks reproduce the marginal covariance
chex.assert_trees_all_close(
blocks[i], marginal_sqrt_cov(L, start, subdim)
) # Check that blocks match marginal_sqrt_cov
Loading