Skip to content

Parameter of MulticlassStatScores and MultilabelStatScores to control which classes/labels to include the averages #1723

@plonerma

Description

@plonerma

🚀 Feature

Add a parameter to MulticlassStatScores and MultilabelStatScores to control which classes/labels to include in averaging.

Motivation

Sklearn's precision_recall_fscore_support allows users to define the labels used in averaging the computed metrics (as well as the order if the metrics are not averaged). This allows calculating "a multiclass average ignoring a majority negative class". E.g. in my use-case (sequence tagging), I do want to consider datapoints which have an out-tag ("O", meaning they are not tagged), as they might contribute to the false positives of other classes. Hence, ignore_index is not sufficient, as the datapoints would be completely excluded.

Pitch

Add a parameter classes to MulticlassStatScores and labels to MultilabelStatScores to limit the calculation of true positives, fp, fn, and tn to these classes/labels. The resulting averages (e.g. f1-score, accuracy) would then be an average only of the selected classes/labels.

If the classes / labels parameter is specified, num_classes / num_labels would not need to be set (or if they are set and do not agree with the passed number of classes/labels, an Exception would need to be raised).

Alternatives

Currently, I am using a very hacky solution:

def metric_with_certain_labels_only(
    metric_type: Union[Type[MulticlassStatScores], Type[MultilabelStatScores]],
    included_labels: torch.Tensor,
    average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
    **kwargs,
):
    metric_average = average

    if average == "micro":
        metric_average = "none"

    metric = metric_type(average=metric_average, **kwargs)

    _final_state_inner = metric._final_state

    def _final_state_wrapper():
        state = _final_state_inner()

        # manipulate the state variable
        new_state = (s[torch.tensor(included_labels)] for s in state)

        return new_state

    metric._final_state = _final_state_wrapper  # type: ignore

    if average == "micro":
        compute_inner = metric.compute

        def compute_wrapper():
            metric.average = "micro"
            result = compute_inner()
            metric.average = "none"
            return result

        metric.compute = compute_wrapper  # type: ignore

    return metric

I am not happy with this solution for two reasons:

  1. It would be nicer, if included_labels would be part of the metric init signature directly or (at least) would be provided in a wrapper. It is possible to rewrite this helper function as a Wrapper, however this would require changing the averaging on an already created metric (and recreating an new state), and (more importantly)
  2. since the classes/labels are selected on the read-out of the state (which requires fewer changes), relevant stats need to be tracked for all classes (even the once which are not included). This is especially inefficient in case average is set to "micro".

Ideally, the stats would be reduced to the selected classes already in _multiclass_stat_scores_update / _multilabel_stat_scores_update

Additional context

I had already opened a discussion. However, I believe this cannot be solved without a new feature.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions