diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index e2e42a48..efa5719c 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -33,12 +33,15 @@ class Algorithm(): Data object with spectral data and metadata. results : Results Results object with model fit results and metrics. + model : SpectralModel, optional + The model object this object is linked to, to provide access to other attributes. debug : bool Whether to run in debug state, raising an error if encountered during fitting. """ def __init__(self, name, description, public_settings, private_settings=None, - data_format='spectrum', modes=None, data=None, results=None, debug=False): + data_format='spectrum', modes=None, data=None, results=None, model=None, + debug=False): """Initialize Algorithm object.""" self.name = name @@ -66,6 +69,8 @@ def __init__(self, name, description, public_settings, private_settings=None, self.set_debug(debug) + self._model = model + def _fit_prechecks(self, verbose): """Pre-checks to run before the fit function - if are some, overload this function.""" @@ -178,13 +183,14 @@ class AlgorithmCF(Algorithm): """ def __init__(self, name, description, public_settings, private_settings=None, - data_format='spectrum', modes=None, data=None, results=None, debug=False): + data_format='spectrum', modes=None, data=None, results=None, + model=None, debug=False): """Initialize Algorithm object.""" Algorithm.__init__(self, name=name, description=description, public_settings=public_settings, private_settings=private_settings, data_format=data_format, modes=modes, data=data, results=results, - debug=debug) + model=model, debug=debug) self._cf_settings_desc = CURVE_FIT_SETTINGS self._cf_settings = SettingsValues(self._cf_settings_desc.names) diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index 328f0f33..c2be3fc2 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -106,7 +106,7 @@ class SpectralFitAlgorithm(AlgorithmCF): def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, peak_threshold=2.0, ap_percentile_thresh=0.025, ap_guess=None, ap_bounds=None, cf_bound=1.5, bw_std_edge=1.0, gauss_overlap_thresh=0.75, maxfev=5000, - tol=0.00001, modes=None, data=None, results=None, debug=False): + tol=0.00001, modes=None, data=None, results=None, model=None, debug=False): """Initialize base model object""" # Initialize base algorithm object with algorithm metadata @@ -115,7 +115,7 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h description='Original parameterizing neural power spectra algorithm.', public_settings=SPECTRAL_FIT_SETTINGS_DEF, private_settings=SPECTRAL_FIT_PRIVATE_SETTINGS_DEF, - modes=modes, data=data, results=results, debug=debug) + modes=modes, data=data, results=results, model=model, debug=debug) ## Public settings self.settings.peak_width_limits = peak_width_limits diff --git a/specparam/data/data.py b/specparam/data/data.py index 7f0d94af..d7f43803 100644 --- a/specparam/data/data.py +++ b/specparam/data/data.py @@ -7,6 +7,7 @@ from specparam.sim.gen import gen_freqs from specparam.data import SpectrumMetaData, ModelChecks +from specparam.utils.array import unlog from specparam.utils.spectral import trim_spectrum from specparam.utils.checks import check_input_options from specparam.reports.strings import gen_data_str @@ -36,6 +37,8 @@ class Data(): Whether to check the spectral data. If so, raises an error for any NaN / Inf values. format : {'power'} The representation format of the data. + model : SpectralModel, optional + The model object this object is linked to, to provide access to other attributes. Attributes ---------- @@ -55,7 +58,7 @@ class Data(): All power values are stored internally in log10 scale. """ - def __init__(self, check_freqs=True, check_data=True, format='power'): + def __init__(self, check_freqs=True, check_data=True, format='power', model=None): """Initialize Data object.""" self._reset_data(True, True) @@ -70,6 +73,7 @@ def __init__(self, check_freqs=True, check_data=True, format='power'): check_input_options(format, FORMATS, 'format') self.format = format + self._model = model @property def has_data(self): @@ -154,6 +158,54 @@ def get_meta_data(self): return SpectrumMetaData(**{key : getattr(self, key) for key in self._meta_fields}) + def get_data(self, component='full', space='log'): + """Get a data component. + + Parameters + ---------- + component : {'full', 'aperiodic', 'peak'} + Which data component to return. + 'full' - full power spectrum + 'aperiodic' - isolated aperiodic data component + 'peak' - isolated peak data component + space : {'log', 'linear'} + Which space to return the data component in. + 'log' - returns in log10 space. + 'linear' - returns in linear space. + + Returns + ------- + output : 1d array + Specified data component, in specified spacing. + + Notes + ----- + The 'space' parameter doesn't just define the spacing of the data component + values, but rather defines the space of the additive data definition such that + `power_spectrum = aperiodic_component + peak_component`. + With space set as 'log', this combination holds in log space. + With space set as 'linear', this combination holds in linear space. + """ + + if not self.has_data: + raise NoDataError("No data available to fit, can not proceed.") + assert space in ['linear', 'log'], "Input for 'space' invalid." + + if component == 'full': + output = self.power_spectrum if space == 'log' \ + else unlog(self.power_spectrum) + elif component == 'aperiodic': + output = self._model.results.model._spectrum_peak_rm if space == 'log' else \ + unlog(self.power_spectrum) / unlog(self._model.results.model._peak_fit) + elif component == 'peak': + output = self._model.results.model._spectrum_flat if space == 'log' else \ + unlog(self.power_spectrum) - unlog(self._model.results.model._ap_fit) + else: + raise ValueError('Input for component invalid.') + + return output + + def plot(self, plt_log=False, **plt_kwargs): """Plot the power spectrum.""" @@ -339,10 +391,10 @@ class Data2D(Data): All power values are stored internally in log10 scale. """ - def __init__(self): + def __init__(self, *args, **kwargs): """Initialize Data2D object.""" - Data.__init__(self) + Data.__init__(self, *args, **kwargs) self.power_spectra = None @@ -451,10 +503,10 @@ class Data2DT(Data2D): All power values are stored internally in log10 scale. """ - def __init__(self): + def __init__(self, *args, **kwargs): """Initialize Data2DT object.""" - Data2D.__init__(self) + Data2D.__init__(self, *args, **kwargs) @property @@ -521,10 +573,10 @@ class Data3D(Data2DT): All power values are stored internally in log10 scale. """ - def __init__(self): + def __init__(self, *args, **kwargs): """Initialize Data3D object.""" - Data2DT.__init__(self) + Data2DT.__init__(self, *args, **kwargs) self.spectrograms = None diff --git a/specparam/models/base.py b/specparam/models/base.py index 0981e2a2..d06c604f 100644 --- a/specparam/models/base.py +++ b/specparam/models/base.py @@ -2,7 +2,6 @@ from copy import deepcopy -from specparam.utils.array import unlog from specparam.utils.checks import check_array_dim from specparam.modes.modes import Modes from specparam.modutils.errors import NoDataError @@ -56,7 +55,7 @@ def add_modes(self, aperiodic_mode, periodic_mode): Mode for periodic component, or string specifying which mode to use. """ - self.modes = Modes(aperiodic=aperiodic_mode, periodic=periodic_mode) + self.modes = Modes(aperiodic=aperiodic_mode, periodic=periodic_mode, model=self) if getattr(self, 'results', None): self.results.modes = self.modes @@ -66,54 +65,6 @@ def add_modes(self, aperiodic_mode, periodic_mode): self.algorithm._reset_subobjects(modes=self.modes, results=self.results) - def get_data(self, component='full', space='log'): - """Get a data component. - - Parameters - ---------- - component : {'full', 'aperiodic', 'peak'} - Which data component to return. - 'full' - full power spectrum - 'aperiodic' - isolated aperiodic data component - 'peak' - isolated peak data component - space : {'log', 'linear'} - Which space to return the data component in. - 'log' - returns in log10 space. - 'linear' - returns in linear space. - - Returns - ------- - output : 1d array - Specified data component, in specified spacing. - - Notes - ----- - The 'space' parameter doesn't just define the spacing of the data component - values, but rather defines the space of the additive data definition such that - `power_spectrum = aperiodic_component + peak_component`. - With space set as 'log', this combination holds in log space. - With space set as 'linear', this combination holds in linear space. - """ - - if not self.data.has_data: - raise NoDataError("No data available to fit, can not proceed.") - assert space in ['linear', 'log'], "Input for 'space' invalid." - - if component == 'full': - output = self.data.power_spectrum if space == 'log' \ - else unlog(self.data.power_spectrum) - elif component == 'aperiodic': - output = self.results.model._spectrum_peak_rm if space == 'log' else \ - unlog(self.data.power_spectrum) / unlog(self.results.model._peak_fit) - elif component == 'peak': - output = self.results.model._spectrum_flat if space == 'log' else \ - unlog(self.data.power_spectrum) - unlog(self.results.model._ap_fit) - else: - raise ValueError('Input for component invalid.') - - return output - - def print_settings(self, description=False, concise=False): """Print out the current settings. diff --git a/specparam/models/event.py b/specparam/models/event.py index 23a5d745..a8e149ca 100644 --- a/specparam/models/event.py +++ b/specparam/models/event.py @@ -57,11 +57,12 @@ def __init__(self, *args, **kwargs): verbose=kwargs.pop('verbose', True), **kwargs) - self.data = Data3D() + self.data = Data3D(model=self) self.results = Results3D(modes=self.modes, metrics=kwargs.pop('metrics', None), - bands=kwargs.pop('bands', None)) + bands=kwargs.pop('bands', None), + model=self) self.algorithm._reset_subobjects(data=self.data, results=self.results) diff --git a/specparam/models/group.py b/specparam/models/group.py index ab617622..0d0e78e9 100644 --- a/specparam/models/group.py +++ b/specparam/models/group.py @@ -60,11 +60,12 @@ def __init__(self, *args, **kwargs): verbose=kwargs.pop('verbose', True), **kwargs) - self.data = Data2D() + self.data = Data2D(model=self) self.results = Results2D(modes=self.modes, metrics=kwargs.pop('metrics', None), - bands=kwargs.pop('bands', None)) + bands=kwargs.pop('bands', None), + model=self) self.algorithm._reset_subobjects(data=self.data, results=self.results) diff --git a/specparam/models/model.py b/specparam/models/model.py index a4666e3c..05423125 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -95,14 +95,14 @@ def __init__(self, aperiodic_mode='fixed', periodic_mode='gaussian', update_converters(DEFAULT_CONVERTERS, converters) BaseModel.__init__(self, aperiodic_mode, periodic_mode, converters, verbose) - self.data = Data() + self.data = Data(model=self) - self.results = Results(modes=self.modes, metrics=metrics, bands=bands) + self.results = Results(modes=self.modes, metrics=metrics, bands=bands, model=self) algorithm_settings = {} if algorithm_settings is None else algorithm_settings self.algorithm = check_algorithm_definition(algorithm, ALGORITHMS)( **algorithm_settings, modes=self.modes, data=self.data, - results=self.results, debug=debug, **model_kwargs) + results=self.results, debug=debug, model=self, **model_kwargs) @replace_docstring_sections([docs_get_section(Data.add_data.__doc__, 'Parameters'), diff --git a/specparam/models/time.py b/specparam/models/time.py index 3c7e30d6..d50764a0 100644 --- a/specparam/models/time.py +++ b/specparam/models/time.py @@ -53,11 +53,12 @@ def __init__(self, *args, **kwargs): verbose=kwargs.pop('verbose', True), **kwargs) - self.data = Data2DT() + self.data = Data2DT(model=self) self.results = Results2DT(modes=self.modes, metrics=kwargs.pop('metrics', None), - bands=kwargs.pop('bands', None)) + bands=kwargs.pop('bands', None), + model=self) self.algorithm._reset_subobjects(data=self.data, results=self.results) diff --git a/specparam/modes/modes.py b/specparam/modes/modes.py index 82ef2281..5f4bf707 100644 --- a/specparam/modes/modes.py +++ b/specparam/modes/modes.py @@ -17,9 +17,11 @@ class Modes(): Aperiodic mode. periodic : str or Mode Periodic mode. + model : SpectralModel, optional + The model object this object is linked to, to provide access to other attributes. """ - def __init__(self, aperiodic, periodic): + def __init__(self, aperiodic, periodic, model=None): """Initialize modes.""" # Set list of component names @@ -29,6 +31,8 @@ def __init__(self, aperiodic, periodic): self.aperiodic = check_mode_definition(aperiodic, AP_MODES) self.periodic = check_mode_definition(periodic, PE_MODES) + self.model = model + def check_params(self): """Check the description of the parameters for each mode.""" diff --git a/specparam/results/results.py b/specparam/results/results.py index 4d1c7c0e..da504229 100644 --- a/specparam/results/results.py +++ b/specparam/results/results.py @@ -38,6 +38,8 @@ class Results(): Metrics object with metric definitions. bands : Bands or dict or int or None Bands object with band definitions, or definition that can be turned into a Bands object. + model : SpectralModel, optional + The model object this object is linked to, to provide access to other attributes. Attributes ---------- @@ -54,7 +56,7 @@ class Results(): """ # pylint: disable=attribute-defined-outside-init, arguments-differ - def __init__(self, modes=None, metrics=None, bands=None): + def __init__(self, modes=None, metrics=None, bands=None, model=None): """Initialize Results object.""" self.modes = modes if modes else Modes(None, None) @@ -68,6 +70,8 @@ def __init__(self, modes=None, metrics=None, bands=None): # Initialize results attributes self._reset_results(True) + self._model = model + @property def has_model(self): @@ -258,10 +262,10 @@ class Results2D(Results): Results of the model fit for each power spectrum. """ - def __init__(self, modes=None, metrics=None, bands=None): + def __init__(self, modes=None, metrics=None, bands=None, model=None): """Initialize Results2D object.""" - Results.__init__(self, modes=modes, metrics=metrics, bands=bands) + Results.__init__(self, modes=modes, metrics=metrics, bands=bands, model=model) self._reset_group_results() @@ -437,10 +441,10 @@ class Results2DT(Results2D): Results of the model fit across each time window. """ - def __init__(self, modes=None, metrics=None, bands=None): + def __init__(self, modes=None, metrics=None, bands=None, model=None): """Initialize Results2DT object.""" - Results2D.__init__(self, modes=modes, metrics=metrics, bands=bands) + Results2D.__init__(self, modes=modes, metrics=metrics, bands=bands, model=model) self._reset_time_results() @@ -506,10 +510,10 @@ class Results3D(Results2DT): Each value in the dictionary stores a model fit parameter, as [n_events, n_time_windows]. """ - def __init__(self, modes=None, metrics=None, bands=None): + def __init__(self, modes=None, metrics=None, bands=None, model=None): """Initialize Results3D object.""" - Results2DT.__init__(self, modes=modes, metrics=metrics, bands=bands) + Results2DT.__init__(self, modes=modes, metrics=metrics, bands=bands, model=model) self._reset_event_results() diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index 9639e9b4..738843de 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -145,7 +145,7 @@ def test_fit_null_conversions(tfm): null_converters = tfm.modes.get_params('dict') ntfm = SpectralModel(converters=null_converters) - ntfm.fit(tfm.data.freqs, tfm.get_data('full', 'linear')) + ntfm.fit(tfm.data.freqs, tfm.data.get_data('full', 'linear')) assert np.all(np.isnan(ntfm.results.get_params('aperiodic', version='converted'))) assert np.all(np.isnan(ntfm.results.get_params('periodic', version='converted'))) @@ -154,7 +154,7 @@ def test_fit_custom_conversions(tfm): converters = {'periodic' : {'pw' : 'lin_sub'}} ntfm = SpectralModel(converters=converters) - ntfm.fit(tfm.data.freqs, tfm.get_data('full', 'linear')) + ntfm.fit(tfm.data.freqs, tfm.data.get_data('full', 'linear')) assert not np.array_equal( tfm.results.get_params('periodic', 'pw'), ntfm.results.get_params('periodic', 'pw')) @@ -295,7 +295,7 @@ def test_get_data(tfm): for comp in ['full', 'aperiodic', 'peak']: for space in ['log', 'linear']: - assert isinstance(tfm.get_data(comp, space), np.ndarray) + assert isinstance(tfm.data.get_data(comp, space), np.ndarray) def test_get_component(tfm):