Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 102 additions & 69 deletions src/pint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,147 +1079,180 @@ def dmxstats(
print(f"{(1-selected).sum()} TOAs not selected in any DMX window", file=file)


def dmxparse(
fitter: "pint.fitter.Fitter", save: bool = False
def chromatic_parse(
fitter: "pint.fitter.Fitter", prefix: str = "DM", save: bool = False
) -> Dict[str, Union[u.Quantity, List]]:
"""Run dmxparse in python using PINT objects and results.
"""Run chromatic_parse in python using PINT objects and results.

Based off dmxparse by P. Demorest (https://github.com/nanograv/tempo/tree/master/util/dmxparse)

Parameters
----------
fitter
PINT fitter used to get timing residuals, must have already run a fit
prefix
Prefix parameter type to parse. Defaults to ``DM``. Outputs and returns ``xxx`` or ``xx`` are replaced by the prefix letters, e.g., ``dmx`` or ```dm``, respectively
save : bool or str or file-like object, optional
If not False or None, saves output to specified file in the format of the TEMPO version. If ``True``, assumes output file is ``dmxparse.out``
If not False or None, saves output to specified file in the format of the TEMPO version. If ``True``, assumes output file is of the format ``xxxparse.out``

Returns
-------
dict :

``dmxs`` : mean-subtraced dmx values
``xxxs`` : mean-subtraced xxx values

``dmx_verrs`` : dmx variance errors
``xxx_verrs`` : xxx variance errors

``dmxeps`` : center mjds of the dmx bins
``xxxeps`` : center mjds of the xxx bins

``r1s`` : lower mjd bounds on the dmx bins
``r1s`` : lower mjd bounds on the xxx bins

``r2s`` : upper mjd bounds on the dmx bins
``r2s`` : upper mjd bounds on the xxx bins

``bins`` : dmx bins
``bins`` : xxx bins

``mean_dmx`` : mean dmx value
``mean_xxx`` : mean xxx value

``avg_dm_err`` : uncertainty in average dmx
``avg_pr_err`` : uncertainty in average xxx

Raises
------
RuntimeError
If the model has no DMX parameters, or if there is a parsing problem
If the model has no XXX parameters, or if there is a parsing problem

"""
# We get the DMX values, errors, and mjds (same as in getting the DMX values for DMX v. time)
# Get number of DMX epochs
try:
DMX_mapping = fitter.model.get_prefix_mapping("DMX_")
X_mapping = fitter.model.get_prefix_mapping(f"{prefix}X_")
except ValueError as e:
raise RuntimeError("No DMX values in model!") from e
dmx_epochs = [f"{x:04d}" for x in DMX_mapping.keys()]
DMX_keys = list(DMX_mapping.values())
DMXs = np.zeros(len(dmx_epochs))
DMX_Errs = np.zeros(len(dmx_epochs))
DMX_R1 = np.zeros(len(dmx_epochs))
DMX_R2 = np.zeros(len(dmx_epochs))
mask_idxs = np.zeros(len(dmx_epochs), dtype=np.bool_)
# Get DMX values (will be in units of 10^-3 pc cm^-3)
for ii, epoch in enumerate(dmx_epochs):
DMXs[ii] = getattr(fitter.model, "DMX_{:}".format(epoch)).value
mask_idxs[ii] = getattr(fitter.model, "DMX_{:}".format(epoch)).frozen
DMX_Errs[ii] = getattr(fitter.model, "DMX_{:}".format(epoch)).uncertainty_value
DMX_R1[ii] = getattr(fitter.model, "DMXR1_{:}".format(epoch)).value
DMX_R2[ii] = getattr(fitter.model, "DMXR2_{:}".format(epoch)).value
DMX_center_MJD = (DMX_R1 + DMX_R2) / 2
raise RuntimeError(f"No {prefix}X values in model!") from e
lower = prefix.lower()
xxx_epochs = [f"{x:04d}" for x in X_mapping.keys()]
X_keys = list(X_mapping.values())
Xs = np.zeros(len(xxx_epochs))
X_Errs = np.zeros(len(xxx_epochs))
X_R1 = np.zeros(len(xxx_epochs))
X_R2 = np.zeros(len(xxx_epochs))
mask_idxs = np.zeros(len(xxx_epochs), dtype=np.bool_)
# Get XXX values (DMX, SWX will be in units of 10^-3 pc cm^-3, CMX in units of 10^-3 pc cm^-3 MHz^-2)
for ii, epoch in enumerate(xxx_epochs):
Xs[ii] = getattr(fitter.model, f"{prefix}X_{epoch}").value
mask_idxs[ii] = getattr(fitter.model, f"{prefix}X_{epoch}").frozen
X_Errs[ii] = getattr(fitter.model, f"{prefix}X_{epoch}").uncertainty_value
X_R1[ii] = getattr(fitter.model, f"{prefix}XR1_{epoch}").value
X_R2[ii] = getattr(fitter.model, f"{prefix}XR2_{epoch}").value
X_center_MJD = (X_R1 + X_R2) / 2
# If any value need to be masked, do it
if True in mask_idxs:
log.warning(
"Some DMX bins were not fit for, masking these bins for computation."
f"Some {prefix}X bins were not fit for, masking these bins for computation."
)
DMX_Errs = np.ma.array(DMX_Errs, mask=mask_idxs)
DMX_keys_ma = np.ma.array(DMX_keys, mask=mask_idxs)
X_Errs = np.ma.array(X_Errs, mask=mask_idxs)
X_keys_ma = np.ma.array(X_keys, mask=mask_idxs)
else:
DMX_keys_ma = None
X_keys_ma = None

# Make sure that the fitter has a covariance matrix, otherwise return the initial values
if hasattr(fitter, "parameter_covariance_matrix"):
# now get the full parameter covariance matrix from pint
# access by label name to make sure we get the right values
# make sure they are sorted in ascending order
cc = fitter.parameter_covariance_matrix.get_label_matrix(
sorted([f"DMX_{x}" for x in dmx_epochs])
sorted([f"{prefix}X_{x}" for x in xxx_epochs])
)
n = len(DMX_Errs) - np.sum(mask_idxs)
# Find error in mean DM
DMX_mean = np.mean(DMXs)
DMX_mean_err = np.sqrt(cc.matrix.sum()) / float(n)
# Do the correction for varying DM
n = len(X_Errs) - np.sum(mask_idxs)
# Find error in mean value
X_mean = np.mean(Xs)
X_mean_err = np.sqrt(cc.matrix.sum()) / float(n)
# Do the correction for varying values
m = np.identity(n) - np.ones((n, n)) / float(n)
cc = np.dot(np.dot(m, cc.matrix), m)
DMX_vErrs = np.zeros(n)
X_vErrs = np.zeros(n)
# We also need to correct for the units here
for i in range(n):
DMX_vErrs[i] = np.sqrt(cc[i, i])
X_vErrs[i] = np.sqrt(cc[i, i])
# If array was masked, we need to add values back in where they were masked
if DMX_keys_ma is not None:
if X_keys_ma is not None:
# Only need to add value to DMX_vErrs
DMX_vErrs = np.insert(DMX_vErrs, np.where(mask_idxs)[0], None)
X_vErrs = np.insert(X_vErrs, np.where(mask_idxs)[0], None)
else:
log.warning(
"Fitter does not have covariance matrix, returning values from model"
)
DMX_mean = np.mean(DMXs)
DMX_mean_err = np.mean(DMX_Errs)
DMX_vErrs = DMX_Errs
X_mean = np.mean(Xs)
X_mean_err = np.mean(X_Errs)
X_vErrs = X_Errs
# Check we have the right number of params
if len(DMXs) != len(DMX_Errs) or len(DMXs) != len(DMX_vErrs):
raise RuntimeError("Number of DMX entries do not match!")
if len(Xs) != len(X_Errs) or len(Xs) != len(X_vErrs):
raise RuntimeError(f"Number of {prefix}X entries do not match!")

# Output the results'
if save is not None and save:
if isinstance(save, bool):
save = "dmxparse.out"
save = f"{lower}parse.out"
lines = [
f"# Mean DMX value = {DMX_mean:+.6e} \n",
f"# Uncertainty in average DM = {DMX_mean_err:.5e} \n",
f"# Columns: DMXEP DMX_value DMX_var_err DMXR1 DMXR2 %s_bin \n",
f"# Mean {prefix}X value = {X_mean:+.6e} \n",
f"# Uncertainty in average {prefix} = {X_mean_err:.5e} \n",
f"# Columns: {prefix}XEP {prefix}X_value {prefix}X_var_err {prefix}XR1 {prefix}XR2 %s_bin \n",
]
lines.extend(
f"{DMX_center_MJD[k]:.4f} {DMXs[k] - DMX_mean:+.7e} {DMX_vErrs[k]:.3e} {DMX_R1[k]:.4f} {DMX_R2[k]:.4f} {DMX_keys[k]} \n"
for k in range(len(dmx_epochs))
f"{X_center_MJD[k]:.4f} {Xs[k] - X_mean:+.7e} {X_vErrs[k]:.3e} {X_R1[k]:.4f} {X_R2[k]:.4f} {X_keys[k]} \n"
for k in range(len(xxx_epochs))
)
with open_or_use(save, mode="w") as dmxout:
dmxout.writelines(lines)
with open_or_use(save, mode="w") as xxxout:
xxxout.writelines(lines)
if isinstance(save, (str, Path)):
log.debug(f"Wrote dmxparse output to '{save}'")
log.debug(f"Wrote {lower}parse output to '{save}'")
# return the new mean subtracted values
mean_sub_DMXs = DMXs - DMX_mean
mean_sub_Xs = Xs - X_mean

# Get units to multiply returned arrays by
DMX_units = getattr(fitter.model, "DMX_{:}".format(dmx_epochs[0])).units
DMXR_units = getattr(fitter.model, "DMXR1_{:}".format(dmx_epochs[0])).units
X_units = getattr(fitter.model, f"{prefix}X_{xxx_epochs[0]}").units
XR_units = getattr(fitter.model, f"{prefix}XR1_{xxx_epochs[0]}").units

return {
"dmxs": mean_sub_DMXs * DMX_units,
"dmx_verrs": DMX_vErrs * DMX_units,
"dmxeps": DMX_center_MJD * DMXR_units,
"r1s": DMX_R1 * DMXR_units,
"r2s": DMX_R2 * DMXR_units,
"bins": DMX_keys,
"mean_dmx": DMX_mean * DMX_units,
"avg_dm_err": DMX_mean_err * DMX_units,
f"{lower}s": mean_sub_Xs * X_units,
f"{lower}_verrs": X_vErrs * X_units,
f"{lower}eps": X_center_MJD * XR_units,
"r1s": X_R1 * XR_units,
"r2s": X_R2 * XR_units,
"bins": X_keys,
"mean_xxx": X_mean * X_units,
f"avg_{lower[:2]}_err": X_mean_err * X_units,
}


def dmxparse(
fitter: "pint.fitter.Fitter", save: bool = False
) -> Dict[str, Union[u.Quantity, List]]:
"""Convenience function, running chromatic_parse() with ``DM``. See more information there.

Parameters
----------
fitter
PINT fitter used to get timing residuals, must have already run a fit
save : bool or str or file-like object, optional
If not False or None, saves output to specified file in the format of the TEMPO version. If ``True``, assumes output file is of the format ``dmxparse.out``
"""
return chromatic_parse(fitter, prefix="DM", save=save)


def cmxparse(
fitter: "pint.fitter.Fitter", save: bool = False
) -> Dict[str, Union[u.Quantity, List]]:
"""Convenience function, running chromatic_parse() with ``CM``. See more information there.

Parameters
----------
fitter
PINT fitter used to get timing residuals, must have already run a fit
save : bool or str or file-like object, optional
If not False or None, saves output to specified file in the format of the TEMPO version. If ``True``, assumes output file is of the format ``cmxparse.out``
"""
return chromatic_parse(fitter, prefix="CM", save=save)


def get_prefix_timerange(
model: "pint.models.TimingModel", prefixname: str
) -> Tuple[Time, ...]:
Expand Down
Loading