Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion docs/notebooks/config.yaml
Original file line number Diff line number Diff line change
@@ -1 +1,19 @@

data_dir: "/Users/kdesoto/superphot-plus-data"
transient_data_fn: "transients_spec_group"
sampler_results_fn: "sampler_result_group"
sampler: "superphot_svi"
chisq_cutoff: 1.2
plot: True
model_type: "LightGBM"
use_redshift_features: False
fits_per_majority: 25
prob_threshhold: 0.5
n_folds: 10
num_epochs: 250
random_seed: 42
n_parallel: 10
#target_label: None
neurons_per_layer: 64
num_hidden_layers: 3
learning_rate: 0.001
batch_size: 64
4,129 changes: 164 additions & 3,965 deletions docs/notebooks/full_train_workflow.ipynb

Large diffs are not rendered by default.

95 changes: 55 additions & 40 deletions src/superphot_plus/priors/superphot_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def update(self) -> None:
self._tgb = ((self._df['max'] - self._df['mean']) / self._df['stddev']).to_numpy()

# faster sample calls
self._logged = self._df['logged'].to_numpy()
self._logged = self._df['logged'].to_numpy().astype(bool)
self._mean = self._df['mean'].to_numpy()
self._std = self._df['stddev'].to_numpy()
self._numpyro_sample_arr = jnp.array(
Expand Down Expand Up @@ -218,55 +218,65 @@ def sample(self, cube, use_numpyro=False, num_events=None):
)
)

# Compute the adjustment only for relative ones
relative_shifts = base_vals[self._relative_idxs_jax]
min_constraint = jnp.maximum(min_vals_rel + base_vals[self._relative_idxs_jax], min_vals_base[self._relative_idxs_jax])
max_constraint = jnp.minimum(max_vals_rel + base_vals[self._relative_idxs_jax], max_vals_base[self._relative_idxs_jax])
if len(self._relative_idxs_jax) > 0:

adjusted_locs = init_loc_rel + relative_shifts
# Compute the adjustment only for relative ones
relative_shifts = base_vals[self._relative_idxs_jax]
min_constraint = jnp.maximum(min_vals_rel + base_vals[self._relative_idxs_jax], min_vals_base[self._relative_idxs_jax])
max_constraint = jnp.minimum(max_vals_rel + base_vals[self._relative_idxs_jax], max_vals_base[self._relative_idxs_jax])

# Reapply the constraints to adjusted_locs to make sure they stay within bounds
adjusted_locs_constrained = jnp.clip(adjusted_locs, min_constraint + 1e-6, max_constraint - 1e-6)
adjusted_locs = init_loc_rel + relative_shifts

with numpyro.plate("relative_params", len(min_vals_rel)):
# Re-sample using the adjusted means only for relative parameters
resampled_vals = numpyro.sample(
"relative_samples",
dist.TruncatedNormal(
loc=adjusted_locs_constrained,
scale=init_scale_rel,
low=min_constraint,
high=max_constraint
# Reapply the constraints to adjusted_locs to make sure they stay within bounds
adjusted_locs_constrained = jnp.clip(adjusted_locs, min_constraint + 1e-6, max_constraint - 1e-6)

with numpyro.plate("relative_params", len(min_vals_rel)):
# Re-sample using the adjusted means only for relative parameters
resampled_vals = numpyro.sample(
"relative_samples",
dist.TruncatedNormal(
loc=adjusted_locs_constrained,
scale=init_scale_rel,
low=min_constraint,
high=max_constraint
)
)
)

vals = jnp.concatenate([
base_vals,
resampled_vals
])

vals = jnp.concatenate([
base_vals,
resampled_vals
])
else:
vals = base_vals

vals = vals.at[self._logged_jax].set(10**vals[self._logged_jax])

else:
if cube is None:
cube = self._rng.uniform(size=len(self._df))

vals = np.zeros(len(cube))

vals[~self._relative_mask] = truncnorm.ppf(
cube[~self._relative_mask],
self._tga[~self._relative_mask],
self._tgb[~self._relative_mask],
loc=self._mean[~self._relative_mask],
scale=self._std[~self._relative_mask],
)

vals[self._relative_mask] = truncnorm.ppf(
cube[self._relative_mask],
self._tga[self._relative_mask],
self._tgb[self._relative_mask],
loc=self._mean[self._relative_mask] + vals[self._relative_idxs],
scale=self._std[self._relative_mask]
)

if (len(self._relative_mask) > 0) and np.any(self._relative_mask):
vals[~self._relative_mask] = truncnorm.ppf(
cube[~self._relative_mask],
self._tga[~self._relative_mask],
self._tgb[~self._relative_mask],
loc=self._mean[~self._relative_mask],
scale=self._std[~self._relative_mask],
)

vals[self._relative_mask] = truncnorm.ppf(
cube[self._relative_mask],
self._tga[self._relative_mask],
self._tgb[self._relative_mask],
loc=self._mean[self._relative_mask] + vals[self._relative_idxs],
scale=self._std[self._relative_mask]
)
else:
vals = truncnorm.ppf(
cube, self._tga, self._tgb, loc=self._mean, scale=self._std
)

# log transformations
vals[self._logged] = 10**vals[self._logged]
Expand Down Expand Up @@ -451,15 +461,20 @@ def jax_guide(self, num_events=None):
constraint=dist.constraints.interval(1e-5, 3 * init_scale_base)
)

