From d724a9560d5b80c30225f10ac81ecaa919ee36a3 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 25 Nov 2025 14:41:44 +0800 Subject: [PATCH 01/13] add sigmoid/softmax support and multi-class extension for AsymmetricUnifiedFocalLoss Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 93 +++++++++++++++---------- tests/losses/test_unified_focal_loss.py | 27 +++++-- 2 files changed, 75 insertions(+), 45 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 8484eb67ed..d02066d1d7 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -14,6 +14,7 @@ import warnings import torch +import torch.nn.functional as F from torch.nn.modules.loss import _Loss from monai.networks import one_hot @@ -24,7 +25,9 @@ class AsymmetricFocalTverskyLoss(_Loss): """ AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. - Actually, it's only supported for binary image segmentation now. + It supports both binary and multi-class segmentation. + + The logic assumes channel 0 is Background, and channels 1..N are Foreground. Reimplementation of the Asymmetric Focal Tversky Loss described in: @@ -35,6 +38,7 @@ class AsymmetricFocalTverskyLoss(_Loss): def __init__( self, to_onehot_y: bool = False, + use_softmax: bool = False, delta: float = 0.7, gamma: float = 0.75, epsilon: float = 1e-7, @@ -43,17 +47,25 @@ def __init__( """ Args: to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + use_softmax: whether to use softmax to transform the original logits into probabilities. + If True, softmax is used. If False, sigmoid is used. Defaults to False. delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y + self.use_softmax = use_softmax self.delta = delta self.gamma = gamma self.epsilon = epsilon def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + if self.use_softmax: + y_pred = torch.softmax(y_pred, dim=1) + else: + y_pred = torch.sigmoid(y_pred) + n_pred_ch = y_pred.shape[1] if self.to_onehot_y: @@ -67,17 +79,23 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: # clip the prediction to avoid NaN y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) - axis = list(range(2, len(y_pred.shape))) + + spatial_dims = list(range(2, len(y_pred.shape))) # Calculate true positives (tp), false negatives (fn) and false positives (fp) - tp = torch.sum(y_true * y_pred, dim=axis) - fn = torch.sum(y_true * (1 - y_pred), dim=axis) - fp = torch.sum((1 - y_true) * y_pred, dim=axis) + tp = torch.sum(y_true * y_pred, dim=spatial_dims) + fn = torch.sum(y_true * (1 - y_pred), dim=spatial_dims) + fp = torch.sum((1 - y_true) * y_pred, dim=spatial_dims) dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) # Calculate losses separately for each class, enhancing both classes back_dice = 1 - dice_class[:, 0] - fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) + fore_dice = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma) + + if fore_dice.shape[1] > 1: + fore_dice = torch.mean(fore_dice, dim=1) + else: + fore_dice = fore_dice.squeeze(1) # Average class scores loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) @@ -88,7 +106,7 @@ class AsymmetricFocalLoss(_Loss): """ AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. - Actually, it's only supported for binary image segmentation now. + It supports both binary and multi-class segmentation. Reimplementation of the Asymmetric Focal Loss described in: @@ -99,6 +117,7 @@ class AsymmetricFocalLoss(_Loss): def __init__( self, to_onehot_y: bool = False, + use_softmax: bool = False, delta: float = 0.7, gamma: float = 2, epsilon: float = 1e-7, @@ -107,17 +126,27 @@ def __init__( """ Args: to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. + use_softmax: whether to use softmax to transform the original logits into probabilities. + If True, softmax is used. If False, sigmoid is used. Defaults to False. delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y + self.use_softmax = use_softmax self.delta = delta self.gamma = gamma self.epsilon = epsilon def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + if self.use_softmax: + y_log_pred = F.log_softmax(y_pred, dim=1) + y_pred = torch.exp(y_log_pred) + else: + y_log_pred = F.logsigmoid(y_pred) + y_pred = torch.sigmoid(y_pred) + n_pred_ch = y_pred.shape[1] if self.to_onehot_y: @@ -130,15 +159,20 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) - cross_entropy = -y_true * torch.log(y_pred) + cross_entropy = -y_true * y_log_pred back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] back_ce = (1 - self.delta) * back_ce - fore_ce = cross_entropy[:, 1] + fore_ce = cross_entropy[:, 1:] fore_ce = self.delta * fore_ce - loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) + if fore_ce.shape[1] > 1: + fore_ce = torch.sum(fore_ce, dim=1) + else: + fore_ce = fore_ce.squeeze(1) + + loss = torch.mean(torch.stack([back_ce, fore_ce], dim=-1)) return loss @@ -146,7 +180,7 @@ class AsymmetricUnifiedFocalLoss(_Loss): """ AsymmetricUnifiedFocalLoss is a variant of Focal Loss. - Actually, it's only supported for binary image segmentation now + It supports both binary and multi-class segmentation. Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: @@ -157,20 +191,21 @@ class AsymmetricUnifiedFocalLoss(_Loss): def __init__( self, to_onehot_y: bool = False, - num_classes: int = 2, weight: float = 0.5, gamma: float = 0.5, delta: float = 0.7, + use_softmax: bool = False, reduction: LossReduction | str = LossReduction.MEAN, ): """ Args: to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - num_classes : number of classes, it only supports 2 now. Defaults to 2. delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. weight : weight for each loss function, if it's none it's 0.5. Defaults to None. + use_softmax: whether to use softmax to transform the original logits into probabilities. + If True, softmax is used. If False, sigmoid is used. Defaults to False. Example: >>> import torch @@ -182,50 +217,32 @@ def __init__( """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y - self.num_classes = num_classes self.gamma = gamma self.delta = delta self.weight: float = weight - self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) - self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) + self.use_softmax = use_softmax + self.asy_focal_loss = AsymmetricFocalLoss( + gamma=self.gamma, delta=self.delta, use_softmax=use_softmax, to_onehot_y=to_onehot_y + ) + self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( + gamma=self.gamma, delta=self.delta, use_softmax=use_softmax, to_onehot_y=to_onehot_y + ) - # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: y_pred : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. The input should be the original logits since it will be transformed by a sigmoid in the forward function. y_true : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. Raises: ValueError: When input and target are different shape - ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 - ValueError: When num_classes ValueError: When the number of classes entered does not match the expected number """ if y_pred.shape != y_true.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: - raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") - - if y_pred.shape[1] == 1: - y_pred = one_hot(y_pred, num_classes=self.num_classes) - y_true = one_hot(y_true, num_classes=self.num_classes) - - if torch.max(y_true) != self.num_classes - 1: - raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}") - - n_pred_ch = y_pred.shape[1] - if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: - y_true = one_hot(y_true, num_classes=n_pred_ch) - asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) diff --git a/tests/losses/test_unified_focal_loss.py b/tests/losses/test_unified_focal_loss.py index 3b868a560e..ed964f1684 100644 --- a/tests/losses/test_unified_focal_loss.py +++ b/tests/losses/test_unified_focal_loss.py @@ -20,28 +20,41 @@ from monai.losses import AsymmetricUnifiedFocalLoss TEST_CASES = [ - [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + # Case 0: Binary segmentation + [ + {}, { - "y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), + "y_pred": torch.tensor([[[[20.0, -20.0], [-20.0, 20.0]]], [[[20.0, -20.0], [-20.0, 20.0]]]]), "y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), }, 0.0, ], - [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + # Case 1: Same as above but explicit arguments + [ + {"use_softmax": False, "to_onehot_y": False}, { - "y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), + "y_pred": torch.tensor([[[[20.0, -20.0], [-20.0, 20.0]]], [[[20.0, -20.0], [-20.0, 20.0]]]]), "y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), }, 0.0, ], + # Case 2: Multi-class segmentation + [ + {"use_softmax": True, "to_onehot_y": True}, + { + "y_pred": torch.tensor([[[[-20.0]], [[-20.0]], [[20.0]]]]).repeat(2, 1, 1, 1), + "y_true": torch.tensor([[[[2]]]]).repeat(2, 1, 1, 1), + }, + 0.0, + ], ] class TestAsymmetricUnifiedFocalLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_result(self, input_data, expected_val): - loss = AsymmetricUnifiedFocalLoss() + def test_result(self, input_param, input_data, expected_val): + loss = AsymmetricUnifiedFocalLoss(**input_param) result = loss(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) @@ -52,7 +65,7 @@ def test_ill_shape(self): def test_with_cuda(self): loss = AsymmetricUnifiedFocalLoss() - i = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) + i = torch.tensor([[[[20.0, -20.0], [-20.0, 20.0]]], [[[20.0, -20.0], [-20.0, 20.0]]]]) j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) if torch.cuda.is_available(): i = i.cuda() From d6e433562bdbd547e88606ffb5788b0579ba627e Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 25 Nov 2025 14:56:17 +0800 Subject: [PATCH 02/13] add files Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index d02066d1d7..90c2fe2d10 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -27,8 +27,6 @@ class AsymmetricFocalTverskyLoss(_Loss): It supports both binary and multi-class segmentation. - The logic assumes channel 0 is Background, and channels 1..N are Foreground. - Reimplementation of the Asymmetric Focal Tversky Loss described in: - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", @@ -80,12 +78,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: # clip the prediction to avoid NaN y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) - spatial_dims = list(range(2, len(y_pred.shape))) + axis = list(range(2, len(y_pred.shape))) # Calculate true positives (tp), false negatives (fn) and false positives (fp) - tp = torch.sum(y_true * y_pred, dim=spatial_dims) - fn = torch.sum(y_true * (1 - y_pred), dim=spatial_dims) - fp = torch.sum((1 - y_true) * y_pred, dim=spatial_dims) + tp = torch.sum(y_true * y_pred, dim=axis) + fn = torch.sum(y_true * (1 - y_pred), dim=axis) + fp = torch.sum((1 - y_true) * y_pred, dim=axis) dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) # Calculate losses separately for each class, enhancing both classes @@ -200,32 +198,31 @@ def __init__( """ Args: to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. weight : weight for each loss function, if it's none it's 0.5. Defaults to None. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. + delta : weight of the background. Defaults to 0.7. use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. Example: >>> import torch >>> from monai.losses import AsymmetricUnifiedFocalLoss - >>> pred = torch.ones((1,1,32,32), dtype=torch.float32) - >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) - >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) + >>> pred = torch.randn((1, 3, 32, 32)) + >>> grnd = torch.randint(0, 3, (1, 1, 32, 32)) + >>> fl = AsymmetricUnifiedFocalLoss(use_softmax=True, to_onehot_y=True) >>> fl(pred, grnd) """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y + self.weight: float = weight self.gamma = gamma self.delta = delta - self.weight: float = weight self.use_softmax = use_softmax self.asy_focal_loss = AsymmetricFocalLoss( - gamma=self.gamma, delta=self.delta, use_softmax=use_softmax, to_onehot_y=to_onehot_y + to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax ) self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( - gamma=self.gamma, delta=self.delta, use_softmax=use_softmax, to_onehot_y=to_onehot_y + to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax ) def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: From 8731b3019fa4f55e121d3a920a126a44f8dce2c2 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 25 Nov 2025 15:13:27 +0800 Subject: [PATCH 03/13] Simplify algebraic expression to avoid numerical instability Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 90c2fe2d10..4782df4887 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -49,7 +49,7 @@ def __init__( If True, softmax is used. If False, sigmoid is used. Defaults to False. delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y @@ -88,12 +88,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: # Calculate losses separately for each class, enhancing both classes back_dice = 1 - dice_class[:, 0] - fore_dice = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma) - if fore_dice.shape[1] > 1: - fore_dice = torch.mean(fore_dice, dim=1) + if n_pred_ch > 1: + fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma) + + if fore_dice.shape[1] > 1: + fore_dice = torch.mean(fore_dice, dim=1) + else: + fore_dice = fore_dice.squeeze(1) else: - fore_dice = fore_dice.squeeze(1) + fore_dice = torch.zeros_like(back_dice) # Average class scores loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) @@ -128,7 +132,7 @@ def __init__( If True, softmax is used. If False, sigmoid is used. Defaults to False. delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y @@ -198,8 +202,8 @@ def __init__( """ Args: to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - weight : weight for each loss function, if it's none it's 0.5. Defaults to None. - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. + weight : weight for each loss function. Defaults to 0.5. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. delta : weight of the background. Defaults to 0.7. use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. @@ -235,7 +239,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: Raises: ValueError: When input and target are different shape - ValueError: When the number of classes entered does not match the expected number """ if y_pred.shape != y_true.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") From ea8a6ee5a238070482ee2bb9e8289466752e37fd Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 25 Nov 2025 15:25:38 +0800 Subject: [PATCH 04/13] minor fixes Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 4782df4887..ee4ffbbe3e 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -131,7 +131,7 @@ def __init__( use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. + gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2. epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. """ super().__init__(reduction=LossReduction(reduction).value) @@ -166,13 +166,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] back_ce = (1 - self.delta) * back_ce - fore_ce = cross_entropy[:, 1:] - fore_ce = self.delta * fore_ce + if n_pred_ch > 1: + fore_ce = cross_entropy[:, 1:] + fore_ce = self.delta * fore_ce - if fore_ce.shape[1] > 1: - fore_ce = torch.sum(fore_ce, dim=1) + if fore_ce.shape[1] > 1: + fore_ce = torch.sum(fore_ce, dim=1) + else: + fore_ce = fore_ce.squeeze(1) else: - fore_ce = fore_ce.squeeze(1) + fore_ce = torch.zeros_like(back_ce) loss = torch.mean(torch.stack([back_ce, fore_ce], dim=-1)) return loss From 52ccd353f4cdb521cc130a8558314e99d1eff449 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 25 Nov 2025 15:37:30 +0800 Subject: [PATCH 05/13] fix: Binary segmentation foreground loss never evaluated Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 39 +++++++++++++++++------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index ee4ffbbe3e..4bce1a3587 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -64,6 +64,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: else: y_pred = torch.sigmoid(y_pred) + if y_pred.shape[1] == 1: + y_pred = torch.cat([1 - y_pred, y_pred], dim=1) + y_true = torch.cat([1 - y_true, y_true], dim=1) + n_pred_ch = y_pred.shape[1] if self.to_onehot_y: @@ -77,7 +81,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: # clip the prediction to avoid NaN y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) - axis = list(range(2, len(y_pred.shape))) # Calculate true positives (tp), false negatives (fn) and false positives (fp) @@ -86,18 +89,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: fp = torch.sum((1 - y_true) * y_pred, dim=axis) dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) - # Calculate losses separately for each class, enhancing both classes + # Class 0 is Background back_dice = 1 - dice_class[:, 0] - if n_pred_ch > 1: - fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma) + # Class 1+ is Foreground + fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma) - if fore_dice.shape[1] > 1: - fore_dice = torch.mean(fore_dice, dim=1) - else: - fore_dice = fore_dice.squeeze(1) + if fore_dice.shape[1] > 1: + fore_dice = torch.mean(fore_dice, dim=1) else: - fore_dice = torch.zeros_like(back_dice) + fore_dice = fore_dice.squeeze(1) # Average class scores loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) @@ -149,6 +150,11 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_log_pred = F.logsigmoid(y_pred) y_pred = torch.sigmoid(y_pred) + if y_pred.shape[1] == 1: + y_pred = torch.cat([1 - y_pred, y_pred], dim=1) + y_log_pred = torch.log(torch.clamp(y_pred, 1e-7, 1.0)) + y_true = torch.cat([1 - y_true, y_true], dim=1) + n_pred_ch = y_pred.shape[1] if self.to_onehot_y: @@ -163,19 +169,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) cross_entropy = -y_true * y_log_pred + # Class 0: Background back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] back_ce = (1 - self.delta) * back_ce - if n_pred_ch > 1: - fore_ce = cross_entropy[:, 1:] - fore_ce = self.delta * fore_ce + # Class 1+: Foreground + fore_ce = cross_entropy[:, 1:] + fore_ce = self.delta * fore_ce - if fore_ce.shape[1] > 1: - fore_ce = torch.sum(fore_ce, dim=1) - else: - fore_ce = fore_ce.squeeze(1) + if fore_ce.shape[1] > 1: + fore_ce = torch.sum(fore_ce, dim=1) else: - fore_ce = torch.zeros_like(back_ce) + fore_ce = fore_ce.squeeze(1) loss = torch.mean(torch.stack([back_ce, fore_ce], dim=-1)) return loss From 6c777255cdf0b53dc7cd6531469f017b77e79c97 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 25 Nov 2025 16:10:44 +0800 Subject: [PATCH 06/13] minor fixes Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 4bce1a3587..c165c0de87 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -242,14 +242,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: Args: y_pred : the shape should be BNH[WD], where N is the number of classes. The input should be the original logits since it will be transformed by - a sigmoid in the forward function. + a sigmoid/softmax in the forward function. y_true : the shape should be BNH[WD], where N is the number of classes. - - Raises: - ValueError: When input and target are different shape """ - if y_pred.shape != y_true.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) From c793454cdeea699b2b3f61364be1001d9b9ed27e Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 25 Nov 2025 16:42:35 +0800 Subject: [PATCH 07/13] minor fixes Signed-off-by: ytl0623 --- tests/losses/test_unified_focal_loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/losses/test_unified_focal_loss.py b/tests/losses/test_unified_focal_loss.py index ed964f1684..f93a68272f 100644 --- a/tests/losses/test_unified_focal_loss.py +++ b/tests/losses/test_unified_focal_loss.py @@ -24,7 +24,7 @@ [ {}, { - "y_pred": torch.tensor([[[[20.0, -20.0], [-20.0, 20.0]]], [[[20.0, -20.0], [-20.0, 20.0]]]]), + "y_pred": torch.tensor([[[[100.0, -100.0], [-100.0, 100.0]]], [[[100.0, -100.0], [-100.0, 100.0]]]]), "y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), }, 0.0, @@ -33,7 +33,7 @@ [ {"use_softmax": False, "to_onehot_y": False}, { - "y_pred": torch.tensor([[[[20.0, -20.0], [-20.0, 20.0]]], [[[20.0, -20.0], [-20.0, 20.0]]]]), + "y_pred": torch.tensor([[[[100.0, -100.0], [-100.0, 100.0]]], [[[100.0, -100.0], [-100.0, 100.0]]]]), "y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), }, 0.0, @@ -42,7 +42,7 @@ [ {"use_softmax": True, "to_onehot_y": True}, { - "y_pred": torch.tensor([[[[-20.0]], [[-20.0]], [[20.0]]]]).repeat(2, 1, 1, 1), + "y_pred": torch.tensor([[[[-100.0]], [[-100.0]], [[100.0]]]]).repeat(2, 1, 1, 1), "y_true": torch.tensor([[[[2]]]]).repeat(2, 1, 1, 1), }, 0.0, @@ -65,7 +65,7 @@ def test_ill_shape(self): def test_with_cuda(self): loss = AsymmetricUnifiedFocalLoss() - i = torch.tensor([[[[20.0, -20.0], [-20.0, 20.0]]], [[[20.0, -20.0], [-20.0, 20.0]]]]) + i = torch.tensor([[[[100.0, -100.0], [-100.0, 100.0]]], [[[100.0, -100.0], [-100.0, 100.0]]]]) j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) if torch.cuda.is_available(): i = i.cuda() From f3af22a4327d55b5710068a65ea910f9c9cc647f Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 25 Nov 2025 18:19:06 +0800 Subject: [PATCH 08/13] minor fixes Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index c165c0de87..e8dedd6d85 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -79,8 +79,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - # clip the prediction to avoid NaN - y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) axis = list(range(2, len(y_pred.shape))) # Calculate true positives (tp), false negatives (fn) and false positives (fp) @@ -117,6 +115,7 @@ class AsymmetricFocalLoss(_Loss): Michael Yeung, Computerized Medical Imaging and Graphics """ + def __init__( self, to_onehot_y: bool = False, From 0640d0f8b2cf52f3f6c4f382697020e4d2cd3255 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 25 Nov 2025 18:33:42 +0800 Subject: [PATCH 09/13] codeformat Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index e8dedd6d85..98c0be2f46 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -115,7 +115,6 @@ class AsymmetricFocalLoss(_Loss): Michael Yeung, Computerized Medical Imaging and Graphics """ - def __init__( self, to_onehot_y: bool = False, From 6c2518916c19af88e74852c0992aa7fd9500b5bb Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 25 Nov 2025 19:15:38 +0800 Subject: [PATCH 10/13] minor fixes Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 98c0be2f46..6b3e4921a2 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -176,7 +176,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: fore_ce = self.delta * fore_ce if fore_ce.shape[1] > 1: - fore_ce = torch.sum(fore_ce, dim=1) + fore_ce = torch.mean(fore_ce, dim=1) else: fore_ce = fore_ce.squeeze(1) @@ -241,7 +241,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_pred : the shape should be BNH[WD], where N is the number of classes. The input should be the original logits since it will be transformed by a sigmoid/softmax in the forward function. - y_true : the shape should be BNH[WD], where N is the number of classes. + y_true : the shape should be BNH[WD], or B1H[WD] when to_onehot_y=True. """ asy_focal_loss = self.asy_focal_loss(y_pred, y_true) From e0e48a3d551ba54dfe6985db39a77ddd7d5f045a Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 26 Nov 2025 16:51:16 +0800 Subject: [PATCH 11/13] minor fixes Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 72 +++++++++++++++++------------- 1 file changed, 41 insertions(+), 31 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 6b3e4921a2..3af06f3b58 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -48,8 +48,8 @@ def __init__( use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. - epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. + epsilon : stability factor used to avoid division by zero. Defaults to 1e-7. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y @@ -59,6 +59,17 @@ def __init__( self.epsilon = epsilon def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + n_pred_ch = y_pred.shape[1] + + if self.use_softmax and n_pred_ch == 1: + raise ValueError("single channel prediction with `use_softmax=True` is not allowed.") + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + y_true = one_hot(y_true, num_classes=n_pred_ch) + if self.use_softmax: y_pred = torch.softmax(y_pred, dim=1) else: @@ -68,17 +79,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_pred = torch.cat([1 - y_pred, y_pred], dim=1) y_true = torch.cat([1 - y_true, y_true], dim=1) - n_pred_ch = y_pred.shape[1] - - if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: - y_true = one_hot(y_true, num_classes=n_pred_ch) - if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") + # Calculate Loss axis = list(range(2, len(y_pred.shape))) # Calculate true positives (tp), false negatives (fn) and false positives (fp) @@ -130,8 +134,8 @@ def __init__( use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2. - epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2. + epsilon : stability factor used to avoid division by zero. Defaults to 1e-7. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y @@ -141,6 +145,20 @@ def __init__( self.epsilon = epsilon def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + n_pred_ch = y_pred.shape[1] + + if self.use_softmax and n_pred_ch == 1: + raise ValueError("single channel prediction with `use_softmax=True` is not allowed.") + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + y_true = one_hot(y_true, num_classes=n_pred_ch) + + # Save logits for numerical stability in single-channel expansion + y_logits = y_pred + if self.use_softmax: y_log_pred = F.log_softmax(y_pred, dim=1) y_pred = torch.exp(y_log_pred) @@ -148,19 +166,13 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_log_pred = F.logsigmoid(y_pred) y_pred = torch.sigmoid(y_pred) - if y_pred.shape[1] == 1: + # Handle Single Channel (Binary) Expansion + if n_pred_ch == 1: y_pred = torch.cat([1 - y_pred, y_pred], dim=1) - y_log_pred = torch.log(torch.clamp(y_pred, 1e-7, 1.0)) + bg_log_pred = F.logsigmoid(-y_logits) + y_log_pred = torch.cat([bg_log_pred, y_log_pred], dim=1) y_true = torch.cat([1 - y_true, y_true], dim=1) - n_pred_ch = y_pred.shape[1] - - if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: - y_true = one_hot(y_true, num_classes=n_pred_ch) - if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") @@ -199,20 +211,18 @@ class AsymmetricUnifiedFocalLoss(_Loss): def __init__( self, to_onehot_y: bool = False, - weight: float = 0.5, - gamma: float = 0.5, - delta: float = 0.7, use_softmax: bool = False, + delta: float = 0.7, + gamma: float = 2, reduction: LossReduction | str = LossReduction.MEAN, ): """ Args: - to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - weight : weight for each loss function. Defaults to 0.5. - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. - delta : weight of the background. Defaults to 0.7. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. + delta : weight of the background. Defaults to 0.7. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. Example: >>> import torch @@ -250,9 +260,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss if self.reduction == LossReduction.SUM.value: - return torch.sum(loss) # sum over the batch and channel dims + return torch.sum(loss) if self.reduction == LossReduction.NONE.value: - return loss # returns [N, num_classes] losses + return loss if self.reduction == LossReduction.MEAN.value: return torch.mean(loss) raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') From 954df6075cc7042b0acb4bab06750b4338782e3c Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 26 Nov 2025 16:55:51 +0800 Subject: [PATCH 12/13] minor fixes Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 3af06f3b58..b347e8b043 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -214,6 +214,7 @@ def __init__( use_softmax: bool = False, delta: float = 0.7, gamma: float = 2, + weight: float = 0.5, reduction: LossReduction | str = LossReduction.MEAN, ): """ @@ -223,6 +224,7 @@ def __init__( If True, softmax is used. If False, sigmoid is used. Defaults to False. delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. + weight : weight for each loss function. Defaults to 0.5. Example: >>> import torch From ee2215aab61ef25b507b09628a0be50c5c26a147 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Wed, 26 Nov 2025 17:03:41 +0800 Subject: [PATCH 13/13] minor fixes Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index b347e8b043..bb2dfdd832 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -223,8 +223,8 @@ def __init__( use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. - weight : weight for each loss function. Defaults to 0.5. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2. + weight: weight for combining the focal and focal-Tversky terms. Defaults to 0.5. Example: >>> import torch @@ -241,10 +241,16 @@ def __init__( self.delta = delta self.use_softmax = use_softmax self.asy_focal_loss = AsymmetricFocalLoss( - to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax + to_onehot_y=to_onehot_y, + use_softmax=use_softmax, + delta=self.delta, + gamma=self.gamma, ) self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( - to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax + to_onehot_y=to_onehot_y, + use_softmax=use_softmax, + delta=self.delta, + gamma=self.gamma, ) def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: