Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions specparam/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,40 +166,8 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, prechecks=True):
if prechecks:
self.algorithm._fit_prechecks(self.verbose)

# In rare cases, the model fails to fit, and so uses try / except
try:

# If not set to fail on NaN or Inf data at add time, check data here
# This serves as a catch all for curve_fits which will fail given NaN or Inf
# Because FitError's are by default caught, this allows fitting to continue
if not self.data.checks['data']:
if np.any(np.isinf(self.data.power_spectrum)) or \
np.any(np.isnan(self.data.power_spectrum)):
raise FitError("Model fitting was skipped because there are NaN or Inf "
"values in the data, which preclude model fitting.")

# Call the fit function from the algorithm object
self.algorithm._fit()

# Do any parameter conversions
self._convert_params()

# Compute post-fit metrics
self.results.metrics.compute_metrics(self.data, self.results)

except FitError:

# If in debug mode, re-raise the error
if self.algorithm._debug:
raise

# Clear any interim model results that may have run
# Partial model results shouldn't be interpreted in light of overall failure
self.results._reset_results(True)

# Print out status
if self.verbose:
print("Model fitting was unsuccessful.")
# Call the sub-function to fit the model + post-processing
self._fit()


def report(self, freqs=None, power_spectrum=None, freq_range=None,
Expand Down Expand Up @@ -345,6 +313,52 @@ def to_df(self, bands=None):
return model_to_dataframe(self.results.get_results(), self.modes, bands)


def _fit(self):
""""Internal fit function to run the algorithm fit + post processing.

Notes
-----
Post-processing steps are parameter conversions & model metric evaluations.
In rare cases, the model fails to fit. To manage this, this function uses a try / except,
and in the case of failure while check for run status (to continue or not) and clear object
of any interim results.
"""

try:

# If not set to fail on NaN or Inf data at add time, check data here
# This serves as a catch all for curve_fits which will fail given NaN or Inf
# Because FitError's are by default caught, this allows fitting to continue
if not self.data.checks['data']:
if np.any(np.isinf(self.data.power_spectrum)) or \
np.any(np.isnan(self.data.power_spectrum)):
raise FitError("Model fitting was skipped because there are NaN or Inf "
"values in the data, which preclude model fitting.")

# Call the fit function from the algorithm object
self.algorithm._fit()

# Do any parameter conversions
self._convert_params()

# Compute post-fit metrics
self.results.metrics.compute_metrics(self.data, self.results)

except FitError:

# If in debug mode, re-raise the error
if self.algorithm._debug:
raise

# Clear any interim model results that may have run
# Partial model results shouldn't be interpreted in light of overall failure
self.results._reset_results(True)

# Print out status
if self.verbose:
print("Model fitting was unsuccessful.")


def _convert_params(self):
"""Convert fit parameters."""

Expand Down
2 changes: 1 addition & 1 deletion specparam/results/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _par_fit_group(power_spectrum, group):
"""Function to partialize for running in parallel - group."""

group._pass_through_spectrum(power_spectrum)
group.algorithm._fit()
group._fit()

return group.results._get_results()

Expand Down
30 changes: 26 additions & 4 deletions specparam/tests/models/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ def test_fg_fail():
assert np.isnan(ntfg.results.get_params('aperiodic', 'exponent')[null_ind])
assert np.isnan(ntfg.results.get_metrics('error', 'mae')[null_ind])

# Test that fit failures are caught & continued when running in parallel
ntfg2 = ntfg.copy()
ntfg2.fit(fs, ps, n_jobs=2)
assert ntfg2.results.n_null > 0

def test_drop():
"""Test function to drop results from group object."""

Expand Down Expand Up @@ -226,11 +231,28 @@ def test_fit_par():

tfg = SpectralGroupModel(verbose=False)
tfg.fit(xs, ys, n_jobs=2)
out = tfg.results.get_results()

assert out
assert len(out) == n_spectra
assert np.all(out[1].aperiodic_fit)
assert len(tfg.results.get_results()) == n_spectra

aps = tfg.get_params('aperiodic')
assert aps.shape == (n_spectra, tfg.modes.aperiodic.n_params)
assert np.all(~np.isnan(aps))

pes = tfg.get_params('periodic')
assert pes.shape == (sum(tfg.results.n_peaks), tfg.modes.periodic.n_params + 1)
assert np.all(~np.isnan(pes))

peaks = tfg.get_params('peak')
assert peaks.shape == (sum(tfg.results.n_peaks), tfg.modes.periodic.n_params + 1)
assert np.all(~np.isnan(peaks))

errs = tfg.get_metrics('error')
assert np.all(~np.isnan(errs))
assert len(errs) == n_spectra

gofs = tfg.get_metrics('gof')
assert np.all(~np.isnan(gofs))
assert len(gofs) == n_spectra

def test_print(tfg):
"""Check print method (alias)."""
Expand Down
Loading