-
Notifications
You must be signed in to change notification settings - Fork 1.4k
add sigmoid/softmax support and multi-class extension for AsymmetricU… #8642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 6 commits
d724a95
d6e4335
8731b30
ea8a6ee
52ccd35
6c77725
c793454
f3af22a
3b7277f
0640d0f
6c25189
e0e48a3
954df60
ee2215a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| import warnings | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch.nn.modules.loss import _Loss | ||
|
|
||
| from monai.networks import one_hot | ||
|
|
@@ -24,7 +25,7 @@ class AsymmetricFocalTverskyLoss(_Loss): | |
| """ | ||
| AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. | ||
|
|
||
| Actually, it's only supported for binary image segmentation now. | ||
| It supports both binary and multi-class segmentation. | ||
|
|
||
| Reimplementation of the Asymmetric Focal Tversky Loss described in: | ||
|
|
||
|
|
@@ -35,6 +36,7 @@ class AsymmetricFocalTverskyLoss(_Loss): | |
| def __init__( | ||
| self, | ||
| to_onehot_y: bool = False, | ||
| use_softmax: bool = False, | ||
| delta: float = 0.7, | ||
| gamma: float = 0.75, | ||
| epsilon: float = 1e-7, | ||
|
|
@@ -43,17 +45,29 @@ def __init__( | |
| """ | ||
| Args: | ||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||
| use_softmax: whether to use softmax to transform the original logits into probabilities. | ||
| If True, softmax is used. If False, sigmoid is used. Defaults to False. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||
| epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.use_softmax = use_softmax | ||
| self.delta = delta | ||
| self.gamma = gamma | ||
| self.epsilon = epsilon | ||
|
|
||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
| if self.use_softmax: | ||
| y_pred = torch.softmax(y_pred, dim=1) | ||
| else: | ||
| y_pred = torch.sigmoid(y_pred) | ||
|
|
||
| if y_pred.shape[1] == 1: | ||
| y_pred = torch.cat([1 - y_pred, y_pred], dim=1) | ||
| y_true = torch.cat([1 - y_true, y_true], dim=1) | ||
|
|
||
| n_pred_ch = y_pred.shape[1] | ||
|
|
||
| if self.to_onehot_y: | ||
|
|
@@ -75,9 +89,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | |
| fp = torch.sum((1 - y_true) * y_pred, dim=axis) | ||
| dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) | ||
|
|
||
| # Calculate losses separately for each class, enhancing both classes | ||
| # Class 0 is Background | ||
| back_dice = 1 - dice_class[:, 0] | ||
| fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) | ||
|
|
||
| # Class 1+ is Foreground | ||
| fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma) | ||
|
|
||
| if fore_dice.shape[1] > 1: | ||
| fore_dice = torch.mean(fore_dice, dim=1) | ||
| else: | ||
| fore_dice = fore_dice.squeeze(1) | ||
|
|
||
| # Average class scores | ||
| loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) | ||
|
|
@@ -88,7 +109,7 @@ class AsymmetricFocalLoss(_Loss): | |
| """ | ||
| AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. | ||
|
|
||
| Actually, it's only supported for binary image segmentation now. | ||
| It supports both binary and multi-class segmentation. | ||
|
|
||
| Reimplementation of the Asymmetric Focal Loss described in: | ||
|
|
||
|
|
@@ -99,6 +120,7 @@ class AsymmetricFocalLoss(_Loss): | |
| def __init__( | ||
| self, | ||
| to_onehot_y: bool = False, | ||
| use_softmax: bool = False, | ||
| delta: float = 0.7, | ||
| gamma: float = 2, | ||
| epsilon: float = 1e-7, | ||
|
|
@@ -107,17 +129,32 @@ def __init__( | |
| """ | ||
| Args: | ||
| to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||
| use_softmax: whether to use softmax to transform the original logits into probabilities. | ||
| If True, softmax is used. If False, sigmoid is used. Defaults to False. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2. | ||
| epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.use_softmax = use_softmax | ||
| self.delta = delta | ||
| self.gamma = gamma | ||
| self.epsilon = epsilon | ||
|
|
||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
| if self.use_softmax: | ||
| y_log_pred = F.log_softmax(y_pred, dim=1) | ||
| y_pred = torch.exp(y_log_pred) | ||
| else: | ||
| y_log_pred = F.logsigmoid(y_pred) | ||
| y_pred = torch.sigmoid(y_pred) | ||
|
|
||
| if y_pred.shape[1] == 1: | ||
| y_pred = torch.cat([1 - y_pred, y_pred], dim=1) | ||
| y_log_pred = torch.log(torch.clamp(y_pred, 1e-7, 1.0)) | ||
| y_true = torch.cat([1 - y_true, y_true], dim=1) | ||
ytl0623 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| n_pred_ch = y_pred.shape[1] | ||
|
|
||
| if self.to_onehot_y: | ||
|
|
@@ -130,23 +167,30 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | |
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||
|
|
||
| y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) | ||
| cross_entropy = -y_true * torch.log(y_pred) | ||
| cross_entropy = -y_true * y_log_pred | ||
|
|
||
| # Class 0: Background | ||
| back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] | ||
| back_ce = (1 - self.delta) * back_ce | ||
|
|
||
| fore_ce = cross_entropy[:, 1] | ||
| # Class 1+: Foreground | ||
| fore_ce = cross_entropy[:, 1:] | ||
| fore_ce = self.delta * fore_ce | ||
|
|
||
| loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) | ||
| if fore_ce.shape[1] > 1: | ||
| fore_ce = torch.sum(fore_ce, dim=1) | ||
| else: | ||
| fore_ce = fore_ce.squeeze(1) | ||
|
|
||
| loss = torch.mean(torch.stack([back_ce, fore_ce], dim=-1)) | ||
| return loss | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class AsymmetricUnifiedFocalLoss(_Loss): | ||
| """ | ||
| AsymmetricUnifiedFocalLoss is a variant of Focal Loss. | ||
|
|
||
| Actually, it's only supported for binary image segmentation now | ||
| It supports both binary and multi-class segmentation. | ||
|
|
||
| Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: | ||
|
|
||
|
|
@@ -157,74 +201,50 @@ class AsymmetricUnifiedFocalLoss(_Loss): | |
| def __init__( | ||
| self, | ||
| to_onehot_y: bool = False, | ||
| num_classes: int = 2, | ||
| weight: float = 0.5, | ||
| gamma: float = 0.5, | ||
| delta: float = 0.7, | ||
| use_softmax: bool = False, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| ): | ||
| """ | ||
| Args: | ||
| to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||
| num_classes : number of classes, it only supports 2 now. Defaults to 2. | ||
| weight : weight for each loss function. Defaults to 0.5. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. | ||
| delta : weight of the background. Defaults to 0.7. | ||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. | ||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||
| weight : weight for each loss function, if it's none it's 0.5. Defaults to None. | ||
| use_softmax: whether to use softmax to transform the original logits into probabilities. | ||
| If True, softmax is used. If False, sigmoid is used. Defaults to False. | ||
|
|
||
| Example: | ||
| >>> import torch | ||
| >>> from monai.losses import AsymmetricUnifiedFocalLoss | ||
| >>> pred = torch.ones((1,1,32,32), dtype=torch.float32) | ||
| >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) | ||
| >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) | ||
| >>> pred = torch.randn((1, 3, 32, 32)) | ||
| >>> grnd = torch.randint(0, 3, (1, 1, 32, 32)) | ||
| >>> fl = AsymmetricUnifiedFocalLoss(use_softmax=True, to_onehot_y=True) | ||
| >>> fl(pred, grnd) | ||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| self.to_onehot_y = to_onehot_y | ||
| self.num_classes = num_classes | ||
| self.weight: float = weight | ||
| self.gamma = gamma | ||
| self.delta = delta | ||
| self.weight: float = weight | ||
| self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) | ||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) | ||
| self.use_softmax = use_softmax | ||
| self.asy_focal_loss = AsymmetricFocalLoss( | ||
| to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax | ||
| ) | ||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( | ||
| to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax | ||
| ) | ||
|
Comment on lines
249
to
254
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Numerical instability: gamma=2 causes explosion in Tversky component.
Consider using separate gamma values for each sub-loss or clamping the base in Tversky: self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
- to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
+ to_onehot_y=to_onehot_y, gamma=0.75, delta=self.delta, use_softmax=use_softmax
)Or add clamping in fore_dice_base = torch.clamp(1 - dice_class[:, 1:], min=self.epsilon)
fore_dice = torch.pow(fore_dice_base, 1 - self.gamma)🤖 Prompt for AI Agents |
||
|
|
||
| # TODO: Implement this function to support multiple classes segmentation | ||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| y_pred : the shape should be BNH[WD], where N is the number of classes. | ||
| It only supports binary segmentation. | ||
| The input should be the original logits since it will be transformed by | ||
| a sigmoid in the forward function. | ||
| a sigmoid/softmax in the forward function. | ||
| y_true : the shape should be BNH[WD], where N is the number of classes. | ||
| It only supports binary segmentation. | ||
|
|
||
| Raises: | ||
| ValueError: When input and target are different shape | ||
| ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 | ||
| ValueError: When num_classes | ||
| ValueError: When the number of classes entered does not match the expected number | ||
| """ | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if y_pred.shape != y_true.shape: | ||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||
|
|
||
| if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: | ||
| raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") | ||
|
|
||
| if y_pred.shape[1] == 1: | ||
| y_pred = one_hot(y_pred, num_classes=self.num_classes) | ||
| y_true = one_hot(y_true, num_classes=self.num_classes) | ||
|
|
||
| if torch.max(y_true) != self.num_classes - 1: | ||
| raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}") | ||
|
|
||
| n_pred_ch = y_pred.shape[1] | ||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| else: | ||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||
|
|
||
| asy_focal_loss = self.asy_focal_loss(y_pred, y_true) | ||
| asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential numerical issue when
gamma > 1.When
gamma > 1, the exponent1 - gammabecomes negative. Ifdice_class[:, 1:]approaches 1.0 (perfect prediction),(1 - dice_class)approaches 0, and0^negativeexplodes.Default
gamma=0.75is safe, but document or validate thatgamma <= 1is expected for this class.🤖 Prompt for AI Agents