Skip to content
150 changes: 90 additions & 60 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import warnings

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

from monai.networks import one_hot
Expand All @@ -24,7 +25,7 @@ class AsymmetricFocalTverskyLoss(_Loss):
"""
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.

Actually, it's only supported for binary image segmentation now.
It supports both binary and multi-class segmentation.

Reimplementation of the Asymmetric Focal Tversky Loss described in:

Expand All @@ -35,6 +36,7 @@ class AsymmetricFocalTverskyLoss(_Loss):
def __init__(
self,
to_onehot_y: bool = False,
use_softmax: bool = False,
delta: float = 0.7,
gamma: float = 0.75,
epsilon: float = 1e-7,
Expand All @@ -43,30 +45,44 @@ def __init__(
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : stability factor used to avoid division by zero. Defaults to 1e-7.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.use_softmax = use_softmax
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

if self.use_softmax and n_pred_ch == 1:
raise ValueError("single channel prediction with `use_softmax=True` is not allowed.")

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if self.use_softmax:
y_pred = torch.softmax(y_pred, dim=1)
else:
y_pred = torch.sigmoid(y_pred)

if y_pred.shape[1] == 1:
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
y_true = torch.cat([1 - y_true, y_true], dim=1)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

# clip the prediction to avoid NaN
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
# Calculate Loss
axis = list(range(2, len(y_pred.shape)))

# Calculate true positives (tp), false negatives (fn) and false positives (fp)
Expand All @@ -75,9 +91,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)

# Calculate losses separately for each class, enhancing both classes
# Class 0 is Background
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)

# Class 1+ is Foreground
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)

if fore_dice.shape[1] > 1:
fore_dice = torch.mean(fore_dice, dim=1)
else:
fore_dice = fore_dice.squeeze(1)
Comment on lines +97 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential numerical issue when gamma > 1.

When gamma > 1, the exponent 1 - gamma becomes negative. If dice_class[:, 1:] approaches 1.0 (perfect prediction), (1 - dice_class) approaches 0, and 0^negative explodes.

Default gamma=0.75 is safe, but document or validate that gamma <= 1 is expected for this class.

+        if self.gamma > 1:
+            warnings.warn("gamma > 1 may cause numerical instability in AsymmetricFocalTverskyLoss")
+        
         # Class 1+ is Foreground
         fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 97 to 103, the computation
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma) can explode when
gamma > 1 because the exponent becomes negative and values of (1 - dice) near
zero lead to huge results; to fix this, validate or document gamma constraints
and add a guard: either assert or raise a clear ValueError if self.gamma > 1, or
clamp the base with a small epsilon (e.g., base = torch.clamp(1 - dice_class[:,
1:], min=eps)) before the pow, ensuring stable numerical behavior for negative
exponents; include a unit-test or a docstring note that gamma is expected <= 1
if you choose the validation route.


# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
Expand All @@ -88,7 +111,7 @@ class AsymmetricFocalLoss(_Loss):
"""
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.

Actually, it's only supported for binary image segmentation now.
It supports both binary and multi-class segmentation.

Reimplementation of the Asymmetric Focal Loss described in:

