Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added data/EPTA_feather_new/J0030+0451
Binary file not shown.
Binary file added data/EPTA_feather_new/J0613-0200
Binary file not shown.
Binary file added data/EPTA_feather_new/J0751+1807
Binary file not shown.
Binary file added data/EPTA_feather_new/J0900-3144
Binary file not shown.
Binary file added data/EPTA_feather_new/J1012+5307
Binary file not shown.
Binary file added data/EPTA_feather_new/J1022+1001
Binary file not shown.
Binary file added data/EPTA_feather_new/J1024-0719
Binary file not shown.
Binary file added data/EPTA_feather_new/J1455-3330
Binary file not shown.
Binary file added data/EPTA_feather_new/J1600-3053
Binary file not shown.
Binary file added data/EPTA_feather_new/J1640+2224
Binary file not shown.
Binary file added data/EPTA_feather_new/J1713+0747
Binary file not shown.
Binary file added data/EPTA_feather_new/J1730-2304
Binary file not shown.
Binary file added data/EPTA_feather_new/J1738+0333
Binary file not shown.
Binary file added data/EPTA_feather_new/J1744-1134
Binary file not shown.
Binary file added data/EPTA_feather_new/J1751-2857
Binary file not shown.
Binary file added data/EPTA_feather_new/J1801-1417
Binary file not shown.
Binary file added data/EPTA_feather_new/J1804-2717
Binary file not shown.
Binary file added data/EPTA_feather_new/J1843-1113
Binary file not shown.
Binary file added data/EPTA_feather_new/J1857+0943
Binary file not shown.
Binary file added data/EPTA_feather_new/J1909-3744
Binary file not shown.
Binary file added data/EPTA_feather_new/J1910+1256
Binary file not shown.
Binary file added data/EPTA_feather_new/J1911+1347
Binary file not shown.
Binary file added data/EPTA_feather_new/J1918-0642
Binary file not shown.
Binary file added data/EPTA_feather_new/J2124-3358
Binary file not shown.
Binary file added data/EPTA_feather_new/J2322+2057
Binary file not shown.
409 changes: 409 additions & 0 deletions examples/CURN_and_HD_flow_examples.ipynb

Large diffs are not rendered by default.

Binary file added examples/EPTA_example_files/CURN_EPTA.eqx
Binary file not shown.
Binary file added examples/EPTA_example_files/DF_MAF_J1012.eqx
Binary file not shown.
Binary file added examples/EPTA_example_files/HD_EPTA.eqx
Binary file not shown.
2,990 changes: 2,990 additions & 0 deletions examples/EPTA_example_files/ptmcmc_psr_J1012_EPTAPaper/DEJump_jump.txt

Large diffs are not rendered by default.

300,001 changes: 300,001 additions & 0 deletions examples/EPTA_example_files/ptmcmc_psr_J1012_EPTAPaper/chain_1.txt

Large diffs are not rendered by default.

Binary file not shown.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

3,001 changes: 3,001 additions & 0 deletions examples/EPTA_example_files/ptmcmc_psr_J1012_EPTAPaper/draw_from_prior_jump.txt

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
covarianceJumpProposalAM 0.14
DEJump 0.45
draw_from_prior 0.14
covarianceJumpProposalSCAM 0.27
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
J1012+5307_dm_gp_gamma
J1012+5307_dm_gp_log10_A
J1012+5307_red_noise_gamma
J1012+5307_red_noise_log10_A
250 changes: 250 additions & 0 deletions examples/EPTA_single_pulsar_flow_example.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/discovery/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def __call__(self, flow, loss, temp):

# ps = logx.to_df(flow.sample(train_key, sample_shape=(2*4096,)))

train_key = jax.random.PRNGKey(42)
#train_key = jax.random.PRNGKey(42) # ORIG
train_key = jax.random.key(42)

if self.vlogx is None:
ps = self.logx.to_df(flow.sample(train_key, sample_shape=(2*4096,)))
Expand Down
2 changes: 1 addition & 1 deletion src/discovery/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,4 +1658,4 @@ def kernelterms(params):

kernelterms.params = self.P_var.params

return kernelterms
return kernelterms
75 changes: 75 additions & 0 deletions src/discovery/models/EPTA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from .. import signals
from .. import likelihood


# Signle pulsar noise analysis. Data the second data relise of the European Pulsar Timing Array (DR2new+ dataset)
# Note: The exponential dips are not included in this model
def makemodel_singlepulsar(psrs, psr_name):

for p in psrs:
if psr_name in p.name:
sgl_psr = p

model = [sgl_psr.residuals, signals.makenoise_measurement(sgl_psr, sgl_psr.noisedict), signals.makegp_timing(sgl_psr)]

if sgl_psr.noisedict[sgl_psr.name + '_dm_gp_components']:
model.append(signals.makegp_fourier(sgl_psr, signals.powerlaw, sgl_psr.noisedict[sgl_psr.name + '_dm_gp_components'], T=signals.getspan(sgl_psr), name='dm_gp', fourierbasis=signals.make_dmfourierbasis(alpha = 2.0, tndm = False)))

if sgl_psr.noisedict[sgl_psr.name + '_chrom_components']:
model.append(signals.makegp_fourier(sgl_psr, signals.powerlaw, sgl_psr.noisedict[sgl_psr.name + '_chrom_components'], T=signals.getspan(sgl_psr), name='chrom_gp', fourierbasis=signals.make_dmfourierbasis(alpha = 4.0, tndm = False)))

if sgl_psr.noisedict[sgl_psr.name + '_red_components']:
model.append(signals.makegp_fourier(sgl_psr, signals.powerlaw, sgl_psr.noisedict[sgl_psr.name + '_red_components'], T=signals.getspan(sgl_psr), name='red_noise' ))

return likelihood.PulsarLikelihood(model)


# CURN model from the second data relise of the European Pulsar Timing Array (DR2new+ dataset).
# Note: The exponential dips are not included in this model
def makemodel_curn_EPTA(psrs, crn_components = 30):

pslmodels = []
tspan = signals.getspan(psrs)

for p in psrs:

model = [p.residuals, signals.makenoise_measurement(p, p.noisedict), signals.makegp_timing(p, svd=True)]

if p.noisedict[p.name + '_dm_gp_components']:
model.append(signals.makegp_fourier(p, signals.powerlaw, p.noisedict[p.name + '_dm_gp_components'], T=signals.getspan(p), name='dm_gp', fourierbasis=signals.make_dmfourierbasis(alpha = 2.0, tndm = True)))

if p.noisedict[p.name + '_chrom_components']:
model.append(signals.makegp_fourier(p, signals.powerlaw, p.noisedict[p.name + '_chrom_components'], T=signals.getspan(p), name='chrom_gp', fourierbasis=signals.make_dmfourierbasis(alpha = 4.0, tndm = True)))

if p.noisedict[p.name + '_red_components']:
model.append(signals.makegp_fourier(p, signals.powerlaw, p.noisedict[p.name + '_red_components'], T=tspan, name='red_noise' ))

pslmodels.append(likelihood.PulsarLikelihood(model))

return likelihood.GlobalLikelihood(psls = pslmodels, globalgp=signals.makegp_fourier_global(psrs, signals.powerlaw, signals.uncorrelated_orf, components=crn_components, T=tspan, name='gw_crn'))


# HD model from the second data relise of the European Pulsar Timing Array (DR2new+ dataset).
# Note: The exponential dips are not included in this model
def makemodel_hd_EPTA(psrs, gw_components = 30):

pslmodels = []
tspan = signals.getspan(psrs)

for p in psrs:

model = [p.residuals, signals.makenoise_measurement(p, p.noisedict), signals.makegp_timing(p, svd=True)]

