Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions src/discovery/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,36 @@ def fmatfunc(params):

# for use in ArrayLikelihood. Same process for all pulsars.
def makecommongp_fourier(psrs, prior, components, T, fourierbasis=fourierbasis, common=[], vector=False, name='fourierCommonGP'):
"""make a GP that gets applied to all pulsars that has the same basis length,
but potentially different basis elements for each pulsar.

This should be used with discovery.ArrayLikelihood.

Parameters
----------
psrs : list
list of discovery.Pulsar objects.
prior : Callable
function that takes in frequencies, frequency steps, and parameters and returns a prior.
components : int
max number of components in the basis
T : float, or Iterable
time span of the data, or list of time spans for each pulsar
fourierbasis : Callable, optional
function that returns the basis, by default `fourierbasis`
common : list, optional
list of parameters that are common to all pulsars (e.g. common uncorrelated red noise),
by default []
vector : bool, optional
Set to True to vectorize parameters supplied, by default False
name : str, optional
a name for this model, by default 'fourierCommonGP'

Returns
-------
gp : discovery.matrix.VariableGP
a GP object that can be used in the ArrayLikelihood
"""
argspec = inspect.getfullargspec(prior)

if vector:
Expand All @@ -292,20 +322,41 @@ def makecommongp_fourier(psrs, prior, components, T, fourierbasis=fourierbasis,
if isinstance(components, dict):
components = max(components.values())

fs, dfs, fmats = zip(*[fourierbasis(psr, components, T) for psr in psrs])
f, df = fs[0], dfs[0]
if isinstance(T, Iterable):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume these if statements are ok because they will get evaluated once when compiled?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check that len(T) == len(psrs)

fs, dfs, fmats = zip(*[fourierbasis(psr, components, T[ii]) for ii, psr in enumerate(psrs)])
else:
fs, dfs, fmats = zip(*[fourierbasis(psr, components, T) for psr in psrs])
# f, df = fs[0], dfs[0]
fs = matrix.jnparray(fs)
dfs = matrix.jnparray(dfs)

if vector:
vprior = jax.vmap(prior, in_axes=[None, None] +
[0 if f'({len(psrs)})' in arg else None for arg in argmap])
# different bases for each pulsar
if isinstance(T, Iterable):
in_axes = [0, 0] + [0 if f'({len(psrs)})' in arg else None for arg in argmap]
f = fs
df = dfs
else:
in_axes = [None, None] + [0 if f'({len(psrs)})' in arg else None for arg in argmap]
f = fs[0]
df = dfs[0]
vprior = jax.vmap(prior, in_axes=in_axes)

def priorfunc(params):
return vprior(f, df, *[params[arg] for arg in argmap])

priorfunc.params = sorted(argmap)
else:
vprior = jax.vmap(prior, in_axes=[None, None] +
[0 if isinstance(argmap, list) else None for argmap in argmaps])
# different bases for each pulsar
if isinstance(T, Iterable):
in_axes = [0, 0] + [0 if isinstance(argmap, list) else None for argmap in argmaps]
f = fs
df = dfs
else:
in_axes = [None, None] + [0 if isinstance(argmap, list) else None for argmap in argmaps]
f = fs[0]
df = dfs[0]
vprior = jax.vmap(prior, in_axes=in_axes)

def priorfunc(params):
vpars = [matrix.jnparray([params[arg] for arg in argmap]) if isinstance(argmap, list) else params[argmap]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_compare_enterprise(self):
ll_difference = enterprise_ll - jlogl(initial_position)

# There is a constant offset of ~ -52.4
offset = -52.4
offset = -52.4 - 5866.5585968

# Choose the absolute tolerance
atol = 0.1
Expand Down
69 changes: 69 additions & 0 deletions tests/test_signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3
"""Tests for discovery likelihood"""

import operator
from functools import reduce
from pathlib import Path

import discovery as ds
import jax
import pytest
import numpy as np
import numpy.testing as npt


class TestSignals:
data_dir = Path(__file__).resolve().parent.parent / "data"
psr_files = [
data_dir / "v1p1_de440_pint_bipm2019-B1855+09.feather",
data_dir / "v1p1_de440_pint_bipm2019-B1953+29.feather",
]
psrs = [ds.Pulsar.read_feather(psr) for psr in psr_files]

psr_rn_params = [f"{psrs[0].name}_red_noise_log10_A", f"{psrs[0].name}_red_noise_gamma", f"{psrs[1].name}_red_noise_log10_A", f"{psrs[1].name}_red_noise_gamma"]

# create fake parameter dict for testing
fake_params = {key: np.random.rand()*3 - 13 for key in psr_rn_params if 'log10_A' in key}
fake_params = {**fake_params, **{key: np.random.rand()*3 for key in psr_rn_params if 'gamma' in key}}

@pytest.mark.integration
def test_makecommongp_fourier_basis_construction(self):
tspans = [ds.getspan([psr]) for psr in self.psrs]
# make GP for both pulsars at once

gp = ds.makecommongp_fourier(self.psrs, ds.powerlaw, 14, T=tspans, name="red_noise")

# make two separate GPs
gp_psr1 = ds.makegp_fourier(self.psrs[0], ds.powerlaw, 14, T=tspans[0], name="red_noise")
gp_psr2 = ds.makegp_fourier(self.psrs[1], ds.powerlaw, 14, T=tspans[1], name="red_noise")

# check that bases are correct.
npt.assert_allclose(gp.F[0], gp_psr1.F)
npt.assert_allclose(gp.F[1], gp_psr2.F)

# check that noise matrices are correct
npt.assert_allclose(gp.Phi.getN(self.fake_params), np.vstack([gp_psr1.Phi.getN(self.fake_params), gp_psr2.Phi.getN(self.fake_params)]))

# now make the bases the same, giving a single tspan
# make GP for both pulsars at once
tspan_total = ds.getspan(self.psrs)
gp = ds.makecommongp_fourier(self.psrs, ds.powerlaw, 14, T=tspan_total, name="gw")

# make two separate GPs
gp_psr1 = ds.makegp_fourier(self.psrs[0], ds.powerlaw, 14, T=tspan_total, name="gw")
gp_psr2 = ds.makegp_fourier(self.psrs[1], ds.powerlaw, 14, T=tspan_total, name="gw")
# check that bases are correct.
npt.assert_allclose(gp.F[0], gp_psr1.F)
npt.assert_allclose(gp.F[1], gp_psr2.F)

# check that parameters are what they should be
gp = ds.makecommongp_fourier(self.psrs, ds.powerlaw, 14, T=tspans, name="red_noise")
expected_pars = set(self.psr_rn_params)
assert set(gp.Phi.params) == expected_pars

# check that common parameters are included
common = ["crn_log10_A", "crn_gamma"]
powerlaw = ds.makepowerlaw_crn(14)
gp = ds.makecommongp_fourier(self.psrs, powerlaw, 30, T=tspans, common=common, name="red_noise")
expected_pars = set(self.psr_rn_params + common)
assert set(gp.Phi.params) == expected_pars
Loading