diff --git a/specparam/models/model.py b/specparam/models/model.py index a4666e3c..cadcf5ac 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -166,40 +166,8 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, prechecks=True): if prechecks: self.algorithm._fit_prechecks(self.verbose) - # In rare cases, the model fails to fit, and so uses try / except - try: - - # If not set to fail on NaN or Inf data at add time, check data here - # This serves as a catch all for curve_fits which will fail given NaN or Inf - # Because FitError's are by default caught, this allows fitting to continue - if not self.data.checks['data']: - if np.any(np.isinf(self.data.power_spectrum)) or \ - np.any(np.isnan(self.data.power_spectrum)): - raise FitError("Model fitting was skipped because there are NaN or Inf " - "values in the data, which preclude model fitting.") - - # Call the fit function from the algorithm object - self.algorithm._fit() - - # Do any parameter conversions - self._convert_params() - - # Compute post-fit metrics - self.results.metrics.compute_metrics(self.data, self.results) - - except FitError: - - # If in debug mode, re-raise the error - if self.algorithm._debug: - raise - - # Clear any interim model results that may have run - # Partial model results shouldn't be interpreted in light of overall failure - self.results._reset_results(True) - - # Print out status - if self.verbose: - print("Model fitting was unsuccessful.") + # Call the sub-function to fit the model + post-processing + self._fit() def report(self, freqs=None, power_spectrum=None, freq_range=None, @@ -345,6 +313,52 @@ def to_df(self, bands=None): return model_to_dataframe(self.results.get_results(), self.modes, bands) + def _fit(self): + """"Internal fit function to run the algorithm fit + post processing. + + Notes + ----- + Post-processing steps are parameter conversions & model metric evaluations. + In rare cases, the model fails to fit. To manage this, this function uses a try / except, + and in the case of failure while check for run status (to continue or not) and clear object + of any interim results. + """ + + try: + + # If not set to fail on NaN or Inf data at add time, check data here + # This serves as a catch all for curve_fits which will fail given NaN or Inf + # Because FitError's are by default caught, this allows fitting to continue + if not self.data.checks['data']: + if np.any(np.isinf(self.data.power_spectrum)) or \ + np.any(np.isnan(self.data.power_spectrum)): + raise FitError("Model fitting was skipped because there are NaN or Inf " + "values in the data, which preclude model fitting.") + + # Call the fit function from the algorithm object + self.algorithm._fit() + + # Do any parameter conversions + self._convert_params() + + # Compute post-fit metrics + self.results.metrics.compute_metrics(self.data, self.results) + + except FitError: + + # If in debug mode, re-raise the error + if self.algorithm._debug: + raise + + # Clear any interim model results that may have run + # Partial model results shouldn't be interpreted in light of overall failure + self.results._reset_results(True) + + # Print out status + if self.verbose: + print("Model fitting was unsuccessful.") + + def _convert_params(self): """Convert fit parameters.""" diff --git a/specparam/results/utils.py b/specparam/results/utils.py index f8aa27d1..a37941f6 100644 --- a/specparam/results/utils.py +++ b/specparam/results/utils.py @@ -49,7 +49,7 @@ def _par_fit_group(power_spectrum, group): """Function to partialize for running in parallel - group.""" group._pass_through_spectrum(power_spectrum) - group.algorithm._fit() + group._fit() return group.results._get_results() diff --git a/specparam/tests/models/test_group.py b/specparam/tests/models/test_group.py index 2b72e875..d3b0a29e 100644 --- a/specparam/tests/models/test_group.py +++ b/specparam/tests/models/test_group.py @@ -183,6 +183,11 @@ def test_fg_fail(): assert np.isnan(ntfg.results.get_params('aperiodic', 'exponent')[null_ind]) assert np.isnan(ntfg.results.get_metrics('error', 'mae')[null_ind]) + # Test that fit failures are caught & continued when running in parallel + ntfg2 = ntfg.copy() + ntfg2.fit(fs, ps, n_jobs=2) + assert ntfg2.results.n_null > 0 + def test_drop(): """Test function to drop results from group object.""" @@ -226,11 +231,28 @@ def test_fit_par(): tfg = SpectralGroupModel(verbose=False) tfg.fit(xs, ys, n_jobs=2) - out = tfg.results.get_results() - assert out - assert len(out) == n_spectra - assert np.all(out[1].aperiodic_fit) + assert len(tfg.results.get_results()) == n_spectra + + aps = tfg.get_params('aperiodic') + assert aps.shape == (n_spectra, tfg.modes.aperiodic.n_params) + assert np.all(~np.isnan(aps)) + + pes = tfg.get_params('periodic') + assert pes.shape == (sum(tfg.results.n_peaks), tfg.modes.periodic.n_params + 1) + assert np.all(~np.isnan(pes)) + + peaks = tfg.get_params('peak') + assert peaks.shape == (sum(tfg.results.n_peaks), tfg.modes.periodic.n_params + 1) + assert np.all(~np.isnan(peaks)) + + errs = tfg.get_metrics('error') + assert np.all(~np.isnan(errs)) + assert len(errs) == n_spectra + + gofs = tfg.get_metrics('gof') + assert np.all(~np.isnan(gofs)) + assert len(gofs) == n_spectra def test_print(tfg): """Check print method (alias)."""