if p.noisedict[p.name + '_dm_gp_components']:
model.append(signals.makegp_fourier(p, signals.powerlaw, p.noisedict[p.name + '_dm_gp_components'], T=signals.getspan(p), name='dm_gp', fourierbasis=signals.make_dmfourierbasis(alpha = 2.0, tndm = True)))

if p.noisedict[p.name + '_chrom_components']:
model.append(signals.makegp_fourier(p, signals.powerlaw, p.noisedict[p.name + '_chrom_components'], T=signals.getspan(p), name='chrom_gp', fourierbasis=signals.make_dmfourierbasis(alpha = 4.0, tndm = True)))

if p.noisedict[p.name + '_red_components']:
model.append(signals.makegp_fourier(p, signals.powerlaw, p.noisedict[p.name + '_red_components'], T=tspan, name='red_noise' ))

pslmodels.append(likelihood.PulsarLikelihood(model))

return likelihood.GlobalLikelihood(psls = pslmodels, globalgp=signals.makegp_fourier_global(psrs, signals.powerlaw, signals.hd_orf, components=gw_components, T=tspan, name='gw_hd'))

50 changes: 16 additions & 34 deletions src/discovery/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,32 @@ def logpriorfunc(params):

return logpriorfunc


priordict_standard = {
"(.*_)?efac": [0.9, 1.1],
"(.*_)?t2equad": [-8.5, -5],
"(.*_)?tnequad": [-8.5, -5],
"(.*_)?log10_ecorr": [-8.5, -5],
"(.*_)?rednoise_log10_A.*": [-20, -11],
"(.*_)?rednoise_log10_A.*": [-18, -10],
"(.*_)?rednoise_gamma.*": [0, 7],
"(.*_)?rednoise_log10_fb": [-9, -6],
"(.*_)?red_noise_log10_A.*": [-20, -11], # deprecated
"(.*_)?red_noise_log10_A.*": [-18, -10], # deprecated
"(.*_)?red_noise_gamma.*": [0, 7], # deprecated
"(.*_)?red_noise_log10_fb": [-9, -6],
"crn_log10_A.*": [-18, -11],
"crn_gamma.*": [0, 7],
"crn_log10_fb": [-9, -6],
"gw_(.*_)?log10_A": [-18, -11],
"gw_(.*_)?gamma": [0, 7],
"gw_log10_fb": [-9, -6],
"(.*_)?dmgp_log10_A": [-20, -11],
"(.*_)?dmgp_gamma": [0, 7],
"(.*_)?dmgp_alpha": [1, 3],
"(.*_)?gp_log10_A": [-18, -10],
"(.*_)?gp_gamma": [0, 7],
"(.*_)?gp_alpha": [1, 3],
"crn_log10_rho": [-9, -4],
"gw_(.*_)?log10_rho": [-9, -4],
r"(.*_)?red_noise_log10_rho\(([0-9]*)\)": [-9, -4],
r"(.*_)?red_noise_crn_log10_rho\(([0-9]*)\)": [-9, -4]
"(.*_)?red_noise_log10_rho\(([0-9]*)\)": [-9, -4],
"(.*_)?red_noise_crn_log10_rho\(([0-9]*)\)": [-9, -4],
"(.*_)?log10_Amp": [-10, -2],
"(.*_)?log10_tau": [0, 2.5],
"(.*_)?2_t0": [54650, 54850],
"(.*_)?1_t0": [57490, 57530]
}

def getprior_uniform(par, priordict={}):
priordict = {**priordict_standard, **priordict}

for parname, range in priordict.items():
if re.match(parname, par):
return range

raise KeyError(f'getprior_uniform: no prior for parameter {par}.')

def makelogprior_uniform(params, priordict={}):
priordict = {**priordict_standard, **priordict}
Expand All @@ -70,17 +61,9 @@ def makelogtransform_uniform(func, priordict={}):
# figure out slices when there are vector arguments
slices, offset = [], 0
for par in func.params:
# l = int(par[par.index('(')+1:par.index(')')]) if '(' in par else 1
# slices.append(slice(offset, offset+l))
# offset = offset + l

