Skip to content

Remove sigmoid in BinaryPrecisionRecallCurve #3182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed unintended `sigmoid` normalization in `BinaryPrecisionRecallCurve`, which might cause incorrect ROC computation when updating incrementally with raw logits. ([#3179](https://github.com/Lightning-AI/torchmetrics/pull/3179))


---
Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def update(self, preds: Tensor, target: Tensor) -> None:
"""Update metric states."""
if self.validate_args:
_binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index)
preds, target, _ = _binary_precision_recall_curve_format(preds, target, self.thresholds, self.ignore_index)
preds, target, _ = _binary_precision_recall_curve_format(
preds, target, self.thresholds, self.ignore_index, None
)
state = _binary_precision_recall_curve_update(preds, target, self.thresholds)
if isinstance(state, Tensor):
self.confmat += state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def _binary_precision_recall_curve_format(
target: Tensor,
thresholds: Optional[Union[int, list[float], Tensor]] = None,
ignore_index: Optional[int] = None,
normalization: Optional[str] = "sigmoid",
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
"""Convert all input to the right format.

Expand All @@ -182,7 +183,7 @@ def _binary_precision_recall_curve_format(
preds = preds[idx]
target = target[idx]

preds = normalize_logits_if_needed(preds, "sigmoid")
preds = normalize_logits_if_needed(preds, normalization)

thresholds = _adjust_threshold_arg(thresholds, preds.device)
return preds, target, thresholds
Expand Down
5 changes: 4 additions & 1 deletion src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor:
return m[indices] * x + b[indices]


def normalize_logits_if_needed(tensor: Tensor, normalization: Literal["sigmoid", "softmax"]) -> Tensor:
def normalize_logits_if_needed(tensor: Tensor, normalization: Optional[Literal["sigmoid", "softmax"]]) -> Tensor:
"""Normalize logits if needed.

If input tensor is outside the [0,1] we assume that logits are provided and apply the normalization.
Expand All @@ -228,6 +228,9 @@ def normalize_logits_if_needed(tensor: Tensor, normalization: Literal["sigmoid",
tensor([0.0000, 0.5000, 1.0000])

"""
# if not specified, do nothing.
if not normalization:
return tensor
# decrease sigmoid on cpu .
if tensor.device == torch.device("cpu"):
if not torch.all((tensor >= 0) * (tensor <= 1)):
Expand Down
Loading