Skip to content
Open
4 changes: 2 additions & 2 deletions src/discovery/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def to_df(ys, psrs=None):
return pd.DataFrame(np.array(xs), columns=psrcols).sort_index(axis=1)

def logprior(ys):
return jnp.sum(jnp.log(2.0) - 2.0 * jnp.logaddexp(ys, -ys))
return jnp.sum(jnp.log((b - a) / 2) + 2.0 * (jnp.log(2.0) - jnp.logaddexp(ys, -ys)))

def logL(ys):
return func(to_dict(ys))
Expand Down Expand Up @@ -192,7 +192,7 @@ def to_df(ys):
return pd.DataFrame(np.array(xs), columns=func.params)

def logprior(ys):
return jnp.sum(jnp.log(2.0) - 2.0 * jnp.logaddexp(ys, -ys))
return jnp.sum(jnp.log((b - a) / 2) + 2.0 * (jnp.log(2.0) - jnp.logaddexp(ys, -ys)))

# return jnp.sum(jnp.log(0.5) - 2.0 * jnp.log(jnp.cosh(ys)))
# but log(0.5) - 2 * log(cosh(y))
Expand Down
26 changes: 23 additions & 3 deletions src/discovery/samplers/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,23 @@ def makemodel_transformed(mylogl, transform=prior.makelogtransform_uniform, prio

def numpyro_model():
pars = numpyro.sample('pars', dist.Normal(0, 10).expand([parlen]))
logl = logx(pars)
base_logl = logx(pars)

jac = logx.logprior(pars)
logl = base_logl

# keep track of the original log-likelihood
# this is useful for diagnostics and reweighting
numpyro.deterministic('logl_original', logl - jac)
numpyro.factor('logl', logl)
numpyro_model.to_df = lambda chain: logx.to_df(chain['pars'])

def to_df(chain_samples):
df = logx.to_df(chain_samples['pars'])
logl_arr = chain_samples['logl_original']
df['logl'] = logl_arr.reshape(-1)
return df

numpyro_model.to_df = to_df
return numpyro_model


Expand All @@ -29,8 +41,16 @@ def numpyro_model():
logl = mylogl({par: numpyro.sample(par, dist.Uniform(*prior.getprior_uniform(par, priordict)))
for par in mylogl.params})

numpyro.deterministic('logl_det', logl)
numpyro.factor('logl', logl)
numpyro_model.to_df = lambda chain: pd.DataFrame(chain)

def to_df(chain_samples):
df = pd.DataFrame(chain_samples)
logl_arr = chain_samples['logl_det']
df['logl'] = logl_arr.reshape(-1)
return df

numpyro_model.to_df = to_df

return numpyro_model

Expand Down
95 changes: 95 additions & 0 deletions src/discovery/samplers/reweight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pandas as pd
import jax
from jax import numpy as jnp
from typing import Callable


def batch_reweight(
source_df: pd.DataFrame,
target_logl_fn: Callable,
batch_size: int = 64
) -> pd.DataFrame:
"""
Compute the log-likelihood of each sample in source_df under a new model
(given by target_logl_fn) in batches, and return a copy of source_df
augmented with a 'logl' column of the recomputed values.

Args:
source_df: DataFrame of samples (one row per sample, columns = parameters).
If it contains a 'logl' column it will be dropped.
target_logl_fn: Function that maps a 1D array of parameter values to a scalar
log-likelihood.
batch_size: Number of samples to process per vmapped function call.

Returns:
DataFrame: A copy of source_df with 'logl' from target_logl_fn.
"""
df = source_df.copy()
if 'logl' in df.columns:
df = df.drop(columns=['logl'])

jax_dict = {col: jnp.array(df[col].values) for col in df.columns}

jitted_logl_fn = jax.jit(target_logl_fn)
recomputed_logl = jax.lax.map(jitted_logl_fn, jax_dict, batch_size=batch_size)

result_df = pd.DataFrame(df)
result_df['logl'] = recomputed_logl
return result_df

def compute_weights(
base_df: pd.DataFrame,
reweighted_df: pd.DataFrame
) -> jnp.ndarray:
"""
Given two DataFrames with 'logl' columns, compute importance weights w_i = exp(logl2_i - logl1_i).

Args:
base_df: Original samples with their log-likelihoods under the first model.
reweighted_df: Same samples with log-likelihoods under the second model.

Returns:
Array of weights of shape (N,).
"""
logl1 = base_df['logl'].to_numpy()
logl2 = reweighted_df['logl'].to_numpy()
return jnp.exp(logl2 - logl1)

def compute_bayes_factor_from_weights(
weights: jnp.ndarray
) -> tuple[float, float]:
"""
Estimate the Bayes factor between two models:
BF = E[w] = mean_i exp(logl2_i - logl1_i),
with uncertainty σ_w / sqrt(N).

Args:
weights: Array of weights computed from log-likelihoods

Returns:
A tuple (bayes_factor, uncertainty).
"""
bf = jnp.mean(weights)
bf_unc = jnp.std(weights) / jnp.sqrt(len(weights))
return float(bf), float(bf_unc)

def compute_reweighted_bayes_factor(
source_df: pd.DataFrame,
target_logl_fn: Callable,
batch_size: int = 64
) -> tuple[float, float]:
"""
Compute the Bayes factor between two models by reweighting samples from source_df.

Args:
source_df: DataFrame of samples (one row per sample, columns = parameters).
target_logl_fn: Function that maps a 1D array of parameter values to a scalar
log-likelihood.
batch_size: Number of samples to process per vmapped function call.

Returns:
A tuple (bayes_factor, uncertainty).
"""
reweighted_df = batch_reweight(source_df, target_logl_fn, batch_size)
weights = compute_weights(source_df, reweighted_df)
return compute_bayes_factor_from_weights(weights)