Skip to content

Commit ca3d352

Browse files
authored
Merge pull request #54 from Project-MONAI/docs_tests
Docs tests
2 parents 9d08ebc + a88e6ed commit ca3d352

File tree

10 files changed

+1465
-230
lines changed

10 files changed

+1465
-230
lines changed

MetricsReloaded/metrics/calibration_measures.py

Lines changed: 146 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import numpy as np
3232
import math
3333
from scipy.special import gamma
34-
34+
import warnings
3535
# from metrics.pairwise_measures import CacheFunctionOutput
3636
from MetricsReloaded.utility.utils import (
3737
CacheFunctionOutput,
@@ -91,7 +91,7 @@ def class_wise_expectation_calibration_error(self):
9191
9292
cwECE = \dfrac{1}{K}\sum_{k=1}^{K}\sum_{i=1}^{N}\dfrac{\vert B_{i,k} \vert}{N} \left(y_{k}(B_{i,k}) - p_{k}(B_{i,k})\right)
9393
94-
94+
:return: cwece
9595
"""
9696

9797
if "bins_ece" in self.dict_args:
@@ -138,11 +138,22 @@ def expectation_calibration_error(self):
138138
"""
139139
Derives the expectation calibration error in the case of binary task
140140
bins_ece is the key in the dictionary for the number of bins to consider
141+
Cheat sheet SN 3.68 p113
142+
Defined in Mahdi Pakdaman Naeini, Gregory Cooper, and Milos Hauskrecht. Obtaining well calibrated probabilities using
143+
bayesian binning. In Twenty-Ninth AAAI Conference on Artificial Intelligence, 2015.
141144
Default is 10
145+
146+
.. math::
147+
148+
ECE = \sum_{m=1}^{M} \dfrac{|B_m|}{n}(\dfrac{1}{|B_m|}\sum_{i \in B_m}1(pred_ik==ref_ik)-\dfrac{1}{|B_m|}\sum_{i \in B_m}pred_i)
149+
150+
:return: ece
151+
142152
"""
143153
if "bins_ece" in self.dict_args:
144154
nbins = self.dict_args["bins_ece"]
145155
else:
156+
warnings.warn("Bins ECE not specified in optional arguments dictionary - default set to 10")
146157
nbins = 10
147158
step = 1.0 / nbins
148159
range_values = np.arange(0, 1.00001, step)
@@ -169,7 +180,55 @@ def expectation_calibration_error(self):
169180
else:
170181
list_values.append(nsamples * np.abs(prop - np.mean(pred_sel)))
171182
numb_samples += nsamples
172-
return np.sum(np.asarray(list_values)) / numb_samples
183+
ece = np.sum(np.asarray(list_values)) / numb_samples
184+
return ece
185+
186+
187+
def maximum_calibration_error(self):
188+
"""
189+
Derives the maximum calibration error in the case of binary task
190+
bins_mce is the key in the dictionary for the number of bins to consider
191+
Default is 10
192+
193+
.. math::
194+
195+
MCE = max(|\dfrac{1}{|B_m|}\sum_{i \in B_m}1(pred_ik==ref_ik)-\dfrac{1}{|B_m|}\sum_{i \in B_m}pred_i|)
196+
197+
:return: mce
198+
199+
"""
200+
if "bins_mce" in self.dict_args:
201+
nbins = self.dict_args["bins_mce"]
202+
else:
203+
warnings.warn("Bins MCE not specified in optional arguments dictionary - default set to 10")
204+
nbins = 10
205+
step = 1.0 / nbins
206+
range_values = np.arange(0, 1.00001, step)
207+
list_values = []
208+
numb_samples = 0
209+
pred_prob = self.pred[:,1]
210+
for (l, u) in zip(range_values[:-1], range_values[1:]):
211+
ref_tmp = np.where(
212+
np.logical_and(pred_prob > l, pred_prob <= u),
213+
self.ref,
214+
np.ones_like(self.ref) * -1,
215+
)
216+
ref_sel = ref_tmp[ref_tmp > -1]
217+
nsamples = np.size(ref_sel)
218+
prop = np.sum(ref_sel) / nsamples
219+
pred_tmp = np.where(
220+
np.logical_and(pred_prob > l, pred_prob <= u),
221+
pred_prob,
222+
np.ones_like(pred_prob) * -1,
223+
)
224+
pred_sel = pred_tmp[pred_tmp > -1]
225+
if nsamples == 0:
226+
list_values.append(0)
227+
else:
228+
list_values.append(np.abs(prop - np.mean(pred_sel)))
229+
mce = np.max(np.asarray(list_values))
230+
return mce
231+
173232

174233
def brier_score(self):
175234
"""
@@ -179,22 +238,44 @@ def brier_score(self):
179238
Glenn W Brier et al. 1950. Verification of forecasts expressed in terms of probability. Monthly weather review 78, 1
180239
(1950), 1–3.
181240
241+
.. math::
242+
243+
BS = \dfrac{1}{N}\sum_{i=1}{N}\sum_{j=1}^{C}(p_{ic}-r_{ic})^2
244+
245+
where :math: `p_{ic}` is the probability for class c and :math: `r_{ic}` the binary reference for class c and element i
246+
182247
:return: brier score (BS)
248+
183249
"""
184250
bs = np.mean(np.sum(np.square(self.one_hot_ref - self.pred),1))
185251
return bs
186252

187253
def root_brier_score(self):
188254
"""
255+
Determines the root brier score
256+
189257
Gruber S. and Buettner F., Better Uncertainty Calibration via Proper Scores
190258
for Classification and Beyond, In Proceedings of the 36th International
191259
Conference on Neural Information Processing Systems, 2022
260+
261+
.. math::
262+
263+
RBS = \sqrt{BS}
264+
265+
:return: rbs
192266
"""
193-
return np.sqrt(self.brier_score())
267+
rbs = np.sqrt(self.brier_score())
268+
return rbs
194269

195270
def logarithmic_score(self):
196271
"""
197272
Calculation of the logarithmic score https://en.wikipedia.org/wiki/Scoring_rule
273+
274+
.. math::
275+
276+
LS = 1/N\sum_{i=1}^{N}\log{pred_ik}ref_{ik}
277+
278+
:return: ls
198279
"""
199280
eps = 1e-10
200281
log_pred = np.log(self.pred + eps)
@@ -204,27 +285,48 @@ def logarithmic_score(self):
204285
return ls
205286

206287
def distance_ij(self,i,j):
288+
"""
289+
Determines the euclidean distance between two vectors of prediction for two samples i and j
290+
291+
:return: distance
292+
"""
207293
pred_i = self.pred[i,:]
208294
pred_j = self.pred[j,:]
209295
distance = np.sqrt(np.sum(np.square(pred_i - pred_j)))
210296
return distance
211297

212298

213299
def kernel_calculation(self, i,j):
300+
"""
301+
Defines the kernel value for two samples i and j with the following definition for k(x_i,x_j)
302+
303+
.. math::
304+
305+
k(x_i,x_j) = exp(-||x_i-y_j||/ \\nu)I_{N}
306+
307+
where :math: `\\nu` is the bandwith defined as the median heuristic if not specified in the options and N the number of classes
308+
309+
:return: kernel_value
310+
311+
"""
214312
distance = self.distance_ij(i,j)
215313
if 'bandwidth_kce' in self.dict_args.keys():
216314
bandwidth = self.dict_args['bandwidth_kce']
217315
else:
218316
bandwidth = median_heuristic(self.pred)
219317
value = np.exp(-distance/bandwidth)
220-
identity = np.ones([self.pred.shape[1], self.pred.shape[1]])
221-
return value * identity
318+
identity = np.eye(self.pred.shape[1])
319+
kernel_value = value*identity
320+
return kernel_value
222321

223322
def kernel_calibration_error(self):
224323
"""
225324
Based on the paper Widmann, D., Lindsten, F., and Zachariah, D.
226325
Calibration tests in multi-class classification: A unifying framework.
227326
Advances in Neural Information Processing Systems, 32:12257–12267, 2019.
327+
328+
:return: kce
329+
228330
"""
229331
one_hot_ref = one_hot_encode(self.ref, self.pred.shape[1])
230332
numb_samples = self.pred.shape[0]
@@ -246,6 +348,9 @@ def top_label_classification_error(self):
246348
"""
247349
Calculation of the top-label classification error. Assumes pred_proba a matrix K x Numb observations
248350
with probability to be in class k for observation i in position (k,i)
351+
352+
:return: tce
353+
249354
"""
250355
class_max = np.argmax(self.pred, 1)
251356
prob_pred_max = np.max(self.pred, 1)
@@ -271,7 +376,12 @@ def kernel_based_ece(self):
271376
Teodora Popordanoska, Raphael Sayer, and Matthew B Blaschko. 2022. A Consistent and Differentiable Lp Canonical
272377
Calibration Error Estimator. In Advances in Neural Information Processing Systems.
273378
379+
.. math::
380+
381+
ECE\_KDE = 1/N \sum_{j=1}^{N}||\dfrac{\sum_{i \\neq j}k_{Dir}(pred_j,pred_i)ref_i}{\sum_{i \\neq j}k_{Dir}(pred_j,pred_i)} - pred_j ||
382+
274383
:return: ece_kde
384+
275385
"""
276386
ece_kde = 0
277387
one_hot_ref = one_hot_encode(self.ref, self.pred.shape[1])
@@ -298,6 +408,18 @@ def kernel_based_ece(self):
298408
return ece_kde
299409

300410
def gamma_ik(self, i, k):
411+
"""
412+
Definition of gamma value for sample i class k of the predictions
413+
414+
.. math::
415+
416+
gamma_{ik} = \Gamma(pred_{ik}/h + 1)
417+
418+
where h is the bandwidth value set as default to 0.5
419+
420+
:return gamma_ik
421+
422+
"""
301423
pred_ik = self.pred[i, k]
302424
if "bandwidth" in self.dict_args.keys():
303425
h = self.dict_args["bandwidth"]
@@ -308,6 +430,16 @@ def gamma_ik(self, i, k):
308430
return gamma_ik
309431

310432
def dirichlet_kernel(self, j, i):
433+
"""
434+
Calculation of Dirichlet kernel value for predictions of samples i and j
435+
436+
.. math::
437+
438+
k_{Dir}(x_j,x_i) = \dfrac{\Gamma(\sum_{k=1}^{K}\\alpha_{ik})}{\prod_{k=1}^{K}\\alpha_{ik}}\prod_{k=1}^{K}x_jk^{\\alpha_{ik}-1}
439+
440+
:return: kernel_value
441+
442+
"""
311443
pred_i = self.pred[i, :]
312444
pred_j = self.pred[j, :]
313445
nclasses = self.pred.shape[1]
@@ -331,16 +463,22 @@ def negative_log_likelihood(self):
331463
332464
George Cybenko, Dianne P O’Leary, and Jorma Rissanen. 1998. The Mathematics of Information Coding, Extraction
333465
and Distribution. Vol. 107. Springer Science & Business Media.
466+
Cheat Sheet p 116 - Figure SN 3.71
334467
335468
.. math::
336469
337-
-\sum_{i=1}{N} log(p_{i,k} | y_i=k)
470+
NLL = -\dfrac{1}{N}\sum_{i=1}^{N}\sum_{k=1}^{C} y_{ik} \dot log(p_{i,k})
471+
472+
where :math: `y_{ik}` the outcome is 1 if the class of :math: `y_{i}` is k and :math: `p_{ik}` is the predicted
473+
probability for sample :math: `x_i` and class k
474+
475+
:return: NLL
338476
339477
"""
340478
log_pred = np.log(self.pred)
341479
numb_samples = self.pred.shape[0]
342480
ll = np.sum(log_pred[range(numb_samples), self.ref])
343-
nll = -1 * ll
481+
nll = -1/numb_samples * ll
344482
return nll
345483

346484
def to_dict_meas(self, fmt="{:.4f}"):

0 commit comments

Comments
 (0)