Expand All @@ -99,6 +122,7 @@ class AsymmetricFocalLoss(_Loss):
def __init__(
self,
to_onehot_y: bool = False,
use_softmax: bool = False,
delta: float = 0.7,
gamma: float = 2,
epsilon: float = 1e-7,
Expand All @@ -107,46 +131,76 @@ def __init__(
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
epsilon : stability factor used to avoid division by zero. Defaults to 1e-7.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.use_softmax = use_softmax
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

if self.use_softmax and n_pred_ch == 1:
raise ValueError("single channel prediction with `use_softmax=True` is not allowed.")

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

# Save logits for numerical stability in single-channel expansion
y_logits = y_pred

if self.use_softmax:
y_log_pred = F.log_softmax(y_pred, dim=1)
y_pred = torch.exp(y_log_pred)
else:
y_log_pred = F.logsigmoid(y_pred)
y_pred = torch.sigmoid(y_pred)

# Handle Single Channel (Binary) Expansion
if n_pred_ch == 1:
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
bg_log_pred = F.logsigmoid(-y_logits)
y_log_pred = torch.cat([bg_log_pred, y_log_pred], dim=1)
y_true = torch.cat([1 - y_true, y_true], dim=1)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
cross_entropy = -y_true * torch.log(y_pred)
cross_entropy = -y_true * y_log_pred

# Class 0: Background
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
back_ce = (1 - self.delta) * back_ce

fore_ce = cross_entropy[:, 1]
# Class 1+: Foreground
fore_ce = cross_entropy[:, 1:]
fore_ce = self.delta * fore_ce

loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
if fore_ce.shape[1] > 1:
fore_ce = torch.mean(fore_ce, dim=1)
else:
fore_ce = fore_ce.squeeze(1)

loss = torch.mean(torch.stack([back_ce, fore_ce], dim=-1))
return loss


class AsymmetricUnifiedFocalLoss(_Loss):
"""
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.

Actually, it's only supported for binary image segmentation now
It supports both binary and multi-class segmentation.

Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:

Expand All @@ -157,84 +211,60 @@ class AsymmetricUnifiedFocalLoss(_Loss):
def __init__(
self,
to_onehot_y: bool = False,
num_classes: int = 2,
weight: float = 0.5,
gamma: float = 0.5,
use_softmax: bool = False,
delta: float = 0.7,
gamma: float = 2,
weight: float = 0.5,
reduction: LossReduction | str = LossReduction.MEAN,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
num_classes : number of classes, it only supports 2 now. Defaults to 2.
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
weight : weight for each loss function. Defaults to 0.5.

Example:
>>> import torch
>>> from monai.losses import AsymmetricUnifiedFocalLoss
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
>>> pred = torch.randn((1, 3, 32, 32))
>>> grnd = torch.randint(0, 3, (1, 1, 32, 32))
>>> fl = AsymmetricUnifiedFocalLoss(use_softmax=True, to_onehot_y=True)
>>> fl(pred, grnd)
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.num_classes = num_classes
self.weight: float = weight
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.use_softmax = use_softmax
self.asy_focal_loss = AsymmetricFocalLoss(
to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
)
Comment on lines 249 to 254
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Numerical instability: gamma=2 causes explosion in Tversky component.

AsymmetricFocalTverskyLoss computes (1 - dice)^(1 - gamma). With gamma=2 (default here), exponent = -1. When dice approaches 1.0 (near-perfect prediction), this explodes to infinity.

Consider using separate gamma values for each sub-loss or clamping the base in Tversky:

 self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
-    to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
+    to_onehot_y=to_onehot_y, gamma=0.75, delta=self.delta, use_softmax=use_softmax
 )

Or add clamping in AsymmetricFocalTverskyLoss.forward:

fore_dice_base = torch.clamp(1 - dice_class[:, 1:], min=self.epsilon)
fore_dice = torch.pow(fore_dice_base, 1 - self.gamma)
🤖 Prompt for AI Agents
In monai/losses/unified_focal_loss.py around lines 246-248, the
AsymmetricFocalTverskyLoss is instantiated with gamma=self.gamma which can make
the Tversky term explode when gamma>1; modify AsymmetricFocalTverskyLoss instead
to clamp the base before exponentiation (e.g., compute fore_dice_base =
torch.clamp(1 - dice_class[:, 1:], min=self.epsilon)) and then raise to the (1 -
gamma) power, and/or add an optional separate parameter (e.g., tversky_gamma or
epsilon) to the AsymmetricFocalTverskyLoss constructor so callers can pass
different gamma or epsilon values without changing unified_focal_loss.py. Ensure
the forward uses the clamp value and existing epsilon (or new param) before
calling torch.pow to prevent dividing/exponentiating by values near zero.


# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
y_true : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.

Raises:
ValueError: When input and target are different shape
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
ValueError: When num_classes
ValueError: When the number of classes entered does not match the expected number
a sigmoid/softmax in the forward function.
y_true : the shape should be BNH[WD], or B1H[WD] when to_onehot_y=True.
"""
if y_pred.shape != y_true.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")

if y_pred.shape[1] == 1:
y_pred = one_hot(y_pred, num_classes=self.num_classes)
y_true = one_hot(y_true, num_classes=self.num_classes)

if torch.max(y_true) != self.num_classes - 1:
raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")

n_pred_ch = y_pred.shape[1]
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)

loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss

if self.reduction == LossReduction.SUM.value:
return torch.sum(loss) # sum over the batch and channel dims
return torch.sum(loss)
if self.reduction == LossReduction.NONE.value:
return loss # returns [N, num_classes] losses
return loss
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss)
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
27 changes: 20 additions & 7 deletions tests/losses/test_unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,41 @@
from monai.losses import AsymmetricUnifiedFocalLoss

TEST_CASES = [
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
# Case 0: Binary segmentation
[
{},
{
"y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
"y_pred": torch.tensor([[[[100.0, -100.0], [-100.0, 100.0]]], [[[100.0, -100.0], [-100.0, 100.0]]]]),
"y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
},
0.0,
],
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
# Case 1: Same as above but explicit arguments
[
{"use_softmax": False, "to_onehot_y": False},
{
"y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
"y_pred": torch.tensor([[[[100.0, -100.0], [-100.0, 100.0]]], [[[100.0, -100.0], [-100.0, 100.0]]]]),
"y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
},
0.0,
],
# Case 2: Multi-class segmentation
[
{"use_softmax": True, "to_onehot_y": True},
{
"y_pred": torch.tensor([[[[-100.0]], [[-100.0]], [[100.0]]]]).repeat(2, 1, 1, 1),
"y_true": torch.tensor([[[[2]]]]).repeat(2, 1, 1, 1),
},
0.0,
],
]


class TestAsymmetricUnifiedFocalLoss(unittest.TestCase):

@parameterized.expand(TEST_CASES)
def test_result(self, input_data, expected_val):
loss = AsymmetricUnifiedFocalLoss()
def test_result(self, input_param, input_data, expected_val):
loss = AsymmetricUnifiedFocalLoss(**input_param)
result = loss(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)

Expand All @@ -52,7 +65,7 @@ def test_ill_shape(self):

def test_with_cuda(self):
loss = AsymmetricUnifiedFocalLoss()
i = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
i = torch.tensor([[[[100.0, -100.0], [-100.0, 100.0]]], [[[100.0, -100.0], [-100.0, 100.0]]]])
j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
if torch.cuda.is_available():
i = i.cuda()
Expand Down
Loading