if '(' in par:
l = int(par[par.index('(')+1:par.index(')')]) if '(' in par else 1
slices.append(slice(offset, offset+l))
offset = offset + l
else:
slices.append(offset)
offset = offset + 1
l = int(par[par.index('(')+1:par.index(')')]) if '(' in par else 1
slices.append(slice(offset, offset+l))
offset = offset + l

# build vectors of DF column names and of lower and upper uniform limits
a, b = [], []
Expand Down Expand Up @@ -202,7 +185,7 @@ def transformed(ys):
return transformed


def sample_uniform(params, priordict={}, n=1, fail=True):
def sample_uniform(params, priordict={}, n=1, fail = True):
priordict = {**priordict_standard, **priordict}

sample = {}
Expand All @@ -219,7 +202,6 @@ def sample_uniform(params, priordict={}, n=1, fail=True):
sample[par] = np.random.uniform(*range) if n == 1 else np.random.uniform(*range, size=n)
break
else:
if fail:
raise KeyError(f"No known prior for {par}.")
raise KeyError(f"No known prior for {par}.")

return sample
4 changes: 2 additions & 2 deletions src/discovery/samplers/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def makesampler_nuts(numpyro_model, num_warmup=512, num_samples=1024, num_chains
nutsargs = dict(max_tree_depth=8, dense_mass=False,
forward_mode_differentiation=False, target_accept_prob=0.8,
**{arg: val for arg in kwargs.items() if arg in inspect.getfullargspec(infer.NUTS).args})

mcmcargs = dict(num_warmup=512, num_samples=1024, num_chains=1,
mcmcargs = dict(num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains,
chain_method='vectorized', progress_bar=True,
**{arg: val for arg in kwargs.items() if arg in inspect.getfullargspec(infer.MCMC).kwonlyargs})

Expand Down
42 changes: 42 additions & 0 deletions src/discovery/serialisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import jax
import optax
import equinox as eqx
import numpy as np
import json
import jax.random as jr
import prior

from flowjax.bijections import RationalQuadraticSpline, Affine
from flowjax.distributions import StandardNormal, Transformed
from flowjax.flows import masked_autoregressive_flow


# build the flow
def make(*, key, loglike, flow_arch = masked_autoregressive_flow, base_distribution = StandardNormal, transform = RationalQuadraticSpline,
n_samples = 512, knots = 8, interval = 5, patience = 100, multibatch = 1, LR = 1e-3, steps = 2000, flow_lay = 8, deepness = 1):

l = np.array([prior.sample_uniform(loglike.logL.params)[k] for k in prior.sample_uniform(loglike.logL.params).keys()])

key, flow_key, train_key = jr.split(key, 3)
flow = flow_arch(flow_key, base_dist= base_distribution(l.shape), flow_layers = flow_lay, nn_depth = deepness, transformer=transform(knots = knots, interval=interval), invert=False)

return flow

# save the hyperparameters
def save(filename, hyperparams, model):
with open(filename, "wb") as f:
hyperparam_str = json.dumps(hyperparams)
f.write((hyperparam_str + "\n").encode())
eqx.tree_serialise_leaves(f, model)

# load the model
def load(filename, loglike, flow_arch = masked_autoregressive_flow, base_distribution = StandardNormal, transform = RationalQuadraticSpline):
with open(filename, "rb") as f:
hyperparams = json.loads(f.readline().decode())
# in case I set an interval instead of a single number
if isinstance(hyperparams['interval'], list):
hyperparams['interval'] = (hyperparams['interval'][0], hyperparams['interval'][1])

model = make(key=jr.PRNGKey(42), loglike = loglike, flow_arch = flow_arch, base_distribution = base_distribution, transform = transform, **hyperparams)
print(hyperparams)
return eqx.tree_deserialise_leaves(f, model)
Loading
Loading