Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion specparam/metrics/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,52 @@
measure='mae',
description='Mean absolute error of the model fit to the data.',
func=compute_mean_abs_error,
space='log',
)

error_mse = Metric(
category='error',
measure='mse',
description='Mean squared error of the model fit to the data.',
func=compute_mean_squared_error
func=compute_mean_squared_error,
space='log',
)

error_rmse = Metric(
category='error',
measure='rmse',
description='Root mean squared error of the model fit to the data.',
func=compute_root_mean_squared_error,
space='log',
)

error_medae = Metric(
category='error',
measure='medae',
description='Median absolute error of the model fit to the data.',
func=compute_median_abs_error,
space='log',
)

error_maelin = Metric(
category='error',
measure='maelin',
description='Mean absolute error of the model fit to the data, in linear space.',
func=compute_mean_abs_error,
space='linear',
)

# Collect available error metrics
ERROR_METRICS = {

# log spacing
'mae' : error_mae,
'mse' : error_mse,
'rmse' : error_rmse,
'medae' : error_medae,

# linear spacing
'maelin' : error_maelin,
}

###################################################################################################
Expand All @@ -53,6 +70,15 @@
measure='rsquared',
description='R-squared between the model fit and the data.',
func=compute_r_squared,
space='log',
)

gof_rsquaredlin = Metric(
category='gof',
measure='rsquaredlin',
description='R-squared between the model fit and the data, in linear space.',
func=compute_r_squared,
space='linear',
)

gof_adjrsquared = Metric(
Expand All @@ -62,12 +88,18 @@
func=compute_adj_r_squared,
kwargs={'n_params' : lambda data, results: \
results.params.periodic.params.size + results.params.aperiodic.params.size},
space='log',
)

# Collect available error metrics
GOF_METRICS = {

# log spacing
'rsquared' : gof_rsquared,
'adjrsquared' : gof_adjrsquared,

# linear spacing
'rsquaredlin' : gof_rsquaredlin,
}

###################################################################################################
Expand Down
10 changes: 8 additions & 2 deletions specparam/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,23 @@ class Metric():
Description of the metric.
func : callable
The function that computes the metric.
space : {'log', 'linear'}
Spacing of the data & model to use for metric evaluation.
kwargs : dictionary
Additional keyword argument to compute the metric.
Each key should be the name of the additional argument.
Each value should be a lambda function that takes 'data' & 'results'
and returns the desired parameter / computed value.
"""

def __init__(self, category, measure, description, func, kwargs=None):
def __init__(self, category, measure, description, func, space='log', kwargs=None):
"""Initialize metric."""

self.category = category
self.measure = measure
self.description = description
self.func = func
self.space = space
self.result = np.nan
self.kwargs = {} if not kwargs else kwargs

Expand Down Expand Up @@ -76,7 +79,10 @@ def compute_metric(self, data, results):
for key, lfunc in self.kwargs.items():
kwargs[key] = lfunc(data, results)

self.result = self.func(data.power_spectrum, results.model.modeled_spectrum, **kwargs)
self.result = self.func(
data.get_data('full', space=self.space),
results.model.get_component('full', space=self.space),
**kwargs)


def reset(self):
Expand Down
2 changes: 1 addition & 1 deletion specparam/tests/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_metric(tfm):
def test_metric_kwargs(tfm):

metric = Metric('gof', 'ar2', 'Description.', compute_adj_r_squared,
{'n_params' : lambda data, results: \
kwargs={'n_params' : lambda data, results: \
results.params.periodic.params.size + results.params.aperiodic.params.size})

assert isinstance(metric, Metric)
Expand Down
Loading