Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_multivariate_normal_log_prob(
) -> tuple[LogProbFn, tuple[torch.Tensor, torch.Tensor]]:
mean = torch.randn(dim)
sqrt_cov = torch.randn(dim, dim)
cov = sqrt_cov @ sqrt_cov.T
cov = sqrt_cov @ sqrt_cov.T * 0.3 + torch.eye(dim) * 0.7
chol_cov = torch.linalg.cholesky(cov) # Lower triangular with positive diagonal

def log_prob(p, batch):
Expand Down
2 changes: 1 addition & 1 deletion tests/sgmcmc/test_baoa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_baoa():
torch.manual_seed(42)

# Set inference parameters (with torch.optim.SGD parameterization)
lr = 1e-3
lr = 1e-2
mu = 0.9
tau = 0.9

Expand Down
2 changes: 1 addition & 1 deletion tests/sgmcmc/test_sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_sghmc():
torch.manual_seed(42)

# Set inference parameters (with torch.optim.SGD parameterization)
lr = 1e-3
lr = 1e-2
mu = 0.9
tau = 0.9

Expand Down
2 changes: 1 addition & 1 deletion tests/sgmcmc/test_sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_sgld():
torch.manual_seed(42)

# Set inference parameters
lr = 1e-3
lr = 1e-2
beta = 0.0

# Run MCMC test on Gaussian
Expand Down
2 changes: 1 addition & 1 deletion tests/sgmcmc/test_sgnht.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_sgnht():
torch.manual_seed(42)

# Set inference parameters (with torch.optim.SGD parameterization)
lr = 1e-3
lr = 1e-2
mu = 0.9
tau = 0.9

Expand Down
22 changes: 17 additions & 5 deletions tests/sgmcmc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

def run_test_sgmcmc_gaussian(
transform_builder: Callable[[LogProbFn], Transform],
dim: int = 3,
n_steps: int = 15_000,
burnin: int = 5_000,
dim: int = 2,
n_steps: int = 100_000,
burnin: int = 15_000,
rtol: float = 1e-2, # Relative reduction of KL for final distribution compared to initial distribution
):
# Load log posterior
Expand All @@ -33,6 +33,13 @@ def run_test_sgmcmc_gaussian(
# Remove burnin
all_states = all_states[burnin:]

samples_mean = all_states.params.mean(0)
samples_cov = all_states.params.T.cov()

# Check sufficient statistics
assert torch.allclose(samples_mean, mean, atol=1e-1)
assert torch.allclose(samples_cov, cov, atol=1e-1)

# Check KL divergence between true and inferred Gaussian
kl_init = utils.kl_gaussians(torch.zeros(dim), torch.eye(dim) * init_var, mean, cov)
kl_inferred = utils.kl_gaussians(
Expand All @@ -54,7 +61,12 @@ def run_test_sgmcmc_gaussian(

start_num_samples = 500 # Omit first few KLs with high variance due to few samples
spacing = 100
kl_divs_spaced = kl_divs[start_num_samples::spacing]

kl_divs_min_idx = torch.argmin(kl_divs[2:]) + 2
kl_divs_before_min = kl_divs[:kl_divs_min_idx]
kl_divs_spaced = kl_divs_before_min[start_num_samples::spacing]
spaced_decreasing = kl_divs_spaced[:-1] > kl_divs_spaced[1:]
proportion_decreasing = spaced_decreasing.float().mean()
assert proportion_decreasing > 0.8
assert proportion_decreasing > 0.5, (
f"Proportion decreasing: {proportion_decreasing}"
)