numpyro.sample(
base_samples = numpyro.sample(
"base_samples",
dist.Normal(
loc=svi_loc_base,
scale=svi_scale_base,
)
)

#debug.print("Mu base: {}", svi_loc_base)
#debug.print("Scale base: {}", svi_scale_base)

# Compute the shifts for relative parameters
if len(self._relative_idxs_jax) == 0:
return
relative_shifts = svi_loc_base[self._relative_idxs_jax]
adjusted_locs = init_loc_rel + relative_shifts

Expand Down
38 changes: 25 additions & 13 deletions src/superphot_plus/samplers/dynesty_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def __init__(
self._param_map = None
self._dynamic = dynamic

self._initialize_sampler(sample_strategy, bound, nlive, dynamic)


def _initialize_sampler(self, sample_strategy, bound, nlive, dynamic):
if dynamic:
self._nested_sampler = DynamicNestedSampler(
self._logL, self._prior_func, (self._nparams + 3) * len(self._unique_bands),
Expand All @@ -80,6 +84,7 @@ def __init__(
rstate=self._rng, #walks=50
)


def _logL(self, cube):
"""Define the log-likelihood function.

Expand Down Expand Up @@ -120,6 +125,25 @@ def reset(self):
self._nested_sampler.reset()


def _generate_param_map(self):
# Revise unique bands
new_unique_bands = []
for band in self._unique_bands:
if band in self._X[:, 1]:
new_unique_bands.append(band)

self._unique_bands = new_unique_bands

# map time steps to param values
self._param_map = np.zeros((self._nparams+3, len(self._X)), dtype=int)
for i, param in enumerate(self._base_params):
for b in self._unique_bands:
b_idxs = self._X[:,1] == b
self._param_map[i,b_idxs] = np.where(self._params == f'{param}_{b}')[0][0]

self._param_map = np.array(self._param_map)


def fit(self, X, y):
"""Runs dynesty importance nested sampling on a set of light curves; saves set
of equally weighted posteriors (sets of fit parameters).
Expand All @@ -139,20 +163,8 @@ def fit(self, X, y):
super().fit(X, y)
self._t = self._X[:,0].astype(np.float32)
self._err = self._X[:,2].astype(np.float32)

# map time steps to param values
self._param_map = np.zeros((self._nparams+3, len(self._X)), dtype=int)
for i, param in enumerate(self._base_params):
for b in self._unique_bands:
b_idxs = self._X[:,1] == b
self._param_map[i,b_idxs] = np.where(self._params == f'{param}_{b}')[0][0]

self._param_map = np.array(self._param_map)

# Require data in all bands
for band in self._unique_bands:
if band not in self._X[:, 1]:
return None
self._generate_param_map()

self.reset()

Expand Down
28 changes: 18 additions & 10 deletions src/superphot_plus/samplers/numpyro_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ def stablish_update(svi, svi_state, *args, forward_mode_differentiation=False, *
)
state = svi_state.optim_state
params = svi.optim.get_params(state)
#debug.print("Params: {}", params)
debug.print("Params: {}", params)
loss_val, grads = value_and_grad(loss_fn)(params)
debug.print("Loss_val: {}", loss_val)
debug.print("Grads: {}", grads)

optim_state = lax.cond(
jnp.isfinite(ravel_pytree(grads)[0]).all(),
Expand Down Expand Up @@ -311,7 +313,6 @@ def fit(
length of X.
"""
super().fit(X,y,event_indices=event_indices)

"""
_, band_counts = np.unique(X[:, 1], return_counts=True)
print(band_counts)
Expand Down Expand Up @@ -341,6 +342,7 @@ def _process_samples(self, samples_df):
samples_df = self._priors.transform(samples_df)
self.result = SamplerResult(samples_df, sampler_name=self._sampler_name)
self._is_fitted = True
print(self.result.fit_parameters)
self.result.score = np.array(
self.score(self._X, self._y, orig_num_times=self._orig_num_times)
)
Expand Down Expand Up @@ -535,6 +537,7 @@ def fit(
orig_num_times=orig_num_times,
event_indices=event_indices
)
debug.print("B: {}", self._X[:,1])

if self._svi_state is None:
self.reset()
Expand Down Expand Up @@ -565,6 +568,7 @@ def fit(
)

params = self._svi.get_params(self._svi_state)
print(params)

if event_indices is not None:
params_loc = jnp.concatenate([
Expand Down Expand Up @@ -608,14 +612,18 @@ def fit(
self._process_samples_hierarchical(global_mu_arr, global_scale_arr, indiv_param_arr)

else:
params_loc = jnp.concatenate([
params['loc_base'],
params['loc_relative']
])
params_scale = jnp.concatenate([
params['scale_base'],
params['scale_relative']
])
if 'loc_relative' in params:
params_loc = jnp.concatenate([
params['loc_base'],
params['loc_relative']
])
params_scale = jnp.concatenate([
params['scale_base'],
params['scale_relative']
])
else:
params_loc = params['loc_base']
params_scale = params['scale_base']

param_arr = params_loc + random.normal(
key=self._rng, shape=(1000, len(params_loc))
Expand Down
28 changes: 24 additions & 4 deletions src/superphot_plus/samplers/superphot_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@ def __init__(
self._priors = priors
self._params = self._priors.dataframe['param'].to_numpy()
self._unique_bands = []

for c in self._params:
if c[0] == 'A':
self._unique_bands.append(c[2:])
band_suffix = "_".join(c.split("_")[-2:])
if band_suffix not in self._unique_bands:
self._unique_bands.append(band_suffix)

self._base_params = []
for c in self._params:
if self._unique_bands[0] in c:
self._base_params.append(c.replace("_" +self._unique_bands[0], ""))
self.result = None

print(self._params, self._base_params)

def fit(self, X, y, event_indices=None):
"""Remove elements where filter not included in priors.
Expand All @@ -54,6 +56,7 @@ def fit(self, X, y, event_indices=None):
else:
super().fit(X[mask], y[mask])


def _reformat_cube(self, cube):
"""Reformat cube based on self._param_map"""
return cube[self._param_map]
Expand Down Expand Up @@ -92,4 +95,21 @@ def predict(self, X, num_fits=None):
return self.flux_model(
cube,
val_x[:, 0].astype(np.float32), val_x[:, 1]
), val_x
), val_x


def plot_fit_abs_mag(
self, ax, formatter=None, photometry=None, X=None, dense=True,
):
"""Plots the model fit in absolute magnitude space.

Parameters
----------
X : np.ndarray
The x data to plot.
y : np.ndarray
The y data to plot.
dense : bool, optional
Whether to make time array dense for better plotting.
"""
pass
10 changes: 9 additions & 1 deletion src/superphot_plus/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ def train(self, i: int, train_data, val_data):
self.config.input_features = train_df.columns[~train_df.columns.isin(['label', 'score', 'sampler'])]

# extract features
train_df.replace([np.inf, -np.inf], np.nan, inplace=True)
val_df.replace([np.inf, -np.inf], np.nan, inplace=True)
train_df.dropna(axis=0, how='any', subset=self.config.input_features, inplace=True)
val_df.dropna(axis=0, how='any', subset=self.config.input_features, inplace=True)
train_features = train_df.loc[:, self.config.input_features]
val_features = val_df.loc[:, self.config.input_features]

Expand Down Expand Up @@ -224,7 +228,11 @@ def evaluate(self, k_fold, test_data):
if self.config.input_features is None:
self.config.input_features = test_df.columns[~test_df.columns.isin(['label', 'score', 'sampler'])]

probs_avg = model.evaluate(test_df[self.config.input_features])
test_df.replace([np.inf, -np.inf], np.nan, inplace=True)
test_df.dropna(axis=0, how='any', subset=self.config.input_features, inplace=True)
test_features = test_df.loc[:, self.config.input_features]

probs_avg = model.evaluate(test_features)

if self.config.target_label is None:
probs_avg.columns = np.sort(self.config.allowed_types)
Expand Down
20 changes: 0 additions & 20 deletions tests/data/ZTF22abvdwik.csv

This file was deleted.

Binary file removed tests/data/goldens/ZTF22abvdwik_eqwt_NUTS.npz
Binary file not shown.
Binary file removed tests/data/goldens/ZTF22abvdwik_eqwt_dynesty.npz
Binary file not shown.
Binary file removed tests/data/goldens/ZTF22abvdwik_eqwt_svi.npz
Binary file not shown.
Binary file removed tests/data/lsst_lcs/10372190.npz
Binary file not shown.
Binary file removed tests/data/lsst_lcs/16342379.npz
Binary file not shown.
Binary file removed tests/data/lsst_lcs/19513412.npz
Binary file not shown.
Binary file removed tests/data/lsst_lcs/28956522.npz
Binary file not shown.
Loading
Loading