-
Notifications
You must be signed in to change notification settings - Fork 457
Description
🚀 Feature
Add a parameter to MulticlassStatScores
and MultilabelStatScores
to control which classes/labels to include in averaging.
Motivation
Sklearn's precision_recall_fscore_support
allows users to define the labels used in averaging the computed metrics (as well as the order if the metrics are not averaged). This allows calculating "a multiclass average ignoring a majority negative class". E.g. in my use-case (sequence tagging), I do want to consider datapoints which have an out-tag ("O", meaning they are not tagged), as they might contribute to the false positives of other classes. Hence, ignore_index
is not sufficient, as the datapoints would be completely excluded.
Pitch
Add a parameter classes
to MulticlassStatScores
and labels
to MultilabelStatScores
to limit the calculation of true positives, fp, fn, and tn to these classes/labels. The resulting averages (e.g. f1-score, accuracy) would then be an average only of the selected classes/labels.
If the classes
/ labels
parameter is specified, num_classes
/ num_labels
would not need to be set (or if they are set and do not agree with the passed number of classes/labels, an Exception would need to be raised).
Alternatives
Currently, I am using a very hacky solution:
def metric_with_certain_labels_only(
metric_type: Union[Type[MulticlassStatScores], Type[MultilabelStatScores]],
included_labels: torch.Tensor,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
**kwargs,
):
metric_average = average
if average == "micro":
metric_average = "none"
metric = metric_type(average=metric_average, **kwargs)
_final_state_inner = metric._final_state
def _final_state_wrapper():
state = _final_state_inner()
# manipulate the state variable
new_state = (s[torch.tensor(included_labels)] for s in state)
return new_state
metric._final_state = _final_state_wrapper # type: ignore
if average == "micro":
compute_inner = metric.compute
def compute_wrapper():
metric.average = "micro"
result = compute_inner()
metric.average = "none"
return result
metric.compute = compute_wrapper # type: ignore
return metric
I am not happy with this solution for two reasons:
- It would be nicer, if included_labels would be part of the metric init signature directly or (at least) would be provided in a wrapper. It is possible to rewrite this helper function as a Wrapper, however this would require changing the averaging on an already created metric (and recreating an new state), and (more importantly)
- since the classes/labels are selected on the read-out of the state (which requires fewer changes), relevant stats need to be tracked for all classes (even the once which are not included). This is especially inefficient in case
average
is set to"micro"
.
Ideally, the stats would be reduced to the selected classes already in _multiclass_stat_scores_update
/ _multilabel_stat_scores_update
Additional context
I had already opened a discussion. However, I believe this cannot be solved without a